# Stellar Mass by Region

In [61]:
from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
from multiprocessing import Pool

In [62]:
from auriga.snapshot import Snapshot
from auriga.images import figure_setup, set_axs_configuration
from auriga.settings import Settings

In [63]:
figure_setup()
settings = Settings()

In [64]:
def read_data(simulation: str) -> pd.DataFrame:
    """
    Return a DataFrame with data of interest.
    """
    settings = Settings()

    s = Snapshot(simulation=simulation,
                 loadonlytype=[0, 1, 2, 3, 4, 5])
    s.tag_particles_by_region(
        disc_std_circ=settings.disc_std_circ,
        disc_min_circ=settings.disc_min_circ,
        cold_disc_delta_circ=settings.cold_disc_delta_circ,
        bulge_max_specific_energy=settings.bulge_max_specific_energy)

    is_real_star = (s.type == 4) & (s.stellar_formation_time > 0)
    is_main_obj = (s.halo == s.halo_idx) & (s.subhalo == s.subhalo_idx)

    df = pd.DataFrame()
    df["RegionTag"] = s.region_tag[is_real_star & is_main_obj]
    df["Mass_Msun"] = s.mass[is_real_star & is_main_obj]

    return df

In [65]:
def get_statistics(df: pd.DataFrame) -> tuple:
    total_solar_masses = df["Mass_Msun"].sum()
    components_solar_masses = np.zeros(4)
    for i in range(len(settings.components)):
        components_solar_masses[i] = df["Mass_Msun"][
            df["RegionTag"] == i].sum()
    return components_solar_masses / total_solar_masses

In [66]:
simulations = [f"au{i}_or_l4_s127" for i in settings.galaxies]

In [67]:
fig = plt.figure(figsize=(7.2, 7.2))
gs = fig.add_gridspec(nrows=6, ncols=5, hspace=0.0, wspace=0.0)
axs = gs.subplots(sharex=True, sharey=True)

for ax in axs.flat:
    ax.tick_params(which='both', direction="in")
    ax.set_xlim(-0.5, 3.5)
    ax.set_xticks([0, 1, 2, 3])
    ax.set_xticklabels([])
    ax.set_ylim(0, 1)
    ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8])
    ax.set_ylabel(r"$f_\star$")
    ax.label_outer()

for i, simulation in enumerate(simulations):
    data = get_statistics(read_data(simulation))
    current_ax = axs.flatten()[i]
    current_ax.text(
        x=0.05, y=0.95, size=6.0,
        s=r"$\texttt{" + simulation.upper() + "}$",
        ha="left", va="top", transform=current_ax.transAxes)

    for j in range(len(data)):
        current_ax.bar(x=j, height=data[j],
                       color=list(settings.component_colors.values())[j],
                       width=0.5, linewidth=0)
        current_ax.text(j, data[j] + 0.025,
            s=r"$\textbf{" + str(int(np.round(100 * data[j], 0))) \
                + "\%" + "}$",
            c=list(settings.component_colors.values())[j],
            ha="center", va="bottom", size=5.0)
        
        if current_ax.get_subplotspec().is_last_row():
            current_ax.text(j, -0.05, size=6.0,
                            s=r"$\textbf{" + settings.components[j] + "}$",
                            c=list(settings.component_colors.values())[j],
                            ha="center", va="top")

    fig.savefig(
        f"../images/galaxy_decomposition/stellar_mass_distribution.pdf")
    plt.close(fig)