# Star Formation

In [17]:
import matplotlib.pyplot as plt
import pandas as pd
from scipy.signal import savgol_filter
from multiprocessing import Pool
import os

In [18]:
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import numpy as np

In [19]:
from auriga.snapshot import Snapshot
from auriga.images import figure_setup
from auriga.settings import Settings
from auriga.support import make_snapshot_number
from auriga.paths import Paths
from auriga.parser import parse

In [20]:
figure_setup()
settings = Settings()

In [21]:
DISC_STD_CIRC = settings.disc_std_circ
DISC_MIN_CIRC = settings.disc_min_circ
COLD_DISC_DELTA_CIRC = settings.cold_disc_delta_circ
BULGE_MAX_SPECIFIC_ENERGY = -0.6
SUFFIX = "_02"

In [22]:
def calculate_sfr_by_region(simulation: str) -> np.ndarray:
    _, _, _, snapshot = parse(simulation)
    settings = Settings()

    if snapshot >= 40:  # Here, settings.first_snap doesn't work
        s = Snapshot(simulation, loadonlytype=[0, 1, 2, 3, 4, 5])
        s.tag_particles_by_region(
            disc_std_circ=DISC_STD_CIRC,
            disc_min_circ=DISC_MIN_CIRC,
            cold_disc_delta_circ=COLD_DISC_DELTA_CIRC,
            bulge_max_specific_energy=BULGE_MAX_SPECIFIC_ENERGY)
        sfr = s.calculate_sfr_by_region()
    else:
        sfr = [np.nan] * 4

    return np.array(sfr)

In [23]:
def calculate_sfr_by_region_evolution(simulation: str) -> pd.DataFrame:
    galaxy, rerun, resolution = parse(simulation)
    n_snapshots = make_snapshot_number(rerun, resolution)
    snapnums = [f"{simulation}_s{i}" for i in range(n_snapshots)]
    sfr = np.array(Pool().map(calculate_sfr_by_region, snapnums))

    # Read time
    paths = Paths(galaxy, rerun, resolution)
    time = pd.read_csv(
        f"../{paths.results}temporal_data.csv", usecols=["Time_Gyr"])

    # Create dataframe
    data = {"SFR_Msun/yr_H": sfr[:, 0], 
            "SFR_Msun/yr_B": sfr[:, 1], 
            "SFR_Msun/yr_CD": sfr[:, 2], 
            "SFR_Msun/yr_WD": sfr[:, 3],
            "SFR_Msun/yr": np.sum(sfr, axis=1),
            "Time_Gyr": time["Time_Gyr"].to_numpy()}

    return pd.DataFrame(data)

In [24]:
fig = plt.figure(figsize=(7.2, 7.2))
gs = fig.add_gridspec(nrows=6, ncols=5, hspace=0.0, wspace=0.0)
axs = gs.subplots(sharex=True, sharey=True)

for ax in axs.flatten():
    ax.set_xlim(0, 14)
    ax.set_ylim(0.1, 20),
    ax.set_xticks([2, 4, 6, 8, 10, 12])
    ax.set_yscale("log")
    ax.set_yticks([0.1, 1, 10])
    ax.set_yticklabels([0.1, 1, 10])
    ax.set_xlabel("Time [Gyr]")
    ax.set_ylabel(r"SFR [$\mathrm{M}_\odot \, \mathrm{yr}^{-1}$]")
    ax.label_outer()

for idx, ax in enumerate(axs.flatten()):
    galaxy = idx + 1
    simulation = f"au{galaxy}_or_l4"
    sfr = calculate_sfr_by_region_evolution(simulation)
    sfr.dropna(inplace=True)
    for component in settings.components:
        ax.plot(sfr["Time_Gyr"],
                savgol_filter(sfr[f"SFR_Msun/yr_{component}"], 5, 1),
                c=settings.component_colors[component],
                lw=1.0, label=settings.component_labels[component], zorder=10)
    ax.plot(sfr["Time_Gyr"], savgol_filter(sfr["SFR_Msun/yr"], 5, 1),
            c='k', lw=1.0, label="Total", zorder=10)
    ax.text(x=0.95, y=0.95, size=6.0,
            s=r"$\texttt{" + simulation.upper() + "}$",
            ha='right', va='top', transform=ax.transAxes)

    if ax == axs[0, 0]:
        ax.legend(loc="upper left", framealpha=0.0, fontsize=4.0)

    ax.xaxis.label.set_size(8.0)
    ax.yaxis.label.set_size(8.0)

    fig.savefig(f"../images/sfr_by_region/originals{SUFFIX}.pdf")

plt.close(fig)