# Merger Effects

In [99]:
from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from tqdm import tqdm
import pandas as pd

In [100]:
from auriga.snapshot import Snapshot
from auriga.images import figure_setup, set_axs_configuration
from auriga.settings import Settings
from auriga.support import find_indices

In [101]:
figure_setup()

In [102]:
settings = Settings()

## The Decomposition Phase-Space

In [103]:
def read_data(simulation: str) -> tuple:
    settings = Settings()

    s = Snapshot(simulation=simulation,
                 loadonlytype=[0, 1, 2, 3, 4, 5])
    s.add_circularity()
    s.add_reference_to_potential()
    s.add_normalized_potential()
    s.add_stellar_age()
    s.add_metal_abundance(of="Fe", to="H")
    s.tag_particles_by_region(
        disc_std_circ=settings.disc_std_circ,
        disc_min_circ=settings.disc_min_circ,
        cold_disc_delta_circ=settings.cold_disc_delta_circ,
        bulge_max_specific_energy=settings.bulge_max_specific_energy)

    is_real_star = (s.type == 4) & (s.stellar_formation_time > 0)

    df = pd.DataFrame(
        {"Circularity": s.circularity[is_real_star],
         "zSpecificAngularMomentum_kpckm/s": np.cross(s.pos, s.vel)[
            is_real_star, 2] * s.expansion_factor,
         "ParticleID": s.ids[is_real_star],
         "NormalizedPotential": s.normalized_potential[is_real_star],
         "xPosition_ckpc": s.pos[is_real_star, 0],
         "yPosition_ckpc": s.pos[is_real_star, 1],
         "Mass_Msun": s.mass[is_real_star],
         "StellarAge_Gyr": s.stellar_age[is_real_star],
         "[Fe/H]": s.metal_abundance["Fe/H"][is_real_star],
         "RegionTag": s.region_tag[is_real_star],
         "Halo": s.halo[is_real_star],
         "Subhalo": s.subhalo[is_real_star]}
    )

    return df, s.halo_idx, s.subhalo_idx

In [110]:
simulation = "au29_or_l4"

In [111]:
target_ids = None

# Select scatter IDs at snapshot 94 for Au8
if simulation == "au8_or_l4":
    df, halo_idx, _ = read_data(simulation="au8_or_l4_s94")
    target_ids = df["ParticleID"][
        (df["Halo"] == halo_idx) & (df["Subhalo"] == 1)].to_numpy()

# Select scatter IDs at snapshot 80 for Au12
if simulation == "au12_or_l4":
    df, halo_idx, _ = read_data(simulation="au12_or_l4_s80")
    target_ids = df["ParticleID"][
        (df["Halo"] == halo_idx) & (df["Subhalo"] == 1)].to_numpy()

# Select scatter IDs at snapshot 100 for Au29
if simulation == "au29_or_l4":
    df, halo_idx, _ = read_data(simulation="au29_or_l4_s100")
    target_ids = df["ParticleID"][
        (df["Halo"] == halo_idx) & (df["Subhalo"] == 1)].to_numpy()

In [114]:
# for snum in tqdm(range(100, 127 + 1)):
for snum in [99]:
    df, halo_idx, subhalo_idx = read_data(simulation=f"{simulation}_s{snum}")

    if target_ids is not None:
        idxs = find_indices(df["ParticleID"], target_ids)
        idxs = idxs[idxs >= 0]

    fig = plt.figure(figsize=(7.2, 2.5))
    gs = fig.add_gridspec(nrows=1, ncols=4, hspace=0.5, wspace=0.6)
    axs = gs.subplots(sharex=False, sharey=False)

    axs[0].set_xlim(-4000, 4000)
    axs[0].set_ylim(-1, 0)
    axs[0].set_xticks([-2000, 0, 2000])
    axs[0].set_yticks([-1, -.8, -.6, -.4, -.2, 0])
    axs[0].set_xlabel(
        r"$j_z$ [$\mathrm{kpc} ~ \mathrm{km} ~ \mathrm{s}^{-1}$]", fontsize=7)
    axs[0].set_ylabel(
        r"$\tilde{e} = e \, \left| e \right|_\mathrm{max}^{-1}$", fontsize=7)

    axs[1].set_ylim(-100, 100)
    axs[1].set_xlim(-100, 100)
    axs[1].set_xlabel(r"$x$ [ckpc]", fontsize=7)
    axs[1].set_ylabel(r"$y$ [ckpc]", fontsize=7)

    axs[2].set_ylim(-4, 3)
    axs[2].set_xlim(0, 14)
    axs[2].set_xlabel("Age [Gyr]", fontsize=7)
    axs[2].set_ylabel("[Fe/H]", fontsize=7)

    axs[3].set_ylim(0, 0.15)
    axs[3].set_xlim(-3, 1.5)
    axs[3].set_yticks([0, 0.04, 0.08, 0.12])
    axs[3].set_xticks([-2, -1, 0, 1])
    axs[3].set_xlabel('[Fe/H] [dex]', fontsize=7)
    axs[3].set_ylabel(r"$f_\star$", fontsize=7)

    for ax in axs.flatten():
        ax.set_box_aspect(1)
        ax.tick_params(axis="both", which="both", labelsize=6)

    is_main_obj = (df["Halo"] == halo_idx) & (df["Subhalo"] == subhalo_idx)

    axs[0].hist2d(
        df["zSpecificAngularMomentum_kpckm/s"][is_main_obj],
        df["NormalizedPotential"][is_main_obj],
        cmap='gnuplot2',
        bins=200,
        range=[axs[0].get_xlim(), axs[0].get_ylim()],
        norm=mcolors.LogNorm(vmin=1E0, vmax=1E3),
        rasterized=True)
    if target_ids is not None:
        axs[0].scatter(
            df.iloc[idxs]["zSpecificAngularMomentum_kpckm/s"],
            df.iloc[idxs]["NormalizedPotential"],
            lw=0, s=0.1, c="tab:green",
        )

    n_bins = 200
    hist_range = axs[1].get_xlim()
    bin_area = (np.diff(hist_range)[0] / n_bins)**2
    axs[1].hist2d(
        df["xPosition_ckpc"][is_main_obj],
        df["yPosition_ckpc"][is_main_obj],
        cmap='gnuplot2',
        weights=df["Mass_Msun"][is_main_obj] / bin_area,
        bins=n_bins,
        range=[hist_range, hist_range],
        norm=mcolors.LogNorm(vmin=1E4, vmax=1E9),
        rasterized=True)
    if target_ids is not None:
        axs[1].scatter(
            df.iloc[idxs]["xPosition_ckpc"],
            df.iloc[idxs]["yPosition_ckpc"],
            lw=0, s=0.1, c="tab:green",
        )

    axs[2].hist2d(
        df["StellarAge_Gyr"][is_main_obj],
        df["[Fe/H]"][is_main_obj],
        cmap='gnuplot2',
        bins=200,
        range=[axs[2].get_xlim(), axs[2].get_ylim()],
        norm=mcolors.LogNorm(vmin=1E0, vmax=1E3),
        rasterized=True)
    if target_ids is not None:
        axs[2].scatter(
            df.iloc[idxs]["StellarAge_Gyr"],
            df.iloc[idxs]["[Fe/H]"],
            lw=0, s=0.1, c="tab:green",
        )

    hist, bin_edges = np.histogram(
        a=df["[Fe/H]"][is_main_obj], bins=50, range=axs[3].get_xlim())
    bin_centers = bin_edges[1:] - np.diff(bin_edges)[0] / 2
    axs[3].plot(bin_centers, hist / df["[Fe/H]"][is_main_obj].shape[0],
                c="black", lw=1)
    for i in range(len(settings.components)):
        hist, _ = np.histogram(
            a=df["[Fe/H]"][df["RegionTag"] == i], bins=50,
            range=axs[3].get_xlim())
        axs[3].plot(
            bin_centers, hist / df["[Fe/H]"][is_main_obj].shape[0], lw=1,
            c=settings.component_colors[settings.components[i]],
            label=settings.components[i])
    axs[3].legend(loc="upper left", framealpha=0, fontsize=5)

    fig.savefig(
        f"../images/merger_effects/{simulation}/s{snum}.png",)
    plt.close(fig)