In [None]:
import numpy as np
import matplotlib.pyplot as plt

from lbg_desc_forecast import (
    FisherMatrix,
    fig_dir,
    get_lbg_mappers,
    fisher_dir,
)

In [None]:
# Best fit parameters from DESI
w0 = -0.752
wa = -0.86


def create_bin_data(year: int, bin: str) -> FisherMatrix:
    # Load Fisher Matrix
    lbg = FisherMatrix.load(
        fisher_dir / f"y{year}_fiducial_{bin}bin.fisher_matrix.npz",
    )

    # Set priors on nuisance params
    # - Interloper fraction priors are just vibes
    # - Interloper bias priors are described in default_lbg
    # - Mag bias are faint-end slope uncertainties from GOLDRUSH IV (Table 6)
    lbg.set_prior(
        **{
            f"{bin}_dz": 0.004 if (year == 10 and bin == "u") else np.inf,
            f"{bin}_f_interlopers": 0.05,
            f"{bin}_g_bias_inter": 0.10 if bin == "u" else 0.11,
            f"{bin}_mag_bias": dict(u=0.02, g=0.03, r=0.04)[bin],
        }
    )

    # Ignore neutrino mass and wa
    lbg = lbg.fix(["m_nu", "wa"])

    # Marginalize over all LBG nuisance params
    lbg = lbg.marginalize(lbg.keys[6:])

    # Marginalize everything else except for w
    lbg = lbg.marginalize(lbg.keys[:-1])

    # Get the mapper
    mapper = {b: m for b, m in zip("ugr", get_lbg_mappers(year))}[bin]

    # Get the mean redshift
    z, pz = mapper.tomographic_bin.pz
    z, pz = z[z > 1.5], pz[z > 1.5]
    z_mean = np.trapz(z * pz, z) / np.trapz(pz, z)

    # Get w(z)
    w = w0 + wa * (1 - 1 / (1 + z_mean))

    return z_mean, w, lbg.covariance[0, 0]

In [None]:
forecast = {1: dict(), 10: dict()}

for year in forecast:
    for bin in "ugr":
        forecast[year][bin] = create_bin_data(year, bin)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3), constrained_layout=True, dpi=150)

desi = np.array(
    [
        [0.26, -0.85, 0.10],
        [0.79, -1.16, 0.14],
        [1.31, -1.48, 0.51],
        [1.83, -1.09, 1.125],
        [2.79, -1.94, 0.42],
    ]
)
ax.errorbar(
    desi[:, 0],
    desi[:, 1],
    yerr=desi[:, 2],
    ls="",
    capsize=5,
    c="k",
    label="DESI DR2",
)

for i, data in enumerate(forecast[1].values()):
    label = "LSST Y1" if i == 0 else None
    ax.errorbar(*data, c="C0", capsize=5, ls="", label=label)

for i, data in enumerate(forecast[10].values()):
    label = "LSST Y10" if i == 0 else None
    ax.errorbar(*data, c="C1", capsize=5, ls="", label=label)

z = np.linspace(0, 5, 1000)
w = w0 + wa * (1 - 1 / (1 + z))
plt.plot(z, w, color="gray")

# Label for cosmological constant
ax.axhline(-1, c="silver", ls="--", zorder=0)
# ax.text(4.9, -0.98, "Cosmological constant", va="bottom", ha="right")

ax.set(xlabel="Redshift", ylabel="$w(z)$", xlim=(0, 5), ylim=(-2.5, 0.1))
ax.legend(frameon=False, loc="upper left")

fig.savefig(fig_dir / "binned_w_constraints.pdf", bbox_inches="tight")