In [None]:
from functools import partial
import types
import logging
logging.disable(logging.WARNING)

from copy import deepcopy

import xarray as xr
import numpy as np
from numpy import ma
import torch
import pandas as pd
from distributed import LocalCluster, Client

import lightning.pytorch as pl

from ddm_dynamical.scheduler import EDMSamplingScheduler, LinearScheduler
from ddm_dynamical.sampler import HeunSampler
from ddm_dynamical.parameterization import VParam
from diffusion_nextsim.data import TrajectoryDataset
from diffusion_nextsim.data.utils import get_mesh, estimate_rot2curv, rotate_uv2curv
from diffusion_nextsim.deformation import estimate_deform
from diffusion_nextsim.freedrift import OnlyAtmosphereModule

from tqdm.notebook import trange

from hydra import compose, initialize
from hydra.utils import instantiate

import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(42)
torch.set_float32_matmul_precision('high')

In [None]:
n_leadtime = 100
n_ens = 16

# Load data

In [None]:
ds_regional = xr.open_zarr("../data/nextsim/test_regional.zarr")
ds_aux = xr.open_dataset("../data/auxiliary/ds_auxiliary_regional.nc")

In [None]:
start_time = "2017-11-10 03:00"

In [None]:
device = torch.device("cuda")

In [None]:
with initialize(version_base=None, config_path="../configs", job_name="predict_surrogate"):
    cfg = compose(
        config_name="surrogate_test.yaml",
        overrides=[
            "+experiments/deterministic=deterministic",
            "+computer=laputa",
            "network=uvit_s",
            "ckpt_path='../data/models/deterministic/deterministic/best.ckpt'",
        ])

In [None]:
ds_states = ds_regional["state_data"].sel(var_names_1=cfg.data.state_variables).sel(time=slice(start_time, None))[:n_leadtime*2+1:2]
ds_forcings = ds_regional["forcing_data"].sel(var_names_2=cfg.data.forcing_variables).sel(time=slice(start_time, None))[:n_leadtime*2+1:2]

In [None]:
ds_states.to_dataset("var_names_1").to_netcdf("data/consistency_nextsim.nc")

In [None]:
mask = torch.from_numpy(ds_aux["mask"].values).float().to(device)[None, None, ...]
mesh = get_mesh(ds_aux, length_scale=1.)
mesh = torch.from_numpy(mesh).float().to(device)[None, ...]

### Rotate forcings

In [None]:
ds_forcings = xr.concat((
    ds_forcings.sel(var_names_2="tus", drop=True),
    ds_forcings.sel(var_names_2="huss", drop=True),
    *rotate_uv2curv(ds_forcings.sel(var_names_2="uas", drop=True), ds_forcings.sel(var_names_2="vas", drop=True),*estimate_rot2curv(ds_aux))
), dim="var_names_2")
ds_forcings = ds_forcings.transpose("time", "var_names_2", "y", "x")

# Predict with deterministic

## Load model

In [None]:
model = instantiate(cfg.surrogate)
model = type(model).load_from_checkpoint(
    cfg.ckpt_path, map_location=device, strict=False
)
model = model.to(device).eval()

## Predict

In [None]:
predictions = [ds_states[0].values]
for k in trange(1, n_leadtime+1):
    states = torch.from_numpy(predictions[-1]).float().to(device)[None, None, ...]
    forcings = torch.from_numpy(ds_forcings[k-1:k+1].values).float().to(device)[None, ...]
    with torch.no_grad():
        curr_pred = model(states, forcings, mask=mask, mesh=mesh).squeeze(0)
    predictions.append(curr_pred.cpu().numpy())

In [None]:
det_predictions = ds_states[:n_leadtime+1].copy(data=np.stack(predictions))
det_predictions.to_dataset("var_names_1").to_netcdf("data/consistency_deterministic.nc")

# Predict with diffusion model

## Load model

In [None]:
with initialize(version_base=None, config_path="../configs", job_name="predict_surrogate"):
    cfg_diff = compose(
        config_name="surrogate_test.yaml",
        overrides=[
            "+experiments/diffusion=residual",
            "+computer=laputa",
            "sampler=heun",
            "network=uvit_l",
            "surrogate.ckpt_det='../data/models/deterministic/deterministic/best.ckpt'",
            "ckpt_path='../data/models/diffusion/resdiff_l_exp/best.ckpt'",
        ])

In [None]:
model_diff = instantiate(cfg_diff.surrogate)
_ = model_diff.load_state_dict(
    torch.load(cfg_diff.ckpt_path, map_location=device)["state_dict"],
    strict=False
)
model_diff = model_diff.to(device).eval()

## Prediction

In [None]:
pred_diff = [
    np.broadcast_to(ds_states[0].values, (n_ens, 5, 64, 64)),
]
mask = mask.expand(n_ens, 1, 64, 64)
for k in trange(1, n_leadtime+1):
    states = torch.from_numpy(pred_diff[-1]).float().to(device)[:, None, ...]
    forcings = torch.from_numpy(ds_forcings[k-1:k+1].values).float().to(device)[None, ...]
    forcings = forcings.expand(n_ens, 2, 4, 64, 64).reshape(n_ens, 2, 4, 64, 64)
    with torch.no_grad():
        curr_pred = model_diff(states, forcings, mask=mask, mesh=torch.ones_like(mask))
    pred_diff.append(curr_pred.cpu().numpy())

In [None]:
diff_predictions = ds_states.expand_dims(ens=np.arange(n_ens), axis=1).copy(data=np.stack(pred_diff))
diff_predictions.to_dataset("var_names_1").to_netcdf("data/consistency_diffusion.nc")