# Raw diffusion model output

This file of samples from diffusion model include the raw values as well as the pred_pr values after applying the inverse of the target transform.

In [None]:
%reload_ext autoreload

%autoreload 2

%reload_ext dotenv
%dotenv

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from mlde_utils import samples_path, workdir_path, dataset_split_path
from mlde_utils.transforms import load_transform

from mlde_notebooks.data import si_to_mmday

In [None]:
model_id = "score-sde/subvpsde/ukcp18_12em_cncsnpp_continuous/paper-12em"
dataset = "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season"
split = "test"
ensemble_member = "01"

pred_with_raw_ds = xr.open_dataset(samples_path(
    workdir_path(model_id),
    checkpoint="epoch-20",
    dataset=dataset,
    input_xfm="bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-stan",
    split=split,
    ensemble_member=ensemble_member,
) / "predictions-RSAKzVXxTGynBkwrstiyVL.nc")

sim_ds = xr.open_dataset(dataset_split_path(dataset, split)).sel(
        ensemble_member=[ensemble_member]
    )

ds = xr.combine_by_coords([pred_with_raw_ds, sim_ds], data_vars="minimal")

ds

In [None]:
target_xfm = load_transform(workdir_path(model_id) / "transforms" / "target.pickle")

## Distribution

Below is the full distribution of values directly from the diffusion model, the distribution of those below -1 which is physically impossible and an equal spread around -1, bounded below by most negative value from diffusion

In [None]:
raw_pred = ds["raw_pred"]
bad_points = ds.where(ds["raw_pred"] < -1)

nbad_vals = bad_points["raw_pred"].count()
total_vals = np.product(raw_pred.shape)

print(f"{nbad_vals.data} values below -1 out of {total_vals} values")

raw_pred.plot.hist(bins=100)
plt.show()

bad_points["raw_pred"].plot.hist(bins=50)
plt.show()

raw_pred.where(raw_pred <= -1-(raw_pred.min()+1)).plot.hist(bins=50)
plt.show()

### Quantiles

In [None]:
bad_points["raw_pred"].quantile([0, 0.01, 0.1, 0.5])

Use the target transform to convert the bad quantiles as though they were reflected in -1 (i.e. the same amount above -1 as they are below -1) in order to put them into the valid space

In [None]:
reflected_bad_raw_pred = -1-(1+bad_points["raw_pred"])
si_to_mmday(target_xfm.invert({"target_pr": reflected_bad_raw_pred.quantile([0.5, 0.9, 0.99, 1])})["target_pr"])

## Relationship with CPM values

In [None]:
plt.scatter(bad_points["raw_pred"], bad_points["target_pr"], alpha=0.1)
plt.show()

plt.scatter(bad_points["raw_pred"], target_xfm.transform(bad_points)["target_pr"], alpha=0.1)
plt.show()