In [None]:
import numpy as np
import matplotlib.pyplot as plt

from lbg_desc_forecast import load_joint_forecast, fig_dir

In [None]:
def plot_constraints(ax, p_to_marg, clustering, xcorr, lensing, clean=True):
    # Get list of parameters to fix
    p_to_fix = ["dz", "f_interlopers", "g_bias", "g_bias_inter", "mag_bias"]
    fix = []
    if p_to_marg != "all":
        for p in p_to_fix:
            if p != p_to_marg:
                for band in "ugr":
                    fix.append(f"{band}_{p}")

    # Get forecast for configuration
    srd, lbg_solo, joint = load_joint_forecast(
        10,
        clean=clean,
        fix=fix,
        clustering=clustering,
        xcorr=xcorr,
        lensing=lensing,
        set_prior=False,
    )

    ax.plot(*lbg_solo.contour_2d(["w0", "wa"]), c="C1", lw=0.5)
    ax.plot(*srd.contour_2d(["w0", "wa"]), c="k", ls="--", lw=0.5)
    ax.plot(*joint.contour_2d(["w0", "wa"]), c="C0", lw=0.5)

    # Calculate relative FOMs
    fom_srd = 1.0
    fom_lbg = lbg_solo.figure_of_merit(["w0", "wa"]) / srd.figure_of_merit(["w0", "wa"])
    fom_jnt = joint.figure_of_merit(["w0", "wa"]) / srd.figure_of_merit(["w0", "wa"])

    # Print text
    y = np.array([0.32, 0.2, 0.08])
    settings = dict(transform=ax.transAxes, va="center", ha="left", fontsize=8)
    ax.text(0.03, y[0], f"{fom_srd:.2f}", c="k", **settings)
    ax.text(0.03, y[1], f"{fom_lbg:.2f}", c="C1", **settings)
    ax.text(0.03, y[2], f"{fom_jnt:.2f}", c="C0", **settings)

In [None]:
fig, axes = plt.subplots(
    7,
    4,
    figsize=(5, 8),
    constrained_layout=True,
    sharex=True,
    sharey=True,
    dpi=400,
)

# Set axis labels
for ax in axes[-1, :]:
    ax.set_xlabel("$w_0$")
for ax in axes[:, 0]:
    ax.set_ylabel("$w_a$")

# Set titles for dataset combos
axes[0, 0].set_title(r"$g \times g$")
axes[0, 1].set_title(r"$\kappa_\text{CMB} \times g$")
axes[0, 2].set_title(r"$2 \times 2\,$pt")
axes[0, 3].set_title(r"$3 \times 2\,$pt")

# Set labels for systematic rows
for ax in axes[:, -1]:
    ax.yaxis.set_label_position("right")
ylabel_settings = dict(rotation=-90, va="bottom")
axes[0, -1].set_ylabel("no\nsystematics", **ylabel_settings)
axes[1, -1].set_ylabel("LBG bias", **ylabel_settings)
axes[2, -1].set_ylabel("magnification\nbias", **ylabel_settings)
axes[3, -1].set_ylabel("$\Delta z$", **ylabel_settings)
axes[4, -1].set_ylabel("interloper\nfraction", **ylabel_settings)
axes[5, -1].set_ylabel("interloper\nbias", **ylabel_settings)
axes[6, -1].set_ylabel("fiducial", **ylabel_settings)

# Plot clustering forecasts
combo = dict(clustering=True, xcorr=False, lensing=False)
plot_constraints(axes[0, 0], "none", **combo)
plot_constraints(axes[1, 0], "g_bias", **combo)
plot_constraints(axes[2, 0], "mag_bias", **combo)
plot_constraints(axes[3, 0], "dz", **combo)
plot_constraints(axes[4, 0], "f_interlopers", **combo)
plot_constraints(axes[5, 0], "g_bias_inter", **combo)
plot_constraints(axes[6, 0], "all", **combo)

# Plot xcorr forecasts
combo = dict(clustering=False, xcorr=True, lensing=False)
plot_constraints(axes[0, 1], "none", **combo)
plot_constraints(axes[1, 1], "g_bias", **combo)
plot_constraints(axes[2, 1], "mag_bias", **combo)
plot_constraints(axes[3, 1], "dz", **combo)
plot_constraints(axes[4, 1], "f_interlopers", **combo)
plot_constraints(axes[5, 1], "g_bias_inter", **combo)
plot_constraints(axes[6, 1], "all", **combo)

# Plot 2x2pt forecasts
combo = dict(clustering=True, xcorr=True, lensing=False)
plot_constraints(axes[0, 2], "none", **combo)
plot_constraints(axes[1, 2], "g_bias", **combo)
plot_constraints(axes[2, 2], "mag_bias", **combo)
plot_constraints(axes[3, 2], "dz", **combo)
plot_constraints(axes[4, 2], "f_interlopers", **combo)
plot_constraints(axes[5, 2], "g_bias_inter", **combo)
plot_constraints(axes[6, 2], "all", **combo)

# Plot 3x2pt forecasts
combo = dict(clustering=True, xcorr=True, lensing=True)
plot_constraints(axes[0, 3], "none", **combo)
plot_constraints(axes[1, 3], "g_bias", **combo)
plot_constraints(axes[2, 3], "mag_bias", **combo)
plot_constraints(axes[3, 3], "dz", **combo)
plot_constraints(axes[4, 3], "f_interlopers", **combo)
plot_constraints(axes[5, 3], "g_bias_inter", **combo)
plot_constraints(axes[6, 3], "all", **combo)

axes[0, 0].set(
    xlim=(-1.16, -0.84),
    ylim=(-0.55, 0.55),
)
plt.show()

fig.savefig(fig_dir / "forecast_systematics.pdf", bbox_inches="tight")

In [None]:
fig, axes = plt.subplots(
    2, 4, figsize=(6, 3.5), constrained_layout=True, sharex=True, sharey=True, dpi=400
)

combo = dict(clustering=True, xcorr=True, lensing=True)

axes[0, 0].set_title("No systematics")
plot_constraints(axes[0, 0], "none", **combo)

axes[0, 1].set_title("LBG galaxy bias")
plot_constraints(axes[0, 1], "g_bias", **combo)

axes[0, 2].set_title("Magnification bias")
plot_constraints(axes[0, 2], "mag_bias", **combo)

axes[1, 0].set_title("$\Delta z$")
plot_constraints(axes[1, 0], "dz", **combo)

axes[1, 1].set_title("Low-$z$ interlopers")
plot_constraints(axes[1, 1], "f_interlopers", **combo)

axes[1, 2].set_title("Low-$z$ galaxy bias")
plot_constraints(axes[1, 2], "g_bias_inter", **combo)

axes[1, 3].set_title("Everything")
plot_constraints(axes[1, 3], "all", **combo)

axes[0, 0].set(
    xlim=(-1.16, -0.84),
    ylim=(-0.55, 0.55),
)
for ax in axes[:, 0]:
    ax.set(ylabel="$w_a$")
for ax in axes[1, :]:
    ax.set(xlabel="$w_0$")

axes[0, -1].set_axis_off()