In [None]:
from functools import partial
import logging
logging.disable(logging.WARNING)

from copy import deepcopy

import numpy as np
import xarray as xr
import torch
from torch.utils.data import DataLoader

import lightning.pytorch as pl

from k_diffusion.sampling import *
from ddm_dynamical.scheduler import EDMSamplingScheduler, LinearScheduler, BinarizedScheduler
from ddm_dynamical.sampler import KDiffusionSampler
from ddm_dynamical.parameterization import VParam
from ddm_dynamical.weighting import ExponentialWeighting
from ddm_dynamical.utils import normalize_gamma

from diffusion_nextsim.data import TrajectoryDataset
from diffusion_nextsim.surrogate.diffusion import residual_preprocessing
from diffusion_nextsim.utils import estimate_crps_ens, estimate_crps_gauss, get_fft_stats, estimate_spectrum

from hydra import compose, initialize
from hydra.utils import instantiate

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import matplotlib.gridspec as mpl_gs
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

In [None]:
torch.manual_seed(42)
device = torch.device("cuda")
torch.set_float32_matmul_precision('high')

# Load data

In [None]:
train_dataset = TrajectoryDataset(
    "../data/nextsim/train_regional.zarr",
    "../data/auxiliary/ds_auxiliary_regional.nc",
    n_cycles=2
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
train_data = next(iter(train_loader))

In [None]:
val_dataset = TrajectoryDataset(
    "../data/nextsim/validation_regional.zarr",
    "../data/auxiliary/ds_auxiliary_regional.nc",
    n_cycles=2
)

In [None]:
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=True)
val_data = next(iter(val_loader))

# Load model

In [None]:
with initialize(version_base=None, config_path="../configs", job_name="predict_surrogate"):
    cfg = compose(
        config_name="surrogate_test.yaml",
        overrides=[
            "+experiments/diffusion=diffusion",
            "+computer=laputa",
            "network=uvit_l",
            "ckpt_path='../data/models/diffusion/diff_l_exp/best.ckpt'",
        ])

In [None]:
model = instantiate(cfg.surrogate)
_ = model.load_state_dict(torch.load(cfg.ckpt_path, map_location=device)["state_dict"])
model = model.to(device).eval()

# Evaluate loss

In [None]:
@torch.no_grad
def get_error(data, noise, gamma):
    # Input data
    state_in = data["state_traj"][:, :-1].to(device)
    forcing_in = data["forcing_traj"].to(device)
    in_tensor = torch.cat((
        state_in.view(state_in.size(0), -1, *state_in.shape[-2:]),
        forcing_in.view(forcing_in.size(0), -1, *forcing_in.shape[-2:])
    ), dim=-3)

    residual = (data["state_traj"][:, -1] - data["state_traj"][:, -2]).to(device)
    residual = (residual - model.decoder.mean) / model.decoder.std

    ## Diffuse model
    alpha_sq = torch.sigmoid(gamma)
    alpha = alpha_sq.sqrt()
    sigma = (1-alpha_sq).sqrt()
    noised_residual = alpha * residual + sigma * noise

    ## Estimate prediction with diffusion model
    encoded = model.encoder(in_tensor)
    in_tensor = torch.cat(
        (noised_residual, encoded), dim=1
    )
    normalized_gamma = normalize_gamma(
        gamma, model.gamma_min, model.gamma_max
    ).view(-1, 1)
    prediction = model.network(
        in_tensor,
        normalized_gamma=normalized_gamma,
        labels=data["labels"].to(device),
        mesh=data["mesh"].to(device)
    )

    ## Estimate loss
    error_diffusion = model.param.estimate_errors(
        prediction,
        in_data=noised_residual,
        target=residual,
        noise=noise,
        alpha=alpha,
        sigma=sigma,
        gamma=gamma,
    )
    weighted_error = model.weighting(gamma) * error_diffusion
    return weighted_error.cpu()

In [None]:
noise = torch.randn(1024, 5, 64, 64, device=device)

In [None]:
gamma_levels = torch.linspace(-20, 20, 101, device=device)

In [None]:
error_train = []
for g in tqdm(gamma_levels):
    curr_err = get_error(train_data, noise, g)
    error_train.append(curr_err.mean(dim=(0, 2, 3)))
error_train = torch.stack(error_train)

In [None]:
error_val = []
for g in tqdm(gamma_levels):
    curr_err = get_error(val_data, noise, g)
    error_val.append(curr_err.mean(dim=(0, 2, 3)))
error_val = torch.stack(error_val)

In [None]:
fig, ax = plt.subplots(figsize=(4, 2))
ax.grid(alpha=0.5)
plt_train_var = ax.plot(gamma_levels.cpu().numpy(), error_train.numpy(), c="C1", ls="--", label="Train variables", lw=0.7, alpha=0.7)
plt_train_mean, = ax.plot(gamma_levels.cpu().numpy(), error_train.mean(dim=1).numpy(), c="firebrick", label="Train mean")
plt_val_var = ax.plot(gamma_levels.cpu().numpy(), error_val.numpy(), c="0.5", ls="--", label="Validation variables", lw=0.7, alpha=0.7)
plt_val_mean, = ax.plot(gamma_levels.cpu().numpy(), error_val.mean(dim=1).numpy(), c="black", ls="-", label="Validation mean")
ax.annotate(
    "Velocities", xy=(0.62, 0.125), xytext=(0.7, 0.25),
    xycoords=ax.transAxes,
    arrowprops=dict(
        facecolor='black', width=0.1, headwidth=3, headlength=3,
        zorder=100
    ),
    bbox=dict(boxstyle='square,pad=0', fc='#ffffff99', ec='none')
)
ax.set_ylabel("Weighted error")
ax.set_ylim(0, 0.145)
ax.set_xlabel("Log signal-to-noise ratio $\lambda(\\tau)$")
ax.set_xlim(-20, 20)
ax.legend(handles=[plt_train_var[0], plt_train_mean, plt_val_var[0], plt_val_mean], labels=["Train variable", "Train mean", "Val variable", "Val mean"])
fig.savefig("figures/fig_app_b4_diff_error.png", dpi=300)