# The Evolution of Stellar Positions

In [80]:
import matplotlib as mpl
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.cm as cm
import numpy as np
import pandas as pd
import warnings
from auriga.images import figure_setup, set_axs_configuration, add_redshift
from auriga.snapshot import Snapshot
from auriga.settings import Settings
from auriga.support import make_snapshot_number, multi_color_line
from auriga.parser import parse

In [17]:
figure_setup()

In [10]:
settings = Settings()

In [11]:
def get_stellar_ids_by_region(simulation: str) -> dict:
    """
    Return a dictionary with all the IDs of the stars in each region in the
    current snapshot.

    Parameters
    ----------
    simulation : str
        The simulation to load.

    Returns
    -------
    ids_by_region : dict
        A dictionary with all the IDs of the stars in each region.
    """

    settings = Settings()

    s = Snapshot(simulation=simulation, loadonlytype=[0, 1, 2, 3, 4, 5])
    s.add_circularity()
    s.tag_particles_by_region(disc_std_circ=1.0,
                              disc_min_circ=0.4,
                              cold_disc_delta_circ=0.25,
                              bulge_max_specific_energy=-0.75)

    is_star = (s.type == 4) & (s.stellar_formation_time > 0)

    ids_by_region = {}

    for region_tag in settings.component_tags.values():
        is_region = s.region_tag == region_tag
        ids_by_region[region_tag] = s.ids[is_star & is_region]
    
    return ids_by_region

In [89]:
def get_avg_props_of_ids_in_snapshot(simulation: str,
                                     ids_by_region: dict
                                     ) -> tuple:
    """
    Return the average properties of the ids in `ids_by_region` (a dictionary
    of star IDs that populate each region) in the current simulation.

    Paramters
    ---------
    simulation : str
        The simulation to analyze.
    ids_by_region : dict
        A dictionary with the star IDs of the particles in each region of the
        galaxy.
    
    Returns
    -------
    tuple
        A tuple with the time in the first component and a dictionary with
        a list of properties (median spherical radius, median cylindrical
        radius, and median absolute value of z, in that order) for each
        component.
    """

    s = Snapshot(simulation=simulation, loadonlytype=[4])
    s.add_extra_coordinates()

    props_by_region = {}

    for region_tag in ids_by_region:
        idxs = s.get_idxs_of_ids(ids=ids_by_region[region_tag])
        # Remove idxs with values -1 (unmatched particles).
        idxs = idxs[idxs >= 0]

        if idxs.shape[0] == 0:
            props_by_region[region_tag] = [np.nan] * 15
        else:
            props_by_region[region_tag] = [
                np.mean(s.r[idxs]),
                np.median(s.r[idxs]),
                np.std(s.r[idxs]),
                np.percentile(s.r[idxs], 16),
                np.percentile(s.r[idxs], 84),
                np.mean(s.rho[idxs]),
                np.median(s.rho[idxs]),
                np.std(s.rho[idxs]),
                np.percentile(s.rho[idxs], 16),
                np.percentile(s.rho[idxs], 84),
                np.mean(np.abs(s.pos[idxs, 2])),
                np.median(np.abs(s.pos[idxs, 2])),
                np.std(np.abs(s.pos[idxs, 2])),
                np.percentile(np.abs(s.pos[idxs, 2]), 16),
                np.percentile(np.abs(s.pos[idxs, 2]), 84),
                ]

    return s.time, s.redshift, props_by_region

In [13]:
def get_avg_props_of_present_day_stars(simulation: str) -> pd.DataFrame:
    """
    Create a Pandas DataFrame with the time and the properties of each
    galactic component.

    Parameters
    ----------
    simulation : str
        The simulation to analyze.
    
    Returns
    -------
    pd.DataFrame
        The Pandas DataFrame.
    """

    settings = Settings()

    times = []
    redshifts = []

    halo_r_mean = []
    halo_r_median = []
    halo_r_std = []
    halo_r_perc16 = []
    halo_r_perc84 = []
    halo_rxy_mean = []
    halo_rxy_median = []
    halo_rxy_std = []
    halo_rxy_perc16 = []
    halo_rxy_perc84 = []
    halo_z_mean = []
    halo_z_median = []
    halo_z_std = []
    halo_z_perc16 = []
    halo_z_perc84 = []

    bulge_r_mean = []
    bulge_r_median = []
    bulge_r_std = []
    bulge_r_perc16 = []
    bulge_r_perc84 = []
    bulge_rxy_mean = []
    bulge_rxy_median = []
    bulge_rxy_std = []
    bulge_rxy_perc16 = []
    bulge_rxy_perc84 = []
    bulge_z_mean = []
    bulge_z_median = []
    bulge_z_std = []
    bulge_z_perc16 = []
    bulge_z_perc84 = []

    cold_disc_r_mean = []
    cold_disc_r_median = []
    cold_disc_r_std = []
    cold_disc_r_perc16 = []
    cold_disc_r_perc84 = []
    cold_disc_rxy_mean = []
    cold_disc_rxy_median = []
    cold_disc_rxy_std = []
    cold_disc_rxy_perc16 = []
    cold_disc_rxy_perc84 = []
    cold_disc_z_mean = []
    cold_disc_z_median = []
    cold_disc_z_std = []
    cold_disc_z_perc16 = []
    cold_disc_z_perc84 = []

    warm_disc_r_mean = []
    warm_disc_r_median = []
    warm_disc_r_std = []
    warm_disc_r_perc16 = []
    warm_disc_r_perc84 = []
    warm_disc_rxy_mean = []
    warm_disc_rxy_median = []
    warm_disc_rxy_std = []
    warm_disc_rxy_perc16 = []
    warm_disc_rxy_perc84 = []
    warm_disc_z_mean = []
    warm_disc_z_median = []
    warm_disc_z_std = []
    warm_disc_z_perc16 = []
    warm_disc_z_perc84 = []

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", RuntimeWarning)

        # Get IDs by region of the last snapshot.
        _, rerun, resolution = parse(simulation)
        n_snapshots = make_snapshot_number(rerun, resolution)
        ids_by_region = get_stellar_ids_by_region(
            f"{simulation}_s{n_snapshots - 1}")

        for i in range(n_snapshots):
            if i >= settings.first_snap:
                time, redshift, props_by_region = \
                    get_avg_props_of_ids_in_snapshot(
                        simulation=f"{simulation}_s{i}",
                        ids_by_region=ids_by_region)
                
                times.append(time)
                redshifts.append(redshift)

                halo_r_mean.append(props_by_region[0][0])
                halo_r_median.append(props_by_region[0][1])
                halo_r_std.append(props_by_region[0][2])
                halo_r_perc16.append(props_by_region[0][3])
                halo_r_perc84.append(props_by_region[0][4])
                halo_rxy_mean.append(props_by_region[0][5])
                halo_rxy_median.append(props_by_region[0][6])
                halo_rxy_std.append(props_by_region[0][7])
                halo_rxy_perc16.append(props_by_region[0][8])
                halo_rxy_perc84.append(props_by_region[0][9])
                halo_z_mean.append(props_by_region[0][10])
                halo_z_median.append(props_by_region[0][11])
                halo_z_std.append(props_by_region[0][12])
                halo_z_perc16.append(props_by_region[0][13])
                halo_z_perc84.append(props_by_region[0][14])

                bulge_r_mean.append(props_by_region[1][0])
                bulge_r_median.append(props_by_region[1][1])
                bulge_r_std.append(props_by_region[1][2])
                bulge_r_perc16.append(props_by_region[1][3])
                bulge_r_perc84.append(props_by_region[1][4])
                bulge_rxy_mean.append(props_by_region[1][5])
                bulge_rxy_median.append(props_by_region[1][6])
                bulge_rxy_std.append(props_by_region[1][7])
                bulge_rxy_perc16.append(props_by_region[1][8])
                bulge_rxy_perc84.append(props_by_region[1][9])
                bulge_z_mean.append(props_by_region[1][10])
                bulge_z_median.append(props_by_region[1][11])
                bulge_z_std.append(props_by_region[1][12])
                bulge_z_perc16.append(props_by_region[1][13])
                bulge_z_perc84.append(props_by_region[1][14])

                cold_disc_r_mean.append(props_by_region[2][0])
                cold_disc_r_median.append(props_by_region[2][1])
                cold_disc_r_std.append(props_by_region[2][2])
                cold_disc_r_perc16.append(props_by_region[2][3])
                cold_disc_r_perc84.append(props_by_region[2][4])
                cold_disc_rxy_mean.append(props_by_region[2][5])
                cold_disc_rxy_median.append(props_by_region[2][6])
                cold_disc_rxy_std.append(props_by_region[2][7])
                cold_disc_rxy_perc16.append(props_by_region[2][8])
                cold_disc_rxy_perc84.append(props_by_region[2][9])
                cold_disc_z_mean.append(props_by_region[2][10])
                cold_disc_z_median.append(props_by_region[2][11])
                cold_disc_z_std.append(props_by_region[2][12])
                cold_disc_z_perc16.append(props_by_region[2][13])
                cold_disc_z_perc84.append(props_by_region[2][14])

                warm_disc_r_mean.append(props_by_region[3][0])
                warm_disc_r_median.append(props_by_region[3][1])
                warm_disc_r_std.append(props_by_region[3][2])
                warm_disc_r_perc16.append(props_by_region[3][3])
                warm_disc_r_perc84.append(props_by_region[3][4])
                warm_disc_rxy_mean.append(props_by_region[3][5])
                warm_disc_rxy_median.append(props_by_region[3][6])
                warm_disc_rxy_std.append(props_by_region[3][7])
                warm_disc_rxy_perc16.append(props_by_region[3][8])
                warm_disc_rxy_perc84.append(props_by_region[3][9])
                warm_disc_z_mean.append(props_by_region[3][10])
                warm_disc_z_median.append(props_by_region[3][11])
                warm_disc_z_std.append(props_by_region[3][12])
                warm_disc_z_perc16.append(props_by_region[3][13])
                warm_disc_z_perc84.append(props_by_region[3][14])

    # Create data frame.
    df = pd.DataFrame()
    df["Time"] = times
    df["Redshift"] = redshift

    df["HaloMeanSphericalRadius"] = halo_r_mean
    df["HaloMedianSphericalRadius"] = halo_r_median
    df["HaloStdSphericalRadius"] = halo_r_std
    df["HaloPerc16SphericalRadius"] = halo_r_perc16
    df["HaloPerc84SphericalRadius"] = halo_r_perc84
    df["HaloMeanCylindricalRadius"] = halo_rxy_mean
    df["HaloMedianCylindricalRadius"] = halo_rxy_median
    df["HaloStdCylindricalRadius"] = halo_rxy_std
    df["HaloPerc16CylindricalRadius"] = halo_rxy_perc16
    df["HaloPerc84CylindricalRadius"] = halo_rxy_perc84
    df["HaloMeanHeight"] = halo_z_mean
    df["HaloMedianHeight"] = halo_z_median
    df["HaloStdHeight"] = halo_z_std
    df["HaloPerc16Height"] = halo_z_perc16
    df["HaloPerc84Height"] = halo_z_perc84

    df["BulgeMeanSphericalRadius"] = bulge_r_mean
    df["BulgeMedianSphericalRadius"] = bulge_r_median
    df["BulgeStdSphericalRadius"] = bulge_r_std
    df["BulgePerc16SphericalRadius"] = bulge_r_perc16
    df["BulgePerc84SphericalRadius"] = bulge_r_perc84
    df["BulgeMeanCylindricalRadius"] = bulge_rxy_mean
    df["BulgeMedianCylindricalRadius"] = bulge_rxy_median
    df["BulgeStdCylindricalRadius"] = bulge_rxy_std
    df["BulgePerc16CylindricalRadius"] = bulge_rxy_perc16
    df["BulgePerc84CylindricalRadius"] = bulge_rxy_perc84
    df["BulgeMeanHeight"] = bulge_z_mean
    df["BulgeMedianHeight"] = bulge_z_median
    df["BulgeStdHeight"] = bulge_z_std
    df["BulgePerc16Height"] = bulge_z_perc16
    df["BulgePerc84Height"] = bulge_z_perc84

    df["ColdDiscMeanSphericalRadius"] = cold_disc_r_mean
    df["ColdDiscMedianSphericalRadius"] = cold_disc_r_median
    df["ColdDiscStdSphericalRadius"] = cold_disc_r_std
    df["ColdDiscPerc16SphericalRadius"] = cold_disc_r_perc16
    df["ColdDiscPerc84SphericalRadius"] = cold_disc_r_perc84
    df["ColdDiscMeanCylindricalRadius"] = cold_disc_rxy_mean
    df["ColdDiscMedianCylindricalRadius"] = cold_disc_rxy_median
    df["ColdDiscStdCylindricalRadius"] = cold_disc_rxy_std
    df["ColdDiscPerc16CylindricalRadius"] = cold_disc_rxy_perc16
    df["ColdDiscPerc84CylindricalRadius"] = cold_disc_rxy_perc84
    df["ColdDiscMeanHeight"] = cold_disc_z_mean
    df["ColdDiscMedianHeight"] = cold_disc_z_median
    df["ColdDiscStdHeight"] = cold_disc_z_std
    df["ColdDiscPerc16Height"] = cold_disc_z_perc16
    df["ColdDiscPerc84Height"] = cold_disc_z_perc84

    df["WarmDiscMeanSphericalRadius"] = warm_disc_r_mean
    df["WarmDiscMedianSphericalRadius"] = warm_disc_r_median
    df["WarmDiscStdSphericalRadius"] = warm_disc_r_std
    df["WarmDiscPerc16SphericalRadius"] = warm_disc_r_perc16
    df["WarmDiscPerc84SphericalRadius"] = warm_disc_r_perc84
    df["WarmDiscMeanCylindricalRadius"] = warm_disc_rxy_mean
    df["WarmDiscMedianCylindricalRadius"] = warm_disc_rxy_median
    df["WarmDiscStdCylindricalRadius"] = warm_disc_rxy_std
    df["WarmDiscPerc16CylindricalRadius"] = warm_disc_rxy_perc16
    df["WarmDiscPerc84CylindricalRadius"] = warm_disc_rxy_perc84
    df["WarmDiscMeanHeight"] = warm_disc_z_mean
    df["WarmDiscMedianHeight"] = warm_disc_z_median
    df["WarmDiscStdHeight"] = warm_disc_z_std
    df["WarmDiscPerc16Height"] = warm_disc_z_perc16
    df["WarmDiscPerc84Height"] = warm_disc_z_perc84

    return df

In [84]:
def create_figure(simulation: str):
    df = get_avg_props_of_present_day_stars(simulation)

    fig = plt.figure(figsize=(7.2, 2.0))
    gs = fig.add_gridspec(nrows=1, ncols=3, hspace=0.0, wspace=0.16)
    axs = gs.subplots(sharex=True, sharey=True)

    axs[0].set_xlim(0, 14)
    axs[0].set_xticks([2, 4, 6, 8, 10, 12])
    for ax in axs:
        ax.set_xlabel("Time [Gyr]")

    axs[0].set_yscale("log")
    axs[0].set_ylim(2E-1, 4E2)
    axs[0].set_yticks([1E0, 1E1, 1E2])
    axs[0].set_yticklabels(["1", "10", "100"])

    axs[0].set_ylabel(r"$\mathrm{median} \left( r_{xy} \right)$ [ckpc]")
    axs[1].set_ylabel(r"$\mathrm{median} \left( r \right)$ [ckpc]")
    axs[2].set_ylabel(
        r"$\mathrm{median} \left( \left| z \right| \right)$ [ckpc]")

    for i, label in enumerate(settings.components):
        reduced_label = settings.component_labels[label].replace(" ", "")

        axs[0].plot(df["Time"], df[f"{reduced_label}MedianCylindricalRadius"],
                    lw=1.5,
                    color=settings.component_colors[label],
                    label=settings.component_labels[label])
        axs[0].fill_between(
            df["Time"],
            df[f"{reduced_label}Perc16CylindricalRadius"],
            df[f"{reduced_label}Perc84CylindricalRadius"],
            lw=0, color=settings.component_colors[label], alpha=0.2)

        axs[1].plot(df["Time"], df[f"{reduced_label}MedianSphericalRadius"],
                    lw=1.5,
                    color=settings.component_colors[label],
                    label=settings.component_labels[label])
        axs[1].fill_between(
            df["Time"],
            df[f"{reduced_label}Perc16SphericalRadius"],
            df[f"{reduced_label}Perc84SphericalRadius"],
            lw=0, color=settings.component_colors[label], alpha=0.2)
        
        axs[2].plot(df["Time"], df[f"{reduced_label}MedianHeight"], lw=1.5,
                    color=settings.component_colors[label],
                    label=settings.component_labels[label])
        axs[2].fill_between(
            df["Time"],
            df[f"{reduced_label}Perc16Height"],
            df[f"{reduced_label}Perc84Height"],
            lw=0, color=settings.component_colors[label], alpha=0.2)

    for ax in axs:
        ax.plot([4.0] * 2, ax.get_ylim(), lw=1.5, ls="--",
                color="tab:gray", zorder=-5)
        add_redshift(ax=ax)

    axs[0].text(x=0.95,
            y=0.95,
            s=r"$\texttt{" + simulation.upper() + "}$",
            size=7.0, transform=axs[0].transAxes,
            ha='right', va='top',
            )

    axs[2].legend(loc="upper right", framealpha=0, fontsize=6.0)

    fig.savefig(f"../images/stellar_position_evolution/{simulation}.pdf")
    plt.close(fig)

    fig = plt.figure(figsize=(7.2, 2.0))
    gs = fig.add_gridspec(nrows=1, ncols=4, hspace=0.0, wspace=0.0)
    axs = gs.subplots(sharex=True, sharey=True)

    axs[0].set_xscale("log")
    axs[0].set_xlim(0.1, 200)
    axs[0].set_xticks([1, 10, 100], labels=["1", "10", "100"])
    for ax in axs:
        ax.set_xlabel(r"$r_{xy}$ [ckpc]")

    axs[0].set_yscale("log")
    axs[0].set_ylabel(r"$\left| z \right|$ [ckpc]")
    axs[0].set_ylim(0.1, 200)
    axs[0].set_yticks([1, 10, 100], labels=["1", "10", "100"])

    for i, label in enumerate(settings.component_labels):
        reduced_lbl = settings.component_labels[label].replace(" ", "")

        mcl, norm, cmap = multi_color_line(
                x=df[f"{reduced_lbl}MeanCylindricalRadius"].to_numpy(),
                y=df[f"{reduced_lbl}MeanHeight"].to_numpy(),
                c=df["Time"].to_numpy(),
                cmap="winter_r", vmin=0.0, vmax=14.0, lw=3.0,
                return_params=True)
        axs[i].add_collection(mcl)
        cbax = axs[i].inset_axes([0.1, 0.89, 0.6, 0.025],
                                transform=axs[i].transAxes)
        cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbax,
                        orientation="horizontal")
        cb.outline.set_visible(False)
        cb.set_ticks([])
        cb.ax.tick_params(length=0.0)
        axs[i].text(s=r"\textbf{MEAN}", x=0.72, y=0.9025, ha="left",
                    color=cmap(norm(14.0)),
                    va="center", fontsize=4.0, transform=axs[i].transAxes)

        # Median    
        mcl, norm, cmap = multi_color_line(
            x=df[f"{reduced_lbl}MedianCylindricalRadius"].to_numpy(),
            y=df[f"{reduced_lbl}MedianHeight"].to_numpy(),
            c=df["Time"].to_numpy(),
            cmap="spring_r", vmin=0.0, vmax=14.0, lw=3.0, return_params=True)
        axs[i].add_collection(mcl)
        cbax = axs[i].inset_axes([0.1, 0.85, 0.6, 0.025],
                                transform=axs[i].transAxes)
        cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbax,
                        orientation="horizontal")
        cb.outline.set_visible(False)
        cb.set_ticks([0, 2, 4, 6, 8, 10, 12, 14])
        cb.set_ticklabels([0, 2, 4, 6, 8, 10, 12, 14], fontsize=5.0)
        cb.ax.tick_params(length=0.0)
        axs[i].text(s=r"\textbf{MEDIAN}", x=0.72, y=0.8625, ha="left",
                    va="center", color=cmap(norm(14.0)), fontsize=4.0,
                    transform=axs[i].transAxes)

        component_txt = r"\textbf{" \
            + str(settings.component_labels[label]) + r"}"
        axs[i].text(s=component_txt,
                    x=0.95, y=0.05, ha="right", va="bottom",
                    fontsize=5.0, transform=axs[i].transAxes)
        axs[i].text(s="Time [Gyr]", x=0.4, y=0.93, ha="center",
                    va="bottom", fontsize=5.0, transform=axs[i].transAxes)
        axs[i].plot(axs[i].get_xlim(), axs[i].get_ylim(),
                    ls="--", lw=0.5, c='k')

    axs[0].text(x=0.05,
                y=0.05,
                s=r"$\texttt{" + simulation.upper() + "}$",
                size=7.0, transform=axs[0].transAxes,
                ha='left', va='bottom',
                )

    fig.savefig(f"../images/stellar_position_trajectories/{simulation}.pdf")
    plt.close(fig)

In [93]:
for galaxy in range(16, 31):
    create_figure(simulation=f"au{galaxy}_or_l4")
    if galaxy in settings.reruns:
        create_figure(simulation=f"au{galaxy}_re_l4")