# Age-Metallicity Relation

In [35]:
from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from scipy.stats import binned_statistic
from scipy.stats import gaussian_kde
import pandas as pd
from tqdm import tqdm

In [2]:
from auriga.snapshot import Snapshot
from auriga.images import figure_setup, set_axs_configuration
from auriga.settings import Settings
from auriga.parser import parse

In [3]:
figure_setup()

In [4]:
settings = Settings()

In [5]:
n_bins = 14
age_range = (0, 14)

In [6]:
def read_data(simulation: str) -> tuple:
    """
    This method returns the age of the stars in the galaxy and the [Fe/H]
    metal abundance.

    Parameters
    ----------
    simulation : str
        The simulation to consider.

    Returns
    -------
    tuple
        The properties.
    """
    s = Snapshot(simulation=simulation,
                 loadonlytype=[4])
    s.add_stellar_age()
    s.add_metal_abundance(of="Fe", to="H")

    is_real_star = (s.type == 4) & (s.stellar_formation_time > 0)
    is_main_obj = (s.halo == s.halo_idx) & (s.subhalo == s.subhalo_idx)

    return (s.stellar_age[is_real_star & is_main_obj],
            s.metal_abundance["Fe/H"][is_real_star & is_main_obj])

In [7]:
def read_data_with_region(simulation: str) -> tuple:
    """
    This method returns the age of the stars in the galaxy, the [Fe/H]
    metal abundance, and the region tag.

    Parameters
    ----------
    simulation : str
        The simulation to consider.

    Returns
    -------
    tuple
        The properties.
    """
    s = Snapshot(simulation=simulation,
                 loadonlytype=[0, 1, 2, 3, 4, 5])
    s.add_stellar_age()
    s.add_circularity()
    s.add_reference_to_potential()
    s.add_normalized_potential()
    s.tag_particles_by_region(disc_std_circ=1.0,
                              disc_min_circ=0.4,
                              cold_disc_delta_circ=0.25,
                              bulge_max_specific_energy=-0.75)
    s.add_metal_abundance(of="Fe", to="H")

    is_real_star = (s.type == 4) & (s.stellar_formation_time > 0)
    is_main_obj = (s.halo == s.halo_idx) & (s.subhalo == s.subhalo_idx)

    return (s.stellar_age[is_real_star & is_main_obj],
            s.metal_abundance["Fe/H"][is_real_star & is_main_obj],
            s.region_tag[is_real_star & is_main_obj])

### Age-Metallicity for All Subhalo Stars

In [34]:
def calc_median_metal_of_sample(simulations: list):
    """
    This method calculates the median metallicity of the sample using in
    stellar age bins.

    Parameters
    ----------
    simulations : list
        A list of simulations to plot.
    """

    ages = []
    metallicities = []
    for simulation in simulations:
        age, fe_to_h = read_data(simulation=simulation)
        ages += list(age)
        metallicities += list(fe_to_h)
    
    stat, bin_edges, _ = binned_statistic(
        x=ages,
        values=metallicities,
        statistic=np.nanmedian,
        bins=n_bins,
        range=age_range
    )
    bin_centers = bin_edges[1:] - np.diff(bin_edges)[0] / 2

    return bin_centers, stat

In [35]:
def add_panel_for_galaxy(simulation: str, ax: plt.Axes):
    """
    This method adds a panel with the results for a given simulation.

    Parameters
    ----------
    simulation : str
        The simulation.
    ax : plt.Axes
        The ax to which to add the plot.
    """

    age, fe_to_h = read_data(simulation=simulation)
    _, _, _, im = ax.hist2d(age,
                            fe_to_h,
                            cmap='nipy_spectral',
                            bins=200,
                            range=[ax.get_xlim(), ax.get_ylim()],
                            norm=mcolors.LogNorm(vmin=1E0, vmax=1E3),
                            rasterized=True)
    stat, bin_edges, _ = binned_statistic(
        x=age,
        values=fe_to_h,
        statistic=np.nanmedian,
        bins=n_bins,
        range=age_range
    )
    bin_centers = bin_edges[1:] - np.diff(bin_edges)[0] / 2
    ax.plot(bin_centers, stat)
    ax.text(x=0.05,
            y=0.95,
            s=r"$\texttt{" + simulation.upper() + "}$",
            size=6.0, transform=ax.transAxes,
            ha='left', va='top',
            )
    
    return bin_centers, stat

In [36]:
def plot_age_metal_rel_for_sample(simulations: list,
                                  filename: str):
    """
    This method creates a plot of the metallicity (Fe/H metal abundance) as
    a function of the age of each star for the stars in the main object.

    Parameters
    ----------
    simulations : list
        A list of simulations to plot.
    filename : str
        The name of the output file.
    """

    n_simulations = len(simulations)

    fig = plt.figure(figsize=(7.2, 7.2))
    gs = fig.add_gridspec(nrows=6, ncols=5, hspace=0.0, wspace=0.0)
    axs = gs.subplots(sharex=True, sharey=True)

    set_axs_configuration(
        xlim=age_range, ylim=(-4, 3),
        xticks=[2, 4, 6, 8, 10, 12], yticks=[-3, -2, -1, 0, 1, 2],
        xlabel="Age [Gyr]", ylabel="[Fe/H]",
        axs=axs, n_used=n_simulations)

    # Add the median of the sample in each panel
    bin_centers, sample_median = calc_median_metal_of_sample(simulations)
    
    for idx, ax in enumerate(axs.flat):
        if idx < n_simulations:
            _, stat = add_panel_for_galaxy(simulations[idx], ax=ax)
            ax.plot(bin_centers, sample_median, c='k', ls="--")
            mse = ((stat - sample_median)**2).sum() / stat.shape[0]
            ax.text(x=0.05,
                y=0.05,
                s=r"$\mathrm{MSE} = " \
                    + str(np.round(mse, 3)).ljust(5, '0') + "$",
                size=6.0, transform=ax.transAxes,
                ha='left', va='bottom',
            )
        else:
            ax.axis("off")

        fig.savefig(f"../images/age_metallicity/{filename}.pdf")

    plt.close(fig)

In [37]:
# # Create plots for all galaxies in two figures
# settings = Settings()
# originals = [f"au{i}_or_l4_s127" for i in settings.galaxies]
# reruns = [f"au{i}_re_l4_s251" for i in settings.reruns]
# plot_age_metal_rel_for_sample(simulations=originals, filename="originals_s127")
# plot_age_metal_rel_for_sample(simulations=reruns, filename="reruns_s251")

### Age-Metallicity by Region

In [8]:
def plot_age_metal_rel_by_region_for_galaxy(simulation: str):
    """
    This method creates a plot of the metallicity (Fe/H metal abundance) as
    a function of the age of each star for the stars in the main object.

    Parameters
    ----------
    simulation : str
        The simulation to plot.
    """

    settings = Settings()

    fig = plt.figure(figsize=(7.4, 2.0))
    gs = fig.add_gridspec(nrows=1, ncols=4, hspace=0.0, wspace=0.0)
    axs = gs.subplots(sharex=True, sharey=True)

    _, rerun, resolution, snapshot = parse(simulation)

    # Median for all galaxies
    ages = []
    abundances = []
    region_tags = []
    galaxies = settings.reruns if rerun else settings.galaxies
    rerun_txt = "re" if rerun else "or"
    for galaxy in galaxies:
        age, metal_ab, region_tag = read_data_with_region(
            simulation=f"au{galaxy}_{rerun_txt}_l{resolution}_s{snapshot}")
        ages += list(age)
        abundances += list(metal_ab)
        region_tags += list(region_tag)
    ages = np.array(ages)
    abundances = np.array(abundances)
    region_tags = np.array(region_tags)
    sample_data = []
    for i in settings.component_tags.values():
        stat, _, _ = binned_statistic(
            x=ages[region_tags == i],
            values=abundances[region_tags == i],
            statistic=np.nanmedian,
            bins=n_bins,
            range=age_range
        )
        sample_data.append(stat)

    for ax in axs.flat:
        ax.tick_params(which='both', direction="in")
        ax.set_xlim(0, 14)
        ax.set_ylim(-4, 3)
        ax.set_xticks([2, 4, 6, 8, 10, 12])
        ax.set_yticks([-3, -2, -1, 0, 1, 2])
        ax.set_xlabel("Age [Gyr]")
        ax.set_ylabel("[Fe/H]")
        ax.label_outer()

    print("MEDIAN DATA")
    print("===========")

    age, metal_ab, region_tag = read_data_with_region(simulation=simulation)
    for idx, tag in enumerate(np.unique(region_tag)):
        is_region = (region_tag == tag)
        _, _, _, im = axs[idx].hist2d(
            age[is_region],
            metal_ab[is_region],
            cmap='nipy_spectral',
            bins=200,
            range=[axs[idx].get_xlim(), axs[idx].get_ylim()],
            norm=mcolors.LogNorm(vmin=1E0, vmax=1E3),
            rasterized=True)

        settings = Settings()
        axs[idx].text(x=0.05,
                      y=0.05,
                      s=settings.component_labels[settings.components[idx]],
                      size=8.0, transform=axs[idx].transAxes,
                      ha='left', va='bottom',
                      )

        print("")
        print(f"{settings.component_labels[settings.components[idx]]}")
        print(f"[Fe/H]: {np.nanmedian(metal_ab[is_region])}")
        print(f"Age [Gyr]: {np.nanmedian(age[is_region])}")

        axs[idx].plot(axs[idx].get_xlim(),
                      [np.nanmedian(metal_ab[is_region]),
                       np.nanmedian(metal_ab[is_region])],
                      lw=.25, color='k')
        axs[idx].plot([np.nanmedian(age[is_region]),
                       np.nanmedian(age[is_region])],
                      axs[idx].get_ylim(),
                      lw=.25, color='k')
        axs[idx].plot(np.nanmedian(age[is_region]),
                      np.nanmedian(metal_ab[is_region]),
                     marker='o', mfc='k', ms=2, mew=0)

        stat, bin_edges, _ = binned_statistic(
            x=age[is_region],
            values=metal_ab[is_region],
            statistic=np.nanmedian,
            bins=n_bins,
            range=age_range
        )
        bin_centers = bin_edges[1:] - np.diff(bin_edges)[0] / 2
        axs[idx].plot(bin_centers, stat, c='b', ls='-')

        axs[idx].plot(bin_centers, sample_data[idx], c='k', ls="--")

        mse = ((stat - sample_data[idx])**2).sum() / stat.shape[0]
        axs[idx].text(
            x=0.05,
            y=0.95,
            s="$\mathrm{MSE} = " + str(np.round(mse, 3)).ljust(5, '0') + "$",
            size=8.0, transform=axs[idx].transAxes,
            ha='left', va='top',
            )

    cbar = fig.colorbar(im, ax=axs[-1], orientation='vertical',
                        label=r'$N_\mathrm{stars}$',
                        pad=0)
    cbar.ax.set_yticks([1E0, 1E1, 1E2, 1E3])

    axs[0].text(x=axs[0].get_xlim()[0],
                y=axs[0].get_ylim()[1],
                s=r"$\texttt{" + simulation.upper() + "}$",
                size=8.0,
                ha='left', va='bottom',
                )

    fig.savefig(f"../images/age_metallicity_by_region/{simulation}.pdf")
    plt.close(fig)

In [70]:
for galaxy in range(1, 31):
    plot_age_metal_rel_by_region_for_galaxy(f"au{galaxy}_or_l4_s127")
    if galaxy in settings.reruns:
        plot_age_metal_rel_by_region_for_galaxy(f"au{galaxy}_re_l4_s251")

### Median Scatter

In [19]:
component = []
abundance_median = []
age_median = []

for galaxy in tqdm(settings.galaxies):
    age, metal_ab, region_tag = read_data_with_region(f"au{galaxy}_or_l4_s127")
    for _, tag in settings.component_tags.items():
        is_region = region_tag == tag
        component.append(tag)
        abundance_median.append(np.nanmedian(metal_ab[is_region]))
        age_median.append(np.nanmedian(age[is_region]))

component = np.array(component)
abundance_median = np.array(abundance_median)
age_median = np.array(age_median)



In [None]:
fig, axs = plt.subplots(
    figsize=(7.4, 2.0), nrows=1, ncols=4, sharey=True, sharex=True,
    gridspec_kw={"hspace": 0.0, "wspace": 0.0})

for ax in axs.flat:
    ax.grid(True, ls='-', lw=0.25, c="gainsboro")
    ax.tick_params(which='both', direction="in")
    ax.set_xlim(0, 14)
    ax.set_xticks([2, 4, 6, 8, 10, 12])
    ax.set_xlabel("Age [Gyr]")
    ax.set_ylim(-1.0, 0.7)
    ax.set_ylabel('median([Fe/H])')
    ax.set_axisbelow(True)
    ax.label_outer()

markers = ["o", "v", "^", "d"]
colors = list(settings.component_colors.values())
labels = list(settings.component_labels.values())

for i, ax in enumerate(axs.flat):
    for j in range(4):
        label = labels[j] if j == i else None
        color = colors[j] if j == i else "silver"
        zorder = 10 if j == i else 5
        ax.scatter(
            age_median[component == j],
            abundance_median[component == j],
            c=color, label=label, zorder=zorder,
            s=20, linewidths=0.4, edgecolors="white", marker=markers[j])

        if j == i:
            age_pdf = gaussian_kde(age_median[component == j])
            pdf_x = np.linspace(0, 14, 100)
            pdf_y = age_pdf(pdf_x) / np.max(age_pdf(pdf_x)) * 0.2 - 1.0
            ax.fill_between(
                x=pdf_x, y1=-1.0, y2=pdf_y, edgecolor=color,
                facecolor=mcolors.TABLEAU_COLORS[color] + "30")
            
            abundance_pdf = gaussian_kde(abundance_median[component == j])
            pdf_y = np.linspace(-1.0, 0.7, 100)
            pdf_x = abundance_pdf(pdf_y) / np.max(abundance_pdf(pdf_y)) * 2
            ax.fill_betweenx(
                y=pdf_y, x1=-pdf_x + 14, x2=14, edgecolor=color,
                facecolor=mcolors.TABLEAU_COLORS[color] + "30")

    ax.legend(loc="upper left", framealpha=0, fontsize=7.5)

fig.savefig("../images/age_metallicity_by_region/median_scatter.pdf")
plt.close(fig)

### MSE Analysis

In [8]:
df = pd.DataFrame()
df["Galaxy"] = [f"Au{i}" for i in range(1, 31)]
df["MSE"] = [0.033, 0.006, 0.019, 0.057, 0.004, 0.002, 0.061, 0.007, 0.012, 0.003, 0.043, 0.016, 0.009, 0.003, 0.099, 0.013, 0.039, 0.017, 0.032, 0.020, 0.009, 0.035, 0.010, 0.013, 0.006, 0.024, 0.007, 0.008, 0.034, 0.040]
df["MSE_H"] = [0.035, 0.065, 0.062, 0.021, 0.033, 0.042, 0.038, 0.048, 0.073, 0.023, 0.017, 0.013, 0.013, 0.010, 0.081, 0.035, 0.068, 0.071, 0.029, 0.027, 0.016, 0.039, 0.073, 0.104, 0.062, 0.044, 0.056, 0.041, 0.062, 0.031]
df["MSE_B"] = [0.019, 0.044, 0.020, 0.065, 0.016, 0.060, 0.054, 0.027, 0.004, 0.002, 0.048, 0.041, 0.002, 0.012, 0.167, 0.066, 0.028, 0.030, 0.057, 0.017, 0.026, 0.051, 0.014, 0.032, 0.024, 0.018, 0.032, 0.004, 0.025, 0.033]
df["MSE_CD"] = [0.051, 0.006, 0.011, 0.039, 0.010, 0.002, 0.064, 0.011, 0.009, 0.002, 0.045, 0.017, 0.002, 0.007, 0.085, 0.005, 0.016, 0.017, 0.042, 0.026, 0.006, 0.015, 0.004, 0.008, 0.007, 0.012, 0.004, 0.004, 0.040, 0.015]
df["MSE_WD"] = [0.053, 0.026, 0.032, 0.048, 0.014, 0.037, 0.076, 0.034, 0.017, 0.003, 0.047, 0.014, 0.004, 0.010, 0.106, 0.018, 0.024, 0.031, 0.053, 0.036, 0.010, 0.027, 0.041, 0.028, 0.022, 0.017, 0.029, 0.008, 0.026, 0.029]

In [49]:
colors = np.array(["tab:blue"] * 30)
colors[np.array(settings.groups["NotMilkyWayLike"]) - 1] = "tab:red"
colors[np.array(settings.groups["Excluded"]) - 1] = "tab:gray"

In [68]:
fig, axs = plt.subplots(figsize=(7.4, 2.0), ncols=4, 
                        gridspec_kw={"wspace": 0.0})

for ax in axs.flatten():
    ax.set_xlim(0, 0.12)
    ax.set_ylim(0, 0.12)
    ax.set_xticks([0.02, 0.04, 0.06, 0.08, 0.10])
    ax.set_xlabel(r"$\mathrm{MSE}$")
    ax.grid(True, ls='-', lw=0.25, c='silver')
    ax.set_ylabel(r"$\mathrm{MSE}_\mathrm{Component}$")
    ax.label_outer()

axs[0].scatter(df["MSE"], df[f"MSE_H"], zorder=10, c=colors, ls='-',
               s=20, linewidths=0.5, edgecolors="white")
axs[1].scatter(df["MSE"], df[f"MSE_B"], zorder=10, c=colors, ls='-',
               s=20, linewidths=0.5, edgecolors="white")
axs[2].scatter(df["MSE"], df[f"MSE_CD"], zorder=10, c=colors, ls='-',
               s=20, linewidths=0.5, edgecolors="white")
axs[3].scatter(df["MSE"], df[f"MSE_WD"], zorder=10, c=colors, ls='-',
               s=20, linewidths=0.5, edgecolors="white")

for i, comp in enumerate(settings.components):
    axs[i].text(x=0.5, y=0.95,
                s=settings.component_labels[comp],
                size=8.0, transform=axs[i].transAxes,
                ha='center', va='top')

fig.savefig(f"../images/age_metallicity_by_region/mse_analysis.pdf")
plt.close(fig)