In [None]:
import numpy as np
import torch
import xarray as xr
import pandas as pd
from distributed import LocalCluster, Client
import wandb

from diffusion_nextsim.data.utils import get_mesh, estimate_rot2curv, rotate_uv2curv
from diffusion_nextsim.freedrift import OnlyAtmosphereModule, SeaIceVelocityModule, FreedriftModel

In [None]:
cluster = LocalCluster(n_workers=32)
client = Client(cluster)
client

# Load data

In [None]:
ds_test = xr.open_zarr("../data/nextsim/test_regional.zarr")
ds_aux = xr.open_dataset("../data/auxiliary/ds_auxiliary_regional.nc")

In [None]:
ds_rolled = ds_test.rolling(time=61).construct("lead_time").isel(lead_time=slice(None, None, 2))
ds_rolled = ds_rolled.dropna("time")
ds_rolled = ds_rolled.transpose("time", "lead_time", ..., "y", "x")

In [None]:
initial = ds_rolled["state_data"].isel(lead_time=0)

In [None]:
mesh = get_mesh(ds_aux, length_scale=1.)

In [None]:
wind_forcing = ds_rolled["forcing_data"][:, :, 2:]
wind_forcing = xr.concat(rotate_uv2curv(wind_forcing[:, :, 0], wind_forcing[:, :, 1], *estimate_rot2curv(ds_aux)), dim="var_names")
wind_forcing = wind_forcing.transpose("time", "lead_time", "var_names", "y", "x")

In [None]:
ice_velocity = ds_rolled["state_data"][:, :, 3:].rename({"var_names_1": "var_names"})
ice_velocity = ice_velocity.transpose("time", "lead_time", "var_names", "y", "x")

# Define looping and logging function

In [None]:
def looping_func(state, forcing, mesh, model):
    predictions = [state]
    curr_state = torch.from_numpy(np.copy(state)).float()
    curr_forcing = torch.from_numpy(np.copy(forcing)).float()
    mesh = torch.from_numpy(mesh).float()
    for k in range(curr_forcing.size(0)-1):
        with torch.no_grad():
            curr_state = model(curr_state, curr_forcing[k:k+2], mesh)
            # Post-processing
            correction = curr_state[1].clamp(min=0, max=1)/(curr_state[1].clamp(min=0)+1E-7)
            curr_state[0] = (curr_state[0] * correction).clamp(min=0)
            curr_state[1] = curr_state[1].clamp(min=0, max=1)
            curr_state[2] = (curr_state[2] * correction).clamp(min=0, max=1)
            curr_state[3:] = model.velocity_module(curr_forcing[k+1])
        predictions.append(curr_state.numpy())
    return np.stack(predictions, axis=0)

In [None]:
def log_experiment(model, forcing, exp_name):
    # initialize wandb
    run = wandb.init(
        dir="/tmp/wandb",
        project="test_diffusion_nextsim_regional",
        entity="tobifinn",
        name=exp_name
    )

    # get lazy prediction
    prediction = xr.apply_ufunc(
        looping_func,
        initial,
        forcing,
        input_core_dims=[["var_names_1", "y", "x"], ["lead_time", "var_names", "y", "x"]],
        output_core_dims=[["lead_time", "var_names_1", "y", "x"]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[float],
        kwargs={"mesh": mesh, "model": model},
        dask_gufunc_kwargs={"allow_rechunk": True}
    ).persist()

    # estimate errors    
    error = prediction-ds_rolled["state_data"]

    mae = np.abs(error).mean(["time", "y", "x"])
    mae = mae.compute().to_pandas()
    rmse = np.sqrt((error**2).mean(["time", "y", "x"]))
    rmse = rmse.compute().to_pandas()
    mae.columns = [f"mae_{c:s}" for c in rmse.columns]
    rmse.columns = [f"rmse_{c:s}" for c in rmse.columns]

    scores = pd.concat((rmse, mae), axis=1)
    scores['iterations'] = scores.index

    # log errors
    wb_scores = wandb.Table(dataframe=scores)
    run.log({"test/scores": wb_scores})
    run.finish()

# Run experiments

In [None]:
model = FreedriftModel(OnlyAtmosphereModule(), dt_model=1200, dt_forcing=12*3600, interp_mode="nearest")
log_experiment(model, wind_forcing, "freedrift_nearest")

In [None]:
model = FreedriftModel(OnlyAtmosphereModule(), dt_model=1200, dt_forcing=12*3600, interp_mode="linear")
log_experiment(model, wind_forcing, "freedrift_linear")

In [None]:
model = FreedriftModel(OnlyAtmosphereModule(), dt_model=1200, dt_forcing=12*3600, interp_mode="cubic")
log_experiment(model, wind_forcing, "freedrift_cubic")

In [None]:
model = FreedriftModel(SeaIceVelocityModule(), dt_model=1200, dt_forcing=12*3600, interp_mode="linear")
log_experiment(model, ice_velocity, "freedrift_perfect")

In [None]:
model = FreedriftModel(SeaIceVelocityModule(), dt_model=1200, dt_forcing=12*3600, interp_mode="cubic")
log_experiment(model, ice_velocity, "freedrift_perfect_cubic")