In [None]:
import os
from itertools import product
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from joblib import Parallel, delayed
from tqdm.auto import tqdm

plt.style.use("../project/nbody6/plot/style.mplstyle")


In [None]:
metric_dict = {
    "n_binary": (r"N_\mathrm{bin.\,sys.}", None),
    "n_hard_binary": (r"N_\mathrm{hard\;bin.\,sys.}", None),
    "n_unresolved_binary": (r"N_\mathrm{unres.\,bin.\,sys.}", None),
    "n_wide_binary": (r"N_\mathrm{wide\,bin.\,sys.}", None),
    "bin_frac": (
        r"f_\mathrm{bin.\,sys.}",
        r"\dfrac{N_\mathrm{bin.\,sys.}}{N_\mathrm{bin.\,sys.}+N_\mathrm{single}}",
    ),
    "hard_bin_frac": (
        r"f_\mathrm{hard\;bin.\,sys.}",
        r"\dfrac{N_\mathrm{hard\;bin.\,sys.}}{N_\mathrm{bin.\,sys.}+N_\mathrm{single}}",
    ),
    "unresolved_bin_frac": (
        r"f_\mathrm{unres.\,bin.\,sys.}",
        r"\dfrac{N_\mathrm{unres.\,bin.\,sys.}}{N_\mathrm{bin.\,sys.}+N_\mathrm{single}}",
    ),
}


## Prepare Timestamp-Aligned Annular Stats DataFrame

In [None]:
load_dotenv()

OUTPUT_BASE = Path(os.getenv("OUTPUT_BASE"))
annular_stats_root_path = (OUTPUT_BASE / "annular_stats").resolve()
if not annular_stats_root_path.exists() or not annular_stats_root_path.is_dir():
    raise FileNotFoundError(
        f"{annular_stats_root_path} does NOT exist or is NOT a directory"
    )

# aligned parquet path
aligned_parquet_path = OUTPUT_BASE / "aligned_annular_stats.parquet"

if aligned_parquet_path.exists():
    aligned_df = pd.read_parquet(aligned_parquet_path)
else:
    full_annular_stats_df = (
        pd.concat(
            [
                pd.read_csv(annular_stats_root_path / f)
                for f in os.listdir(annular_stats_root_path)
                if f.endswith(".csv")
            ],
            ignore_index=True,
        )
        .sort_values(
            by=["init_gc_radius", "init_metallicity", "init_mass_lv", "init_pos"]
        )
        .reset_index(drop=True)
    )
    cat_keys = [
        "init_gc_radius",
        "init_metallicity",
        "init_mass_lv",
        "init_pos",
    ]

    filtered_annular_stats_df = (
        full_annular_stats_df[
            (full_annular_stats_df["dist_key"] == "dist_dc_r_half_mass")
        ]
        .assign(
            bin_frac=lambda df: df["n_binary"] / (df["n_binary"] + df["n_single"]),
            hard_bin_frac=lambda df: df["n_hard_binary"]
            / (df["n_binary"] + df["n_single"]),
            unresolved_bin_frac=lambda df: df["n_unresolved_binary"]
            / (df["n_binary"] + df["n_single"]),
        )
        .astype({col: "category" for col in cat_keys})
    )
    # delete full_annular_stats_df to save memory
    del full_annular_stats_df

    uni_timestamp_grid = np.arange(
        0, filtered_annular_stats_df["timestamp"].max() + 1, 1
    )

    def process_group(group_data):
        attr_dict, group_df = group_data
        return (
            group_df.groupby("timestamp", observed=True)[list(metric_dict.keys())]
            .mean()
            .reindex(np.union1d(group_df["timestamp"].to_numpy(), uni_timestamp_grid))
            .sort_index()
            .interpolate("index", limit_area="inside")
            .reindex(uni_timestamp_grid)
            .assign(
                **dict(zip(["galactic_x", "radius"] + cat_keys, attr_dict)),
                timestamp=lambda df: df.index.values,
            )
        )

    groups = list(
        filtered_annular_stats_df.groupby(
            ["galactic_x", "radius"] + cat_keys,
            observed=True,
            sort=False,
            group_keys=False,
        )
    )
    print(f"Using {(n_cpu := os.cpu_count())} CPU cores for parallel processing")
    aligned_dfs = Parallel(n_jobs=n_cpu)(
        delayed(process_group)(group)
        for group in tqdm(groups, desc="submit processing tasks", leave=False)
    )
    aligned_df = pd.concat(aligned_dfs, ignore_index=True).dropna(
        subset=list(metric_dict.keys()), how="all"
    )
    # save to parquet for future use
    aligned_df.to_parquet(aligned_parquet_path, index=False, compression="zstd")

In [None]:
agg_df = (
    aligned_df.groupby(
        [
            "galactic_x",
            "init_gc_radius",
            "init_metallicity",
            "init_mass_lv",
            "radius",
            "timestamp",
        ],
        observed=True,
        sort=False,
    )[list(metric_dict.keys())]
    .agg(["mean", "std", "median"])
    .reset_index()
    .rename(
        columns=lambda x: "_".join([str(c) for c in x if c]).rstrip("_")
        if isinstance(x, tuple)
        else x
    )
)

In [None]:
for quantile in [0.1, 0.25, 0.5, 0.75, 0.9]:
    print(
        f"{quantile}".rjust(5)
        + " quantile radius -> "
        + f"{agg_df['radius'].quantile(quantile)}".rjust(4)
    )

## Visualize Annular Statistics Over Time

In [None]:
timestamp_to_plot = [0, 10, 50, 100, 200, 300]

for timestamp_to_plot, metric_to_plot in (
    metric_pbar := tqdm(
        product(
            timestamp_to_plot,
            list(metric_dict.keys()),
        ),
        total=len(timestamp_to_plot) * len(metric_dict),
        leave=False,
        dynamic_ncols=True,
    )
):
    metric_pbar.set_description(f"Plotting `{metric_to_plot}` @ {timestamp_to_plot}Myr")
    # target_time = 50
    # metric_to_plot = "bin_frac"

    # setup export path
    fig_export_path = (
        OUTPUT_BASE / "figures" / "annular" / f"{metric_to_plot}-{timestamp_to_plot}Myr"
    )
    if not fig_export_path.exists():
        fig_export_path.mkdir(parents=True, exist_ok=True)

    # plot figure
    plot_df = agg_df[agg_df["timestamp"] == timestamp_to_plot].copy()
    radii = sorted(plot_df["radius"].unique())

    mass_levels = list(range(1, 9))[::-1]
    attr_pairs = [(2, 4), (2, 8), (6, 8), (14, 4), (14, 8), (14, 12)]

    cmap = mpl.colors.ListedColormap(
        [
            "#b2182b",
            "#d6604d",
            "#f4a582",
            "#fddbc7",
            "#d1e5f0",
            "#92c5de",
            "#4393c3",
            "#2166ac",
        ]
    )
    norm = mpl.colors.BoundaryNorm(
        boundaries=np.arange(0.5, 8.5 + 1, 1), ncolors=cmap.N
    )

    attr_pairs = [(2, 4), (2, 8), (6, 8), (14, 4), (14, 8), (14, 12)]

    for fig_idx, (dist_pc, group_df) in (
        dist_pbar := tqdm(
            enumerate(plot_df.groupby("galactic_x", observed=True, sort=False)),
            total=plot_df["galactic_x"].nunique(),
            leave=False,
            dynamic_ncols=True,
        )
    ):
        dist_pbar.set_description(f"Plotting distance={dist_pc} pc")

        fig, axs = plt.subplots(
            nrows=2,
            ncols=3,
            figsize=(17, 8),
            dpi=300,
            constrained_layout=True,
            gridspec_kw=dict(hspace=0.06, wspace=0.08),
        )

        metric_label_tex, metric_formula_tex = metric_dict[metric_to_plot]
        # number
        if metric_formula_tex is None:
            metric_title = metric_label_tex
            y_scale = "log"
            y_plot_range = (0.9, 4000)
            y_major_locator = mpl.ticker.LogLocator(base=10.0, subs=[1.0])
            y_minor_locator = mpl.ticker.LogLocator(
                base=10.0, subs=np.arange(2, 10) * 0.1
            )

        # fraction
        else:
            metric_title = f"{metric_label_tex}={metric_formula_tex}"
            y_scale = "linear"
            y_plot_range = (0, 1)
            y_major_locator = mpl.ticker.MultipleLocator(0.2)
            y_minor_locator = mpl.ticker.MultipleLocator(0.1)

        fig.suptitle(
            rf"${metric_title}$ @{timestamp_to_plot}Myr, "
            rf"$d_\mathrm{{observer}}={dist_pc}\,\mathrm{{pc}}$",
            fontsize=24,
            y=1.04,
            va="bottom",
        )

        for (init_metallicity, init_gc_radius), ax in zip(attr_pairs, axs.flat):
            subgroup = group_df[
                (group_df["init_metallicity"] == init_metallicity)
                & (group_df["init_gc_radius"] == init_gc_radius)
            ]

            for mass_lv in mass_levels:
                mass_df = (
                    subgroup[subgroup["init_mass_lv"] == mass_lv]
                    .set_index("radius")
                    .reindex(radii)
                )
                if mass_df.empty:
                    continue

                ax.errorbar(
                    radii,
                    mass_df[(metric_to_plot, "mean")].values,
                    yerr=mass_df[(metric_to_plot, "std")].values,
                    fmt="o-",
                    capsize=3,
                    color=cmap(norm(mass_lv)),
                    alpha=0.8,
                    label=f"M={mass_lv}",
                )

                ax.set_xlim(0, 17)
                ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(3, offset=1))
                ax.xaxis.set_minor_locator(mpl.ticker.MultipleLocator(1))
                ax.set_xlabel(r"$d\,/\,r_\mathrm{hm}$")

                ax.set_yscale(y_scale)
                ax.set_ylim(*y_plot_range)
                ax.yaxis.set_major_locator(y_major_locator)
                ax.yaxis.set_minor_locator(y_minor_locator)

                ax.set_ylabel(rf"$\langle \mathrm{{{metric_label_tex}}} \rangle$")

                ax.set_title(
                    rf"$Z_{{init.}}={init_metallicity * 10e-4},\;"
                    rf"R_\mathrm{{gc,\,init.}}={init_gc_radius}\ \mathrm{{kpc}}$",
                    fontsize=20,
                    y=1.02,
                )

                ax.grid(ls=":", lw=0.8, c="darkgrey")

        sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
        cbar = fig.colorbar(
            sm,
            ax=axs,
            orientation="vertical",
            fraction=0.08,
            pad=0.03,
            ticks=mass_levels,
        )
        cbar.set_label(r"$M_\mathrm{tot.\,init.}$", fontsize=20)
        cbar.ax.tick_params(direction="in", length=3)

        # show only one figure for demonstration in notebook
        if fig_idx > 0:
            plt.close(fig)

        fig.savefig(
            fig_export_path
            / f"annular-{metric_to_plot}-{timestamp_to_plot}Myr-{dist_pc}pc.png",
            bbox_inches="tight",
            dpi=300,
        )