# metrics plot

In [None]:
import matplotlib.pyplot as plt
import mpl_scatter_density
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
from matplotlib import cm
from matplotlib import colors
from matplotlib.colors import ListedColormap
from astropy.visualization import LogStretch
from astropy.visualization.mpl_normalize import ImageNormalize
from scipy.special import softmax
from scipy.stats import gaussian_kde

In [None]:
data = np.load("z_pred.npz")
test_id = data["test_id"]
z_spec = data["z_spec"]
z_phot = data["z_phot"]

In [None]:
params = {
    "legend.fontsize": "x-large",
    "axes.labelsize": "x-large",
    "axes.titlesize": "x-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
    "figure.facecolor": "w",
    "xtick.top": True,
    "ytick.right": True,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "font.family": "serif",
    "mathtext.fontset": "dejavuserif",
}
plt.rcParams.update(params)

In [None]:
def hodges_lehmann(data, max_pairs=1e6, random_seed=200):
    """The Hodges-Lehmann estimator.

    Adapted from code written by Rongpu Zhou.

    Args:
        data (1D array): Data set for which the estimator is being
            computed.
        max_pairs (int): If number of pairs is larger than this,
            randomly sample pairs.
        random_seed (int): Seed for randomly sampling pairs.

    Returns:
        float: H-L estimate
    """

    import itertools

    max_pairs = int(max_pairs)
    n_data = len(data)
    n_pairs = n_data * (n_data - 1) / 2

    if n_data == 0:
        raise ValueError("Must pass in non-empty array.")

    if n_pairs <= max_pairs:
        # non-identical indices
        ind1, ind2 = np.array(list(itertools.combinations(np.arange(n_data), 2))).T
        pair_means = np.mean([data[ind1], data[ind2]], axis=0)

        #  identical indices
        pair_means = np.concatenate([pair_means, data])

    else:
        if random_seed is not None:
            np.random.seed(random_seed)

        ind1, ind2 = np.random.choice(n_data, size=(max_pairs, 2)).transpose()
        pair_means = np.mean([data[ind1], data[ind2]], axis=0)

    return np.median(pair_means)

In [None]:
def better_step(bin_edges, y, ax=None, **kwargs):
    """A 'better' version of matplotlib's step function
    
    Given a set of bin edges and bin heights, this plots the thing
    that I wish matplotlib's ``step`` command plotted. All extra
    arguments are passed directly to matplotlib's ``plot`` command.
    
    Args:
        bin_edges: The bin edges. This should be one element longer than
            the bin heights array ``y``.
        y: The bin heights.
        ax (Optional): The axis where this should be plotted.
    
    """
    new_x = [a for row in zip(bin_edges[:-1], bin_edges[1:]) for a in row]
    new_y = [a for row in zip(y, y) for a in row]
    if ax is None:
        ax = plt.gca()
    ax.plot(new_x, new_y, **kwargs)
    return ax

In [None]:
class Metrics(object):
    """Produce metrics for the model.

    Args:
        z_phot (array): Predicted photometric redshifts.
        z_spec (array): Measured spectroscopic redshifts.
    """

    def __init__(
        self,
        z_phot,
        z_spec,
        z_min=None,
        z_max=None,
        outlier_threshold=None,
        mag=None,
        **kwargs,
    ):
        z_mask = (z_spec >= z_min) & (z_spec <= z_max)
        self.z_phot = z_phot[z_mask]
        self.z_spec = z_spec[z_mask]
        self.z_min = z_min
        self.z_max = z_max
        self.outlier_threshold = outlier_threshold
        self.mag = mag

        # TODO make convenience functions for all
        # Normalized residuals as defined in Cohen et al. (2000) Section 3
        self.delta_z_norm = (self.z_phot - self.z_spec) / (1 + self.z_spec)
        # Normalized median absolute deviation
        self.sigma_nmad = 1.4826 * np.median(
            np.abs(self.delta_z_norm - np.median(self.delta_z_norm))
        )
        self.bias = np.mean(self.delta_z_norm)
        # Number of objects larger than outlier threshold
        self.n_outlier = np.sum(np.abs(self.delta_z_norm) > self.outlier_threshold)

        # Outlier percentage
        self.percent_outlier = self.n_outlier * 100.0 / len(self.z_spec)

    def _gaussian(self, x, mean=0, sigma=1):
        return np.exp((-0.5 * ((x - mean) / sigma) ** 2)) / np.sqrt(2 * np.pi) / sigma

    def phot_vs_spec(self, show=False, ax=None, fig=None, **kwargs):
        """Photo-z vs. spec-z."""

        if ax is None:
            fig, ax = plt.subplots(
                subplot_kw={"projection": "scatter_density"}, **kwargs,
            )

        print(f"Normalized MAD: {self.sigma_nmad:.6f}")
        print(f"{self.outlier_threshold:.2f} outliers: {self.percent_outlier:.6f}%")

        x = np.linspace(self.z_min, self.z_max, 10)
        outlier_upper = x + self.outlier_threshold * (1 + x)
        outlier_lower = x - self.outlier_threshold * (1 + x)
        ax.plot(x, outlier_upper, "k--")
        ax.plot(x, outlier_lower, "k--")

        # Define new cmap viridis_white
        viridis = cm.get_cmap("viridis", 256)
        newcolors = viridis(np.linspace(0, 1, 256))
        white = np.array([1, 1, 1, 1])
        newcolors[:1, :] = white
        viridis_white = ListedColormap(newcolors, name="viridis_white")

        # Define new cmap cubehelix_white
        # See seaborn website to select colors and rotation
        cubehelix = sns.cubehelix_palette(start=0, rot=-0.5, as_cmap=True, reverse=True)
        newcolors = cubehelix(np.linspace(0, 1, 256))
        white = np.array([1, 1, 1, 1])
        newcolors[:1, :] = white
        cubehelix_white = ListedColormap(newcolors, name="cubehelix_white")

        # Plot scatter density
        norm = ImageNormalize(vmin=0.0, vmax=70, stretch=LogStretch())

        scatter_density = ax.scatter_density(
            self.z_spec,
            self.z_phot,
            cmap=viridis_white,
            dpi=30,
            downres_factor=1,
#             norm=norm,
        )
        cbar = fig.colorbar(scatter_density, fraction=0.046, pad=0.04,)
        cbar.ax.tick_params(labelsize=20)
        cbar.set_label(label="Number of galaxies per pixel", fontsize=30)

        ax.plot(x, x, linewidth=1.5, color="grey")

        ax.set_xlim([self.z_min, self.z_max])
        ax.set_ylim([self.z_min, self.z_max])
        ax.set_xlabel(r"$z_{\mathrm{spec}}$", fontsize=40)
        ax.set_ylabel(r"$z_{\mathrm{phot}}$", fontsize=40)
        ax.yaxis.grid(alpha=0.8, ls="--")
        ax.xaxis.grid(alpha=0.8)
        ax.set_aspect("equal")
        xticks = ax.xaxis.get_major_ticks()
        xticks[0].label1.set_visible(False)
        ax.tick_params(axis="both", which="major", labelsize=25)
        ax.tick_params(axis="both", which="minor", labelsize=25)

        textstr = "\n".join(
            (
                r"$\sigma_{\mathrm{NMAD}}=%.5f$" % (self.sigma_nmad,),
                r"$\mathrm{f}_{\mathrm{outlier}}=%.2f$" % (self.percent_outlier,) + "%",
                r"$\langle \frac{\Delta z}{1+z_{\mathrm{spec}}} \rangle=%.5f$"
                % (self.bias),
            )
        )

        # these are matplotlib.patch.Patch properties
        props = dict(boxstyle="round, pad=0.7", facecolor="white", alpha=1)

        # place a text box in upper left in axes coords
        ax.text(
            0.05,
            0.97,
            textstr,
            transform=ax.transAxes,
            fontsize=30,
            verticalalignment="top",
            bbox=props,
        )

        if show:
            plt.show()

        return fig, ax

    def metrics_vs_z(self, show=False, ax=None, fig=None, **kwargs):
        """σ_NMAD and bias as a function of redshift"""
        num_bins = 10
        _, bins = pd.qcut(self.z_spec, num_bins, retbins=True)  # equal population bins
        #         bins = np.linspace(self.z_min, self.z_max, numbins+1) #equal width bins
        bias_bin = np.zeros(num_bins)
        z_bins_mean = np.zeros(num_bins)
        sigma_nmad_bins = np.zeros(num_bins)

        for i in range(num_bins):
            mask = (self.z_spec >= bins[i]) & (self.z_spec < bins[i + 1])
            bias_bin[i] = hodges_lehmann(self.delta_z_norm[mask], max_pairs=1e6)
            sigma_nmad_bins[i] = 1.4826 * np.median(
                (np.abs(self.delta_z_norm[mask] - np.median(self.delta_z_norm[mask])))
            )

        if ax is None:
            fig, ax = plt.subplots(
                subplot_kw={"projection": "scatter_density"}, **kwargs
            )
        norm = ImageNormalize(vmin=0.0, vmax=300, stretch=LogStretch())
        ax.scatter_density(
            self.z_spec,
            self.delta_z_norm,
            cmap="Greys",
            dpi=15,
            downres_factor=1,
            alpha=0.6,
            #             norm=norm,
        )
        #         ax.plot(
        #             self.z_spec, self.delta_z_norm,ls ="", marker=".", alpha=1, color="k", markersize=0.2
        #         )

        better_step(bins, bias_bin, ax=ax, color="C1", linewidth=3, label="bias")
        better_step(
            bins,
            sigma_nmad_bins,
            ax=ax,
            color="C0",
            ls="--",
            linewidth=3,
            label=r"$\sigma_{\mathrm{NMAD}}$",
        )

        #         # plot lines of constant z_phot
        #         z_phot_fixed = np.linspace(self.z_min, self.z_max, 5)
        #         x = np.linspace(self.z_min, self.z_max, 100)
        #         for it in z_phot_fixed:
        #             ax.plot(x, (it - x) / (1 + x), "--", color="gray", lw=1, alpha=0.5)

        ax.set_ylabel(r"$\dfrac{\Delta z}{1 + z_\mathrm{spec}}$", fontsize=40)
        ax.set_xlabel("$z_{\mathrm{spec}}$", fontsize=40)
        ax.axhline(0, linestyle="--", color="black", lw=3)
        ax.tick_params(axis="both", which="major", labelsize=25)
        ax.tick_params(axis="both", which="minor", labelsize=25)
        ax.set_xlim(self.z_min, self.z_max)
        ax.set_ylim(-0.013, 0.013)
        ax.grid()
        ax.legend(loc="lower right", fontsize=30)

        if show:
            plt.show()

        return fig, ax

    def metrics_vs_mag(self, show=False, ax=None, fig=None, **kwargs):
        """σ_NMAD and bias as a function of magnitude"""
        num_bins = 10
        _, bins = pd.qcut(self.mag, num_bins, retbins=True)  # equal population bins
        bias_bin = np.zeros(num_bins)
        z_bins_mean = np.zeros(num_bins)
        sigma_nmad_bins = np.zeros(num_bins)

        for i in range(num_bins):
            mask = (self.mag >= bins[i]) & (self.mag < bins[i + 1])
            bias_bin[i] = hodges_lehmann(self.delta_z_norm[mask], max_pairs=1e6)
            sigma_nmad_bins[i] = 1.4826 * np.median(
                (np.abs(self.delta_z_norm[mask] - np.median(self.delta_z_norm[mask])))
            )

        if ax is None:
            fig, ax = plt.subplots(
                subplot_kw={"projection": "scatter_density"}, **kwargs
            )
        norm = ImageNormalize(vmin=0.0, vmax=300, stretch=LogStretch())
        ax.scatter_density(
            self.mag,
            self.delta_z_norm,
            cmap="Greys",
            dpi=15,
            downres_factor=1,
            alpha=0.6,
            #             norm=norm,
        )
        #         ax.plot(
        #             self.z_spec, self.delta_z_norm,ls ="", marker=".", alpha=1, color="k", markersize=0.2
        #         )

        better_step(bins, bias_bin, ax=ax, color="C1", linewidth=3, label="bias")
        better_step(
            bins,
            sigma_nmad_bins,
            ax=ax,
            color="C0",
            ls="--",
            linewidth=3,
            label=r"$\sigma_{\mathrm{NMAD}}$",
        )
        print(sigma_nmad_bins)

        # plot lines of constant z_phot
        z_phot_fixed = np.linspace(self.z_min, self.z_max, 5)
        x = np.linspace(self.z_min, self.z_max, 100)
        for it in z_phot_fixed:
            ax.plot(x, (it - x) / (1 + x), "--", color="gray", lw=1, alpha=0.5)

        ax.set_ylabel(r"$\dfrac{\Delta z}{1 + z_\mathrm{spec}}$", fontsize=40)
        ax.set_xlabel("$r$ magnitude", fontsize=40)
        ax.axhline(0, linestyle="--", color="black", lw=3)
        ax.tick_params(axis="both", which="major", labelsize=25)
        ax.tick_params(axis="both", which="minor", labelsize=25)
        ax.set_xlim(15, 17.9)
        ax.set_ylim(-0.013, 0.013)
        ax.grid()
        ax.legend(loc="lower left", fontsize=30)

        if show:
            plt.show()

        return fig, ax

    def error_dist(self, show=False, ax=None, fig=None, **kwargs):
        """Histogram of normalized redshift residuals."""
        print("Bias: {:.6f}".format(self.bias))
        print("Sigma MAD: {:.6f}".format(self.sigma_nmad))

        if ax is None:
            fig, ax = plt.subplots(**kwargs)

        pop, bins, patches = ax.hist(
            self.delta_z_norm,
            #             range=(-1 * self.outlier_threshold, self.outlier_threshold),
            bins=150,
            histtype="stepfilled",
            color="C0",
            alpha=0.5,
            density=True,
        )

        bin_width = bins[1] - bins[0]
        x = np.linspace(bins.min(), bins.max(), 501)
        ax.plot(
            x,
            self._gaussian(x, self.bias, self.sigma_nmad) * pop.sum() * bin_width,
            c="C1",
            ls="-",
            lw=3,
        )
        ax.grid(alpha=0.5)
        ax.set_ylabel("Relative Frequency", fontsize=40)
        ax.set_xlabel(r"$\dfrac{\Delta z}{1 + z_\mathrm{spec}}$", fontsize=40)
        ax_pad = 0.01
        ax.axvspan(
            self.outlier_threshold,
            self.outlier_threshold + ax_pad,
            color="gray",
            alpha=0.3,
            lw=0,
        )
        ax.axvspan(
            -1 * self.outlier_threshold - ax_pad,
            -1 * self.outlier_threshold,
            color="gray",
            alpha=0.3,
            lw=0,
        )
        ax.set_xlim(
            [-1 * self.outlier_threshold - ax_pad, self.outlier_threshold + ax_pad]
        )
        ax.tick_params(axis="both", which="major", labelsize=25)
        ax.tick_params(axis="both", which="minor", labelsize=25)
        yticks = ax.yaxis.get_major_ticks()
        yticks[0].label1.set_visible(False)
        if show:
            plt.show()

        return fig, ax

    def full_diagnostic(self, figsize=(26, 13), show=False):
        fig = plt.figure(figsize=figsize)
        ax0 = plt.subplot2grid((2, 2), (0, 0), rowspan=2, projection="scatter_density")
        ax1 = plt.subplot2grid((2, 2), (0, 1), projection="scatter_density")
        ax2 = plt.subplot2grid((2, 2), (1, 1))
        ax0 = self.phot_vs_spec(ax=ax0, fig=fig)
        ax1 = self.metrics_vs_z(ax=ax1)
        ax2 = self.error_dist(ax=ax2)

        if show:
            plt.show()

        return ax0, ax1, ax2

In [None]:
metrics = Metrics(z_phot, z_spec, 0, 0.4, 0.05)

In [None]:
fig, ax = metrics.phot_vs_spec(show=True, figsize=(12, 12))
fig.savefig("./figs/redshift_comparison.pdf", dpi=300, bbox_inches="tight")

Bootstrap errors on the above

In [None]:
rng = np.random.default_rng()
bootmask = rng.choice(len(z_phot),size=(len(z_phot), 10000), replace=True)

In [None]:
z_phot_boot = z_phot[bootmask]
z_spec_boot = z_spec[bootmask]

In [None]:
delta_z_norm = (z_phot_boot - z_spec_boot) / (1 + z_spec_boot)
sigma_nmad = 1.4826 * np.median(
            np.abs(delta_z_norm - np.median(delta_z_norm, axis=0)), axis=0
        )
bias = np.mean(delta_z_norm, axis=0)
n_outlier = np.sum(np.abs(delta_z_norm) > 0.05, axis=0)
percent_outlier = n_outlier * 100.0 / len(z_spec)

In [None]:
print(np.std(sigma_nmad))
print(np.std(percent_outlier))
print(np.std(bias))

In [None]:
cat = pd.DataFrame(np.load("/data/bid13/photoZ/data/pasquet2019/sdss_vagc.npz", allow_pickle=True)["labels"])
cat = cat.set_index(cat["specObjID"])
cat = cat.loc[test_id]

In [None]:
metrics.mag = cat["dered_petro_r"]
fig, ax = plt.subplots(1,2, figsize=(24,8), sharey= True, subplot_kw={"projection":"scatter_density"})
metrics.metrics_vs_z(show=False,ax=ax[0])
metrics.metrics_vs_mag(show=False, ax=ax[1])
ax[1].set_ylabel("")
fig.subplots_adjust(wspace=0.1)
fig.savefig("./figs/metrics_vs_z-mag.pdf", dpi=300, bbox_inches="tight")

In [None]:
fig, ax = metrics.error_dist(show=True, figsize=(12, 8))
fig.savefig("./figs/residual_dist.pdf", dpi=300, bbox_inches="tight")

# Plot performance vs data size

In [None]:
data = pd.read_csv("performance.txt", delim_whitespace=True)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(60,20))
ax1, ax2 = axs

N_GAL = 516525

ax1.plot(
    100 * data["frac_train"],
    data["sigma_nmad"],
    lw=10,
    c="C0",
    marker="o",
    markersize=25,
    label=r"This work",
    zorder=100,
)
ax1.plot(
    100 * data["frac_train"],
    data["sigma_nmad_psqt"],
    lw=8,
    c="gray",
    ls = ":",
    marker="o",
    markersize=25,
    label=r"Pasquet et al. 2019",
)

ax1.plot(
    100 * data["frac_train"],
    data["sigma_nmad_hyt"],
    lw=8,
    c="darkgrey",
    ls="--",
    marker="o",
    markersize=25,
    label=r"Hayat et al. 2021",
)

ax1.scatter(
    100 * data["frac_train"],
     data["sigma_nmad_beck"],
    marker="x",
    s=1000,
    c="k",
    linewidths=5,
    label=r"Beck et al. 2016",
)

ax1.grid(ls ="--", lw=2, alpha=0.8)
ax1.set_ylim(0.0075, 0.0145)
ax1.set_xlabel("Size of training set (%)", fontsize=80, labelpad=50)
ax1.set_ylabel(r"$\sigma_{\mathrm{NMAD}}$", fontsize=80, labelpad=40)
ax1.tick_params(axis="both", which="major", labelsize=60)
ax1.tick_params(axis="both", which="minor", labelsize=60)
ax1.tick_params(axis='y', which='major', pad=10)
ax1.set_xscale("log")
ax1.set_xticks([1,2,20,50,80])
ax1.set_xticklabels(["1","2","20","50","80"])
ax1.legend(loc="lower left", prop={"size": 50})

ax1_ = ax1.twiny()
ax1_.set_xlim(N_GAL*np.array(ax1.get_xlim())/100)
ax1_.set_xscale("log")
ax1_.set_xticks(N_GAL*np.array([0.01,0.02,0.2,0.5,0.8]))
ax1_.set_xticklabels((N_GAL*np.array([0.01,0.02,0.2,0.5,0.8])).astype(int).astype(str), rotation=45, ha="left")
ax1_.tick_params(axis="both", which="major", labelsize=60)
ax1_.tick_params(axis="both", which="minor", labelsize=60)
ax1_.set_xlabel("Size of training set (count)", fontsize=80, labelpad=50)

ax2.plot(
    100 * data["frac_train"],
    data["f_outlier"],
    #     ls="--",
    lw=10,
    c="C0",
    marker="o",
    markersize=25,
    label=r"This work",
    zorder=100,
)
ax2.plot(
    100 * data["frac_train"],
    data["f_outlier_psqt"],

    lw=8,
    c="gray",
    marker="o",
    ls=":",
    markersize=25,
    label=r"Pasquet et al. 2019",
)
ax2.plot(
    100 * data["frac_train"],
    data["f_outlier_hyt"],
    ls="--",
    lw=8,
    c="darkgrey",
    marker="o",
    markersize=25,
    label=r"Hayat et al. 2021",
)
ax2.scatter(
    100 * data["frac_train"],
    data["f_outlier_beck"],
    marker="x",
    s=1000,
    c="k",
    linewidths=5,
    label=r"Beck et al. 2016", zorder=6
)

ax2.set_ylabel(r"$\mathrm{f}_{\mathrm{outlier}}$ (%)", fontsize=80, labelpad=40)
ax2.set_xlabel("Size of training set (%)", fontsize=80, labelpad=50)

# ax2.set_ylim(0.15, 1.45)
ax2.set_xscale("log")
ax2.set_yscale("log")
ax2.set_yticks([0.2,.3,.4,.5,.6,.7,.8,.9,1])
ax2.set_yticklabels([ "0.2", "0.3", "0.4","0.5","0.6","0.7","0.8","0.9","1"])
ax2.set_xticks([1,2,20,50,80])
ax2.set_xticklabels(["1","2","20","50","80"])
ax2.grid(ls ="--", lw=2, alpha=0.8)
ax2.tick_params(axis="both", which="major", labelsize=60)
ax2.tick_params(axis="both", which="minor", labelsize=60)
ax2.tick_params(axis='y', which='major', pad=10)

ax2.legend(loc="lower left", prop={"size": 50})

ax2_ = ax2.twiny()
ax2_.set_xlim(N_GAL*np.array(ax2.get_xlim())/100)
ax2_.set_xscale("log")
ax2_.set_xticks(N_GAL*np.array([0.01,0.02,0.2,0.5,0.8]))
ax2_.set_xticklabels((N_GAL*np.array([0.01,0.02,0.2,0.5,0.8])).astype(int).astype(str), rotation=45,ha="left")
ax2_.tick_params(axis="both", which="major", labelsize=60)
ax2_.tick_params(axis="both", which="minor", labelsize=60)
ax2_.set_xlabel("Size of training set (count)", fontsize=80, labelpad=50)
fig.align_labels()
fig.savefig("./figs/performance_vs_data.pdf", dpi=300, bbox_inches="tight")

### Plot above but for ARAA

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(60,20))
ax1, ax2 = axs

N_GAL = 516525

ax1.plot(
    100 * data["frac_train"],
    data["sigma_nmad"],
    lw=10,
    c="C0",
    marker="o",
    markersize=25,
    label=r"Dey et al. 2021",
    zorder=100,
)
ax1.plot(
    100 * data["frac_train"],
    data["sigma_nmad_psqt"],
    lw=8,
    c="C0",
    ls = ":",
    marker="o",
    markersize=25,
    label=r"Pasquet et al. 2019",
)

ax1.plot(
    100 * data["frac_train"],
    data["sigma_nmad_hyt"],
    lw=8,
    c="C0",
    ls="--",
    marker="o",
    markersize=25,
    label=r"Hayat et al. 2021",
)

ax1.scatter(
    100 * data["frac_train"],
     data["sigma_nmad_beck"],
    marker="x",
    s=1000,
    c="C1",
    linewidths=10,
    label=r"Beck et al. 2016",
)

ax1.grid(ls ="--", lw=2, alpha=0.8)
ax1.set_ylim(0.0075, 0.0145)
ax1.set_xlabel("Size of training set (%)", fontsize=80, labelpad=50)
ax1.set_ylabel(r"$\sigma_{\mathrm{NMAD}}$", fontsize=80, labelpad=40)
ax1.tick_params(axis="both", which="major", labelsize=60)
ax1.tick_params(axis="both", which="minor", labelsize=60)
ax1.tick_params(axis='y', which='major', pad=10)
ax1.set_xscale("log")
ax1.set_xticks([1,2,20,50,80])
ax1.set_xticklabels(["1","2","20","50","80"])
ax1.legend(loc="lower left", prop={"size": 50})

ax1_ = ax1.twiny()
ax1_.set_xlim(N_GAL*np.array(ax1.get_xlim())/100)
ax1_.set_xscale("log")
ax1_.set_xticks(N_GAL*np.array([0.01,0.02,0.2,0.5,0.8]))
ax1_.set_xticklabels((N_GAL*np.array([0.01,0.02,0.2,0.5,0.8])).astype(int).astype(str), rotation=45, ha="left")
ax1_.tick_params(axis="both", which="major", labelsize=60)
ax1_.tick_params(axis="both", which="minor", labelsize=60)
ax1_.set_xlabel("Size of training set (count)", fontsize=80, labelpad=50)

ax2.plot(
    100 * data["frac_train"],
    data["f_outlier"],
    #     ls="--",
    lw=10,
    c="C0",
    marker="o",
    markersize=25,
    label=r"Dey et al. 2021",
    zorder=100,
)
ax2.plot(
    100 * data["frac_train"],
    data["f_outlier_psqt"],

    lw=8,
    c="C0",
    marker="o",
    ls=":",
    markersize=25,
    label=r"Pasquet et al. 2019",
)
ax2.plot(
    100 * data["frac_train"],
    data["f_outlier_hyt"],
    ls="--",
    lw=8,
    c="C0",
    marker="o",
    markersize=25,
    label=r"Hayat et al. 2021",
)
ax2.scatter(
    100 * data["frac_train"],
    data["f_outlier_beck"],
    marker="x",
    s=1000,
    c="C1",
    linewidths=10,
    label=r"Beck et al. 2016", zorder=6
)

ax2.set_ylabel(r"$\mathrm{f}_{\mathrm{outlier}}$ (%)", fontsize=80, labelpad=40)
ax2.set_xlabel("Size of training set (%)", fontsize=80, labelpad=50)

# ax2.set_ylim(0.15, 1.45)
ax2.set_xscale("log")
ax2.set_yscale("log")
ax2.set_yticks([0.2,.3,.4,.5,.6,.7,.8,.9,1])
ax2.set_yticklabels([ "0.2", "0.3", "0.4","0.5","0.6","0.7","0.8","0.9","1"])
ax2.set_xticks([1,2,20,50,80])
ax2.set_xticklabels(["1","2","20","50","80"])
ax2.grid(ls ="--", lw=2, alpha=0.8)
ax2.tick_params(axis="both", which="major", labelsize=60)
ax2.tick_params(axis="both", which="minor", labelsize=60)
ax2.tick_params(axis='y', which='major', pad=10)

ax2.legend(loc="lower left", prop={"size": 50})

ax2_ = ax2.twiny()
ax2_.set_xlim(N_GAL*np.array(ax2.get_xlim())/100)
ax2_.set_xscale("log")
ax2_.set_xticks(N_GAL*np.array([0.01,0.02,0.2,0.5,0.8]))
ax2_.set_xticklabels((N_GAL*np.array([0.01,0.02,0.2,0.5,0.8])).astype(int).astype(str), rotation=45,ha="left")
ax2_.tick_params(axis="both", which="major", labelsize=60)
ax2_.tick_params(axis="both", which="minor", labelsize=60)
ax2_.set_xlabel("Size of training set (count)", fontsize=80, labelpad=50)
fig.align_labels()
fig.savefig("./figs/performance_vs_data_ARAA.pdf", dpi=300, bbox_inches="tight")