In [None]:
from copy import deepcopy
import logging

logging.basicConfig(level=logging.WARNING)

import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl

from tqdm.notebook import tqdm

from hydra import compose, initialize
from hydra.utils import instantiate

from diffusion_nextsim.data import TrajectoryDataset
from diffusion_nextsim.decoder import StochasticDecoder, FFTSampler
from diffusion_nextsim.utils import estimate_crps_ens, get_fft_stats, estimate_spectrum

import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda")

plt.style.use("paper")
plt.style.use("wiley")

# Load model

In [None]:
with initialize(version_base=None, config_path="../configs", job_name="predict_surrogate"):
    cfg = compose(
        config_name="surrogate_test.yaml",
        overrides=[
            "+experiments/deterministic=deterministic",
            "+computer=laputa",
            "network.n_embedding=64",
            "ckpt_path='../data/models/deterministic/deterministic/best.ckpt'",
        ])

In [None]:
model = instantiate(cfg.surrogate)
model = type(model).load_from_checkpoint(
    cfg.ckpt_path, map_location=device,
).eval()

# Load data

In [None]:
dataset = TrajectoryDataset(
    "../data/nextsim/validation_regional.zarr",
    "../data/auxiliary/ds_auxiliary_regional.nc",
    n_cycles=2
)

In [None]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
data = [b for b in iter(dataloader)]

# Cycle the model to get initial residuals

In [None]:
residuals = []

for batch in tqdm(data, total=len(data)):
    states = batch["state_traj"][:, 0].to(device)
    forcings = batch["forcing_traj"][:, :2].to(device)
    in_tensor = torch.cat((
        states.reshape(states.size(0), -1, *states.shape[-2:]),
        forcings.reshape(forcings.size(0), -1, *forcings.shape[-2:]),
    ), dim=-3)
    labels = torch.zeros(
        states.size(0), 3, dtype=states.dtype, device=states.device
    )
    mask = batch["mask"].to(device)
    mesh = batch["mesh"].to(device)

    with torch.no_grad():
        encoded = model.encoder(in_tensor)
        prediction = model.network(encoded, labels=labels)
        prediction = model.decoder.to_prediction(prediction, states)
    residuals.append(batch["state_traj"][:, 1]-prediction.cpu())
residuals = torch.cat(residuals, dim=0)

# Get cross-covariance

In [None]:
res_perts = residuals.swapdims(0, 1).reshape(5, -1)
res_perts = res_perts-res_perts.mean(dim=1, keepdims=True)
cross_cov = (res_perts@res_perts.T)/(res_perts.size(1) - 1)
res_std = cross_cov.diagonal().sqrt()
cross_corr = cross_cov / res_std[:, None]/ res_std[None, :]

In [None]:
modelled_corr = torch.eye(5)
modelled_corr[1, 0] = modelled_corr[0, 1] = 0.57
modelled_corr[0, 2] = modelled_corr[2, 0] = -0.05
modelled_corr[3, 4] = modelled_corr[4, 3] = -0.06

modelled_std = res_std * torch.tensor([1.05, 1.35, 1.1, 1.02, 1.02])
modelled_cov = modelled_corr * modelled_std[:, None] * modelled_std[None, :]

## Print in nice latex

In [None]:
indexes = ["SIT", "SIC", "SID", "SIU", "SIV"]

In [None]:
estimated_latex = pd.DataFrame(cross_corr.numpy(), index=indexes, columns=indexes)
estimated_latex = pd.concat((estimated_latex, pd.Series(res_std.numpy(), index=indexes).to_frame("$\sigma$")), axis=1).T
print(estimated_latex.round(2).to_latex(float_format="%.2f"))

In [None]:
sampling_latex = pd.DataFrame(modelled_corr.numpy(), index=indexes, columns=indexes)
sampling_latex = pd.concat((sampling_latex, pd.Series(modelled_std.numpy(), index=indexes).to_frame("$\sigma$")), axis=1).T
print(sampling_latex.round(2).to_latex(float_format="%.2f"))

# Define and decompose FFT

In [None]:
res_fft = torch.fft.fft2(residuals).mean(dim=0)

# To decompose the image into periodic and static
# Based on https://github.com/jacobkimmel/ps_decomp
res_avg = torch.fft.ifft2(res_fft).real
res_mean = res_avg.mean(dim=(-2, -1), keepdims=True)
res_avg = res_avg-res_mean

res_fft = torch.fft.fft2(res_avg)

# V component
res_v = torch.zeros(5, 64, 64)
res_v[:, 0, :] = res_avg[:, -1, :]-res_avg[:, 0, :]
res_v[:, -1, :] = res_avg[:, 0, :]-res_avg[:, -1, :]
res_v[:, :, 0] = res_avg[:, :, -1]-res_avg[:, :, 0]
res_v[:, :, -1] = res_avg[:, :, 0]-res_avg[:, :, -1]
res_v_fft = torch.fft.fft2(res_v)

# smooth component
q = torch.arange(64)[:, None]
r = torch.arange(64)[None, :]
den = 2 * torch.cos(2*torch.pi*q/64) + 2 * torch.cos(2*torch.pi*r/64) - 4
s = torch.where(den != 0, res_v_fft / den, 0.)
s[0, 0] = 0.
res_smooth = torch.fft.ifft2(s).real
res_periodic_fft = res_fft-s

In [None]:
fft_shift = res_smooth + res_mean

## Plot FFTs

In [None]:
fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(5, 5*2/3), dpi=150)

for axi in ax:
    for axij in axi:
        axij.set_facecolor("white")
        axij.xaxis.set_visible(False)
        axij.set_yticks([])
        axij.spines.left.set_visible(False)
        axij.spines.right.set_visible(False)
        axij.spines.bottom.set_visible(False)

ax[0, 0].pcolormesh(
    np.arange(65), np.arange(65), res_fft[0].numpy(), cmap="coolwarm"
)
ax[0, 0].text(0.02, 0.98, "(a) SIT", ha="left", va="top", transform=ax[0, 0].transAxes)
ax[0, 1].pcolormesh(
    np.arange(65), np.arange(65), res_fft[1].numpy(), cmap="coolwarm"
)
ax[0, 1].text(0.02, 0.98, "(b) SIC", ha="left", va="top", transform=ax[0, 1].transAxes)
ax[0, 2].pcolormesh(
    np.arange(65), np.arange(65), res_fft[2].numpy(), cmap="coolwarm"
)
ax[0, 2].text(0.02, 0.98, "(c) SID", ha="left", va="top", transform=ax[0, 2].transAxes)
ax[1, 0].pcolormesh(
    np.arange(65), np.arange(65), res_fft[3].numpy(), cmap="coolwarm"
)
ax[1, 0].text(0.02, 0.98, "(d) SIU", ha="left", va="top", transform=ax[1, 0].transAxes)
ax[1, 1].pcolormesh(
    np.arange(65), np.arange(65), res_fft[4].numpy(), cmap="coolwarm"
)
ax[1, 1].text(0.02, 0.98, "(e) SIV", ha="left", va="top", transform=ax[1, 1].transAxes)
fig.delaxes(ax[1, 2])
fig.savefig("figures/fig_app_a1_stoch_fields.png")

# Write decoder into checkpoint

In [None]:
decoder = StochasticDecoder(
    FFTSampler(res_periodic_fft, fft_shift, modelled_cov),
    cfg.decoder.mean, cfg.decoder.std, cfg.decoder.lower_bound, cfg.decoder.upper_bound
)

In [None]:
stochastic_ckpt = torch.load("../data/models/deterministic/deterministic/best.ckpt", map_location="cpu")
decoder_keys = [k for k in stochastic_ckpt["state_dict"].keys() if k.startswith("decoder")]
for k in decoder_keys:
    del stochastic_ckpt["state_dict"][k]

for k, v in decoder.state_dict().items():
    stochastic_ckpt["state_dict"][f"decoder.{k:s}"] = v

torch.save(stochastic_ckpt, "../data/models/deterministic/deterministic/stochastic.ckpt")