# Evaluation of a selection of models on 60km -> 2.2km-4x over Birmingham

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv
    
import math
import os

import cftime
import iris
import iris.analysis.cartography
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pysteps
import scipy
import xarray as xr

from mlde_notebooks import plot_map, show_samples, distribution_figure, plot_mean_bias, plot_std_bias, scatter_plots, seasonal_distribution_figure, compute_gridspec, freq_density_plot
from mlde_notebooks.psd import plot_psd
from mlde_utils import cp_model_rotated_pole, TIME_PERIODS
from mlde_notebooks.data import prep_eval_data
from mlde_notebooks import create_map_fig, qq_plot, STYLES

In [None]:
def reasonable_quantiles(da):
    limit = int(np.log10(1/da.size))
    print(limit)
    return np.concatenate([np.linspace((1-10**(i+1))+(10**i), (1-10**i), 9) for i in range(-1, limit-1, -1)] + [[1.0]])

In [None]:
split = "val"
ensemble_members = [
    "01",
    "04",
    "05",
    "06",
    "07",
    "08",
    "09",
    "10",
    "11",
    "12",
    "13",
    "15",
]
samples_per_run = 3
data_configs = {
    "CPM": [
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "LR precip (interp)",
            "deterministic": True,
            "dataset": "bham_gcmx-4x_12em_linpr_eqvt_random-season",
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "stan",
            "label": "Diffusion 12em",
            "dataset": "bham_gcmx-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ],
    "GCM": [
        {
            "fq_model_id": "id-linpr",
            "checkpoint": "epoch-0",
            "input_xfm": "none",
            "label": "LR precip (interp)",
            "deterministic": True,
            "dataset": "bham_60km-4x_12em_linpr_eqvt_random-season",
        },
        {
            "fq_model_id": "score-sde/subvpsde/xarray_cncsnpp_continuous/bham-4x_12em_PslT4V4_random-season-IstanTsqrturrecen",
            "checkpoint": "epoch-20",
            "input_xfm": "pixelmmsstan",
            "label": "Diffusion bc 12em",
            "dataset": "bham_60km-4x_12em_psl-temp4th-vort4th_eqvt_random-season",
            "deterministic": False,
        },
    ],
}
highlighted_cpm_models = ["Diffusion 12em"]
# the datasets to use for comparisons like PSD which need default datasets with CPM-based hi-res precip and GCM-based lo-res precip respectively
simulation_pr_datasets = {
    "GCM": "bham_60km-4x_linpr_random-season",
    "CPM": "bham_60km-4x_linpr_random-season"
}
gcm_lr_lin_pr_dataset = "bham_60km-4x_12em_linpr_eqvt_random-season"
cpm_hr_pr_dataset = "bham_gcmx-4x_12em_linpr_eqvt_random-season"
desc = """
Describe in more detail the models being compared
"""

In [None]:
IPython.display.Markdown(desc)

In [None]:
merged_ds = { source: prep_eval_data(data_config, split, ensemble_members=ensemble_members, samples_per_run=samples_per_run) for source, data_config in data_configs.items() }
cpm_pr = merged_ds["CPM"]["target_pr"]
merged_ds

## Pixel distribution

In [None]:
quantiles = reasonable_quantiles(cpm_pr)

### CPM

In [None]:
distribution_figure(merged_ds["CPM"], cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

### GCM

In [None]:
distribution_figure(merged_ds["GCM"], cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

### Seasonal

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    seasonal_distribution_figure(merged_ds[source], cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

### Subregions

In [None]:
subregions = {
    "SE": dict(grid_latitude=slice(10, 26), grid_longitude=slice(38, 54)),
    "NW": dict(grid_latitude=slice(44, 60), grid_longitude=slice(18, 34)),
}

fig, axd = create_map_fig([["subregions"]])

plot_map(10*xr.ones_like(merged_ds["CPM"]["pred_pr"].isel(ensemble_member=0, sample_id=0, model=0, time=0)).isel(**subregions["NW"]), ax=axd["subregions"], style="precip")
plot_map(1*xr.ones_like(merged_ds["CPM"]["pred_pr"].isel(ensemble_member=0, sample_id=0, model=0, time=0)).isel(**subregions["SE"]), ax=axd["subregions"], style="precip")
    
axd["subregions"].set_extent([-2, 3, -2.5, 2.5], crs=cp_model_rotated_pole)

#### NW

In [None]:
srname = "NW"
srbnds = subregions[srname]

for source in merged_ds.keys():
    IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
    ds = merged_ds[source].isel(**srbnds)
    if source == "CPM":
        ds = ds.sel(model=highlighted_cpm_models)
    sr_cpm_pr = cpm_pr.isel(**srbnds)
    quantiles = reasonable_quantiles(sr_cpm_pr)
    distribution_figure(ds, sr_cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

##### NW Winter

In [None]:
srname="NW"
srseason="DJF"
for source in merged_ds.keys():
    IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
    srseason_mask = merged_ds[source].isel(**subregions[srname])["time.season"] == srseason
    srseason_mask_sample_ds = merged_ds[source].isel(**subregions[srname]).sel(time=srseason_mask)
    if source == "CPM":
        srseason_mask_sample_ds = srseason_mask_sample_ds.sel(model=highlighted_cpm_models)
    srseason_mask_cpm_pr = cpm_pr.isel(srbnds).sel(time=srseason_mask)
    quantiles = reasonable_quantiles(srseason_mask_cpm_pr)
    distribution_figure(srseason_mask_sample_ds, srseason_mask_cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

#### SE

In [None]:
srname = "SE"
srbnds = subregions[srname]
for source in merged_ds.keys():
    IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
    ds = merged_ds[source].isel(**srbnds)
    if source == "CPM":
        ds = ds.sel(model=highlighted_cpm_models)
    sr_cpm_pr = cpm_pr.isel(**srbnds)
    quantiles = reasonable_quantiles(sr_cpm_pr)
    distribution_figure(ds, cpm_pr.isel(**srbnds), quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

##### SE Summer

In [None]:
srname = "SE"
srseason = "JJA"
for source in merged_ds.keys():
    IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
    srseason_mask = merged_ds[source].isel(**subregions[srname])["time.season"] == srseason
    srseason_mask_sample_ds = merged_ds[source].isel(**subregions[srname]).sel(time=srseason_mask)
    if source == "CPM":
        srseason_mask_sample_ds = srseason_mask_sample_ds.sel(model=highlighted_cpm_models)
    srseason_mask_cpm_pr = cpm_pr.isel(srbnds).sel(time=srseason_mask)
    quantiles = reasonable_quantiles(srseason_mask_cpm_pr)
    distribution_figure(srseason_mask_sample_ds, srseason_mask_cpm_pr, quantiles, quantile_dims=["ensemble_member", "time", "grid_latitude", "grid_longitude"])

### Scatter plots

In [None]:
# for source in ds["source"].values:
#     IPython.display.display_html(f"<h5>{source}</h5>", raw=True)
#     scatter_plots(merged_ds.sel(source=source), cpm_pr)

## Mean distrubution

In [None]:
mean_ds = { source: ds[["target_pr", "pred_pr"]].mean(dim=["grid_latitude", "grid_longitude"]) for source, ds in merged_ds.items() }
cpm_mean_pr = mean_ds["CPM"]["target_pr"]
quantiles = reasonable_quantiles(cpm_mean_pr)

### CPM Scatter plots

In [None]:
source = "CPM"
fig = plt.figure(layout="constrained")
scatter_plots(mean_ds[source], fig=fig)
plt.show()

for season, seasonal_mean_ds in mean_ds[source].groupby("time.season"):
    IPython.display.display_html(f"<h4>{season}</h4>", raw=True)
    fig = plt.figure(layout="constrained")
    scatter_plots(seasonal_mean_ds, fig=fig)
    plt.show()

## Bias $\frac{\mu_{sample}-\mu_{CPM}}{\mu_{CPM}}$

### All

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    plot_mean_bias(merged_ds[source], cpm_pr)

### All variability

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    target_mean = cpm_pr.mean(dim=["time"])
    for model, model_pr in merged_ds[source]["pred_pr"].groupby("model", squeeze=False):
        IPython.display.display_html(f"<h4>{model}</h4>", raw=True)
        model_sample_mean_pr = model_pr.squeeze("model").mean(dim=["sample_id", "time"])
        bias = model_sample_mean_pr - target_mean
        bias_ratio = 100*bias / target_mean
        g = bias_ratio.plot.pcolormesh(col="ensemble_member", col_wrap=6, transform=cp_model_rotated_pole, subplot_kws=dict(projection=cp_model_rotated_pole), **STYLES["discreteBias"])
        for ax in g.axs.flatten():
            ax.coastlines()
        
        plt.show()

### Seasonal

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    for season, seasonal_ds in merged_ds[source].groupby("time.season"):
        IPython.display.display_html(f"<h5>{season}</h5>", raw=True)
        seasonal_cpm_pr = cpm_pr.sel(time=(cpm_pr["time.season"] == season))
        plot_mean_bias(seasonal_ds, seasonal_cpm_pr)

## Standard deviation $\sigma_{sample}$/$\sigma_{CPM}$

### All

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    plot_std_bias(merged_ds[source], cpm_pr)

### Seasonal

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    for season, seasonal_ds in merged_ds[source].groupby("time.season"):
        IPython.display.display_html(f"<h5>{season}</h5>", raw=True)
        seasonal_cpm_pr = cpm_pr.sel(time=(cpm_pr["time.season"] == season))
        plot_std_bias(seasonal_ds, seasonal_cpm_pr)

## Climate change signal

In [None]:
historical_ts_mask = (merged_ds["CPM"]["time"] >= TIME_PERIODS["historic"][0]) & (merged_ds["CPM"]["time"] <= TIME_PERIODS["historic"][1])
ccs_seasons = ["DJF", "JJA"]

### Pixel quantiles

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)
    IPython.display.display_html(f"<h5>Annual</h5>", raw=True)
    historical_cpm_pr = cpm_pr.sel(time=historical_ts_mask)
    quantiles = reasonable_quantiles(historical_cpm_pr)
    historical_cpm_pr_quantiles = historical_cpm_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
    
    fig, axd = plt.subplot_mosaic([TIME_PERIODS.keys()], figsize=(18, 6))        
    
    for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
        ts_mask = (merged_ds[source]["time"] >= ts_start) & (merged_ds[source]["time"] <= ts_end)
        
        ts_pred_pr = merged_ds[source].sel(time=ts_mask)["pred_pr"]
        
        ts_cpm_pr = cpm_pr.sel(time=ts_mask)
        
        ts_sample_quantiles = ts_pred_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
        ts_cpm_quantiles = ts_cpm_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
        ts_quantiles = xr.concat([ts_sample_quantiles, ts_cpm_quantiles.expand_dims(model=["\u200BCPM"])], dim="model")
        ax = axd[ts_key]
        qq_plot(ax, historical_cpm_pr_quantiles, ts_quantiles, title=f"{ts_key} sample quantiles vs historic CPM quantiles", xlabel="historic CPM precip (mm day$^{-1}$)", tr=200, guide_label=None)
    plt.show()

    fig, axes = plt.subplot_mosaic([TIME_PERIODS.keys()], figsize=(18, 6), constrained_layout=True)
    for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
        ts_mask = (merged_ds[source]["time"] >= ts_start) & (merged_ds[source]["time"] <= ts_end)
        
        ts_pred_pr = merged_ds[source].sel(time=ts_mask)["pred_pr"]
        
        ts_cpm_pr = cpm_pr.sel(time=ts_mask)
        
        ax = axes[ts_key]
        ts_pr = xr.concat([ts_pred_pr, ts_cpm_pr.expand_dims(model=[f"\u200B{ts_key} CPM"])], dim="model")
        freq_density_plot(ax, ts_pr, historical_cpm_pr, title=f"Log density of {ts_key} samples and CPM precip", target_label="Historic CPM", grouping_key="model")

    plt.show()

    for season in ccs_seasons:
        IPython.display.display_html(f"<h5>{season}</h5>", raw=True)
        season_mask = merged_ds[source]["time.season"] == season

        historical_cpm_pr = cpm_pr.sel(time=historical_ts_mask & season_mask)
        quantiles = reasonable_quantiles(historical_cpm_pr)
        historical_cpm_pr_quantiles = historical_cpm_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
        
        fig, axd = plt.subplot_mosaic([TIME_PERIODS.keys()], figsize=(18, 6))        
        
        for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
            ts_mask = (merged_ds[source]["time"] >= ts_start) & (merged_ds[source]["time"] <= ts_end)
            
            ts_pred_pr = merged_ds[source].sel(time=ts_mask & season_mask)["pred_pr"]
            
            ts_cpm_pr = cpm_pr.sel(time=ts_mask & season_mask)
            
            ts_sample_quantiles = ts_pred_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
            ts_cpm_quantiles = ts_cpm_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
            ts_quantiles = xr.concat([ts_sample_quantiles, ts_cpm_quantiles.expand_dims(model=["\u200BCPM"])], dim="model")
            ax = axd[ts_key]
            qq_plot(ax, historical_cpm_pr_quantiles, ts_quantiles, title=f"{season}: {ts_key} sample vs historic CPM quantiles", xlabel=f"historic {season} CPM precip (mm day$^{-1}$)", tr=200, guide_label=None)
        plt.show()

        
        fig, axes = plt.subplot_mosaic([TIME_PERIODS.keys()], figsize=(18, 6), constrained_layout=True)
        for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
            ts_mask = (merged_ds[source]["time"] >= ts_start) & (merged_ds[source]["time"] <= ts_end)
            
            ts_pred_pr = merged_ds[source].sel(time=ts_mask & season_mask)["pred_pr"]
            
            ts_cpm_pr = cpm_pr.sel(time=ts_mask & season_mask)
            
            ax = axes[ts_key]
            ts_pr = xr.concat([ts_pred_pr, ts_cpm_pr.expand_dims(model=[f"\u200B{ts_key} CPM"])], dim="model")
            freq_density_plot(ax, ts_pr, historical_cpm_pr, title=f"{season}: Log density of {ts_key} samples and CPM precip", target_label="Historic CPM", grouping_key="model")

        plt.show()
    
    for season in ccs_seasons:
        IPython.display.display_html(f"<h5>{season}</h5>", raw=True)
        season_mask = merged_ds[source]["time.season"] == season
        for model, model_ds in merged_ds[source].groupby("model", squeeze=False):
            fig, axes = plt.subplot_mosaic(
                    [TIME_PERIODS.keys()], figsize=(10.5, 3.5), constrained_layout=True
                )
            fig.suptitle(f"{model} quantile variability")
            for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
                ts_mask = (model_ds["time"] >= ts_start) & (model_ds["time"] <= ts_end)
                
                ts_cpm_pr = cpm_pr.sel(time=ts_mask & season_mask)
                quantiles = reasonable_quantiles(ts_cpm_pr)
                ts_cpm_quantiles = ts_cpm_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
                
                ts_pred_pr = model_ds.sel(time=ts_mask & season_mask)["pred_pr"].squeeze("model")
                ts_model_quantiles = ts_pred_pr.quantile(quantiles, dim=["ensemble_member", "time", "grid_latitude", "grid_longitude"])
                
                qq_plot(
                    axes[ts_key],
                    historical_cpm_pr_quantiles,
                    ts_model_quantiles,
                    title=f"{ts_key} sample vs historic CPM",
                    grouping_key="sample_id",
                    alpha=0.5,
                    show_legend=False,
                    xlabel="historic CPM precip (mm day$^{-1}$)",
                )
            plt.show()

### Mean change

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source}</h4>", raw=True)

    for season in ccs_seasons:
        IPython.display.display_html(f"<h5>{season}</h5>", raw=True)
        season_mask = merged_ds[source]["time.season"] == season
        cpm_pr_historical_mean = cpm_pr.sel(time=historical_ts_mask & season_mask).mean(dim=["ensemble_member", "time"])
        sample_historical_mean = merged_ds[source].sel(time=historical_ts_mask & season_mask)["pred_pr"].mean(dim=["ensemble_member", "time", "sample_id"])
        for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
            if ts_key == "historic": continue
            IPython.display.display_html(f"<h5>{ts_key}</h5>", raw=True)

            ts_mask = (merged_ds[source]["time"] >= ts_start) & (merged_ds[source]["time"] <= ts_end)
            
            cpm_pr_ts_mean = cpm_pr.sel(time=ts_mask & season_mask).mean(dim=["ensemble_member", "time"])
            cpm_ts_change = 100*(cpm_pr_ts_mean - cpm_pr_historical_mean)/cpm_pr_historical_mean
            
            sample_ts_mean = merged_ds[source].sel(time=ts_mask & season_mask)["pred_pr"].mean(dim=["ensemble_member", "time", "sample_id"])
            sample_ts_change = 100*(sample_ts_mean - sample_historical_mean)/sample_historical_mean
            sample_ts_diff = sample_ts_change - cpm_ts_change
            
            target_name = "CPM"
            models = merged_ds[source]["model"].values
            change_grid_spec = compute_gridspec(models, target_name)
            diff_grid_spec = compute_gridspec(models + " diff", "")
            grid_spec = np.concatenate([change_grid_spec, diff_grid_spec])
            fig, axd = plt.subplot_mosaic(
                grid_spec,
                figsize=(grid_spec.shape[1] * 5.5, grid_spec.shape[0] * 5.5),
                subplot_kw=dict(projection=cp_model_rotated_pole),
                constrained_layout=True,
            )
            axd[""].axis('off')
            
            ax = axd[target_name]
            plot_map(
                cpm_ts_change,
                ax,
                title=f"$(\\mu_{{CPM}}^{{{ts_key}}} - \\mu_{{CPM}}^{{hist}})/\\mu_{{CPM}}^{{hist}}$",
                style="discreteBias",
                add_colorbar=False,
            )
            
            for model in models:
                ax = axd[model]
                
                pcm = plot_map(
                    sample_ts_change.sel(model=model),
                    ax,
                    title=f"{model}: $(\\mu_{{ML}}^{{{ts_key}}} - \\mu_{{ML}}^{{hist}})/\\mu_{{ML}}^{{hist}}$",
                    style="discreteBias",
                    add_colorbar=False,
                )
                
                ax = axd[model + " diff"]
                
                change_pcm = plot_map(
                    sample_ts_diff.sel(model=model),
                    ax,
                    title=f"{model}: $((\\mu_{{ML}}^{{{ts_key}}} - \\mu_{{ML}}^{{hist}})/\\mu_{{ML}}^{{hist}}) - ((\\mu_{{CPM}}^{{{ts_key}}} - \\mu_{{CPM}}^{{hist}})/\\mu_{{CPM}}^{{hist}})$",
                    style="discreteBias",
                    # center=0,
                    add_colorbar=False,
                )
            
            fig.colorbar(pcm, ax=list(axd[model] for model in models)+[axd[target_name]], location="left", shrink=0.8, extend="both")
            fig.colorbar(change_pcm, ax=list(axd[model+" diff"] for model in models)+[axd[""]], location="left", shrink=0.8, extend="both")
            plt.show()

### Mean change variability

In [None]:
for source in merged_ds.keys():
    IPython.display.display_html(f"<h4>{source} variability</h4>", raw=True)

    for season in ccs_seasons:
        season_mask = merged_ds[source]["time.season"] == season
        cpm_pr_historical_mean = cpm_pr.sel(time=historical_ts_mask & season_mask).mean(dim=["time"])
        sample_historical_mean = merged_ds[source].sel(time=historical_ts_mask & season_mask)["pred_pr"].mean(dim=["time", "sample_id"])
        for ts_key, (ts_start, ts_end) in TIME_PERIODS.items():
            if ts_key == "historic": continue

            ts_mask = (merged_ds[source]["time"] >= ts_start) & (merged_ds[source]["time"] <= ts_end)

            cpm_pr_ts_mean = cpm_pr.sel(time=ts_mask & season_mask).mean(dim=["time"])
            cpm_ts_change = (cpm_pr_ts_mean - cpm_pr_historical_mean)/cpm_pr_historical_mean
            
        
            for model, model_pr in merged_ds[source]["pred_pr"].sel(time=ts_mask & season_mask).groupby("model", squeeze=False):
                model_sample_ts_mean = model_pr.squeeze("model").mean(dim=["time", "sample_id"])
                sample_ts_change = 100*(model_sample_ts_mean - sample_historical_mean.sel(model=model))/sample_historical_mean.sel(model=model)
                sample_ts_diff = sample_ts_change - cpm_ts_change

                IPython.display.display_html(f"<h5>{ts_key} {season} {model}</h5>", raw=True)
                g = sample_ts_diff.plot.pcolormesh(col="ensemble_member", col_wrap=6, transform=cp_model_rotated_pole, subplot_kws=dict(projection=cp_model_rotated_pole), **STYLES["discreteBias"])
                for ax in g.axs.flatten():
                    ax.coastlines()
                
                plt.show()

## PSD

In [None]:
gcm_lr_lin_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("DERIVED_DATA"), "moose", "nc-datasets", gcm_lr_lin_pr_dataset, f"{split}.nc"
    )
)["linpr"]*3600*24).assign_attrs({"units": "mm day-1"}).stack(example=["ensemble_member", "time"]).transpose("example", "grid_latitude", "grid_longitude")

cpm_hr_pr = (xr.open_dataset(
    os.path.join(
        os.getenv("DERIVED_DATA"), "moose", "nc-datasets", cpm_hr_pr_dataset, f"{split}.nc"
    )
)["target_pr"]*3600*24).assign_attrs({"units": "mm day-1"}).stack(example=["ensemble_member", "time"]).transpose("example", "grid_latitude", "grid_longitude")

### CPM

In [None]:
fig, axd = plt.subplot_mosaic([["PSD"]], tight_layout=True)  # , figsize=(12, 12)
ax = axd["PSD"]
plot_psd(cpm_hr_pr, gcm_lr_lin_pr=gcm_lr_lin_pr, pred_pr=merged_ds["CPM"]["pred_pr"].stack(example=["ensemble_member", "sample_id", "time"]).transpose("model", "example", "grid_latitude", "grid_longitude"), ax=ax)
plt.show()

### GCM

In [None]:
fig, axd = plt.subplot_mosaic([["PSD"]], tight_layout=True)  # , figsize=(12, 12)
ax = axd["PSD"]
plot_psd(cpm_hr_pr, gcm_lr_lin_pr=gcm_lr_lin_pr, pred_pr=merged_ds["GCM"]["pred_pr"].stack(example=["ensemble_member", "sample_id", "time"]).transpose("model", "example", "grid_latitude", "grid_longitude"), ax=ax)
plt.show()

## FSS

## Correlation