In [None]:
import numpy as np
import matplotlib.pyplot as plt
import contextlib
import os

import rubin_sim.maf as maf
from lbg_tools import TomographicBin

from u_band_strat import (
    single_col,
    colors,
    fig_dir,
    data_dir,
    det_bands,
    calc_lbg_density_metric,
)

In [None]:
m5 = np.linspace(23, 28, 100)

metrics = {band: dict() for band in "ugriz"}
for band in metrics:
    # Get the detection band
    det_band = det_bands[band]

    # Calculate number density curve
    metrics[band]["n"] = np.array(
        [TomographicBin(band, m5i).number_density for m5i in m5]
    )

    # WFD year 1 density
    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        drop = maf.MetricBundle.load(
            data_dir / "m5_maps" / f"baseline_v4_0_1yrs_ExgalM5_{band}.npz"
        ).metric_values
        det = maf.MetricBundle.load(
            data_dir / "m5_maps" / f"baseline_v4_0_1yrs_ExgalM5_{det_band}.npz"
        ).metric_values
    metrics[band]["y1"] = calc_lbg_density_metric(band, drop, det)

    # WFD year 10 density
    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        drop = maf.MetricBundle.load(
            data_dir / "m5_maps" / f"baseline_v4_0_10yrs_ExgalM5_{band}.npz"
        ).metric_values
        det = maf.MetricBundle.load(
            data_dir / "m5_maps" / f"baseline_v4_0_10yrs_ExgalM5_{det_band}.npz"
        ).metric_values
    metrics[band]["y10"] = calc_lbg_density_metric(band, drop, det)

    # COSMOS year 10 density
    with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
        drop = maf.MetricBundle.load(
            data_dir / "m5_maps_deep" / f"COSMOS_10yrs_ExgalM5_{band}.npz"
        ).metric_values
        det = maf.MetricBundle.load(
            data_dir / "m5_maps_deep" / f"COSMOS_10yrs_ExgalM5_{det_band}.npz"
        ).metric_values
    metrics[band]["cosmos"] = calc_lbg_density_metric(band, drop, det)

In [None]:
fig, ax = plt.subplots(figsize=single_col, dpi=150)

for band in metrics:
    # Plot number density curve
    ax.plot(m5, metrics[band]["n"], c=colors[band], label=band)

    # Plot markers for each survey stage
    for stage, marker, size in zip(
        ["y1", "y10", "cosmos"],
        ["|", ".", "*"],
        [100, 50, 25],
    ):
        ax.scatter(
            np.interp(metrics[band][stage], metrics[band]["n"], m5),
            metrics[band][stage],
            c=colors[band],
            marker=marker,
            s=size,
        )

ax.set(
    yscale="log",
    xlim=(m5.min(), m5.max()),
    ylim=(1e-3, 1e5),
    xlabel="5$\sigma$ depth in detection band",
    ylabel="Number density",
)
ax.legend(handlelength=1, frameon=False, fontsize=8, loc="lower right")

fig.savefig(fig_dir / "number_density.pdf", bbox_inches="tight")