
# NerF + QG Loss

The full QG equation is given by:

$$
\begin{aligned}
\partial_t q + \det \boldsymbol{J}(q, \psi) &= 0
\end{aligned}
$$

where:

* $q=\nabla^2 \psi$
* $\det \boldsymbol{J}(q, \psi)=\partial_x q\partial_y\psi - \partial_y q\partial_x\psi$.

We are interested in finding some NerF method that can take in the spatial-temporal coordinates, $\mathbf{x}_\phi$, and output a vector corresponding to the PV and stream function, $\psi$, i.e. $\mathbf{y}_\text{obs}$.

$$
\mathbf{y}_\text{obs} = \boldsymbol{f_\theta}(\mathbf{x}_\phi) + \epsilon, \hspace{5mm}\epsilon \sim \mathcal{N}(0, \sigma^2)
$$

We use a SIREN network which is a fully connected neural network with the $sin$ activation function.

* **Data Inputs**: `256x256x11`
* **Data Ouputs**: `2`


In [None]:
import os
import sys

from pyprojroot import here

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# spyder up to find the root
root = here(project_files=[".root"])
exp = here(
    relative_project_path=root.joinpath("experiments/dc21a"), project_files=[".local"]
)


# append to path
sys.path.append(str(root))
sys.path.append(str(exp))

In [None]:
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
import xarray as xr

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from ml_collections import config_dict
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from inr4ssh._src.datamodules.osse_2020a import AlongTrackDataModule

pl.seed_everything(123)

import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from loguru import logger

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# from ml_collections import config_dict

# cfg = config_dict.ConfigDict()

# # logging args
# cfg.log = config_dict.ConfigDict()
# cfg.log.mode = "online" #"disabled"
# cfg.log.project ="inr4ssh"
# cfg.log.entity = "ige"
# cfg.log.log_dir = "/Users/eman/code_projects/logs/"
# cfg.log.resume = False

# # data args
# cfg.data = config_dict.ConfigDict()
# cfg.data.data_dir =  f"/Users/eman/code_projects/torchqg/data/qgsim_simple_128x128.nc"

# # preprocessing args
# cfg.pre = config_dict.ConfigDict()
# cfg.pre.noise = 0.01
# cfg.pre.dt = 1.0
# cfg.pre.time_min = 500
# cfg.pre.time_max = 511
# cfg.pre.seed = 123

# # train/test args
# cfg.split = config_dict.ConfigDict()
# cfg.split.train_prct = 0.9

# # dataloader args
# cfg.dl = config_dict.ConfigDict()
# cfg.dl.batchsize_train = 2048
# cfg.dl.batchsize_val = 1_000
# cfg.dl.batchsize_test = 5_000
# cfg.dl.batchsize_predict = 10_000
# cfg.dl.num_workers = 0
# cfg.dl.pin_memory = False

# # loss arguments
# cfg.loss = config_dict.ConfigDict()
# cfg.loss.qg = True
# cfg.loss.alpha = 1e-4

# # optimizer args
# cfg.optim = config_dict.ConfigDict()
# cfg.optim.warmup = 10
# cfg.optim.num_epochs = 100
# cfg.optim.learning_rate = 1e-4

# # trainer args
# cfg.trainer = config_dict.ConfigDict()
# cfg.trainer.accelerator = None
# cfg.trainer.devices = 1
# cfg.trainer.grad_batches = 1

In [None]:
# from inr4ssh._src.io import transform_dict

# cfg = get_config()

# cfg.to_dict()

## Data Module

Now we will put all of the preprocessing routines together. This is **very important** for a few reasons:

1. It collapses all of the operations in a modular way
2. It makes it reproducible for the next people
3. It makes it very easy for the PyTorch-Lightning framework down the line.

In [None]:
from ml_collections import config_dict

config = config_dict.ConfigDict()

# LOGGING
config.logger = config_dict.ConfigDict()
config.logger.mode = "disabled"  # "online" #
config.logger.project = "inr4ssh"
config.logger.entity = "ige"
config.logger.log_dir = "/Users/eman/code_projects/logs/"
config.logger.resume = False


# data directory
config.data = data = config_dict.ConfigDict()
data.dataset_dir = "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/nadir4.nc"
data.ref_dir = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM*"
# preprocessing
config.preprocess = config_dict.ConfigDict()
config.preprocess.subset_time = subset_time = config_dict.ConfigDict()
subset_time.subset_time = True
subset_time.time_min = "2012-10-22"
subset_time.time_max = "2012-11-01"  # "2012-12-02"


config.preprocess.subset_spatial = subset_spatial = config_dict.ConfigDict()
subset_spatial.subset_spatial = True
subset_spatial.lon_min = -65.0
subset_spatial.lon_max = -55.0
subset_spatial.lat_min = 33.0
subset_spatial.lat_max = 43.0

# transformations
config.preprocess.transform = transform = config_dict.ConfigDict()
transform.time_transform = "minmax"
transform.time_min = "2012-01-01"
transform.time_max = "2013-01-01"

# train/valid arguments
config.traintest = traintest = config_dict.ConfigDict()
traintest.train_prct = 0.9
traintest.seed = 42


# EVALUATION
config.evaluation = evaluation = config_dict.ConfigDict()
evaluation.lon_min = -65.0
evaluation.lon_max = -55.0
evaluation.dlon = 0.1
evaluation.lat_min = 33.0
evaluation.lat_max = 43.0
evaluation.dlat = 0.1

evaluation.time_min = "2012-10-22"
evaluation.time_max = "2012-11-01"  # "2012-12-02"
evaluation.dt_freq = 1
evaluation.dt_unit = "D"

evaluation.time_resample = "1D"
# , get_demo_config

# config = get_demo_config()

config.preprocess.subset_spatial.subset_spatial = True
config.preprocess.subset_time.subset_time = True


config

In [None]:
wandb_logger = WandbLogger(
    config=config.to_dict(),
    # mode="offline",
    mode=config.logger.mode,
    project=config.logger.project,
    entity=config.logger.entity,
    dir=config.logger.log_dir,
    resume=False,
)

In [None]:
# dataloader
config.dataloader = dataloader = config_dict.ConfigDict()
# train dataloader
dataloader.batchsize_train = 32
dataloader.num_workers_train = 0
dataloader.shuffle_train = True
dataloader.pin_memory_train = False
# valid dataloader
dataloader.batchsize_valid = 32
dataloader.num_workers_valid = 0
dataloader.shuffle_valid = False
dataloader.pin_memory_valid = False
# test dataloader
dataloader.batchsize_test = 32
dataloader.num_workers_test = 0
dataloader.shuffle_test = False
dataloader.pin_memory_test = False
# predict dataloader
dataloader.batchsize_predict = 32
dataloader.num_workers_predict = 0
dataloader.shuffle_predict = False
dataloader.pin_memory_predict = False

In [None]:
# initialize data module
dm = AlongTrackDataModule(
    root=None,
    config=config,
    download=False,
)

# initialize datamodule params
dm.setup()

# initialize dataloaders
ds_train = dm.train_dataloader()

ds_valid = dm.val_dataloader()

ds_test = dm.test_dataloader()

ds_predict = dm.predict_dataloader()

In [None]:
import math

data = dm.ds_train[:10]

data["spatial"].shape, data["temporal"].shape, data["output"].shape

In [None]:
dm.ds_train[:]["spatial"].min(), dm.ds_train[:]["spatial"].max()

In [None]:
x_init = torch.cat([data["spatial"], data["temporal"]], dim=1)
y_init = data["output"]
x_init.shape, y_init.shape

### Transformations

**Spatial**:

> We want to transform this from degrees to radians


**Temporal**:

> We want to transform this from time to sines and cosines

## NerF

This standard Neural Fields.

In [None]:
# MODEL
config.model = model = config_dict.ConfigDict()
model.model = "siren"
# encoder specific
model.encoder = config_dict.placeholder(str)
# generalized
model.num_layers = 5
model.hidden_dim = 256
model.use_bias = True
model.final_activation = "identity"
# SIREN SPECIFIC
model.model_seed = 42
model.w0_initial = 30.0
model.w0 = 1.0
model.final_scale = 1.0
model.c = 6.0
# # MODULATED SIREN
# model.latent_dim = 256
# model.num_layers_latent = 3
# model.operation = "sum"
# # MULTIPLICATIVE FILTER NETWORKS
# model.input_scale = 256.0
# model.weight_scale = 1.0
# model.alpha = 6.0
# model.beta = 1.0

In [None]:
from inr4ssh._src.models.models_factory import model_factory

dim_in = x_init.shape[1]
dim_out = y_init.shape[1]

net = model_factory(
    model=config.model.model,
    # dim_in=x_train.shape[1],
    dim_in=3,
    # dim_out=y_train.shape[1],
    dim_out=1,
    config=config.model,
)

In [None]:
out = net(x_init)

## Experiment

In [None]:
config.transform_spatial = config_dict.ConfigDict()
config.transform_spatial.transform = "identity"
config.transform_spatial.scaler = [1.0 / math.pi, 1.0 / (math.pi / 2.0)]

config.transform_temporal = config_dict.ConfigDict()
config.transform_temporal.transform = "identity"

In [None]:
from inr4ssh._src.transforms.utils import (
    spatial_transform_factory,
    temporal_transform_factory,
)

spatial_transform = spatial_transform_factory(config.transform_spatial)
temporal_transform = temporal_transform_factory(config.transform_temporal)

In [None]:
config.loss = config_dict.ConfigDict()
config.loss.loss = "mse"
config.loss.reduction = "mean"

config.optimizer = config_dict.ConfigDict()
config.optimizer.optimizer = "adam"
config.optimizer.learning_rate = 1e-4

config.lr_scheduler = config_dict.ConfigDict()
config.lr_scheduler.lr_scheduler = "warmcosine"
config.lr_scheduler.num_epochs = 10
config.lr_scheduler.warmup_epochs = 5
config.lr_scheduler.max_epochs = 10
config.lr_scheduler.warmup_lr = 0.0
config.lr_scheduler.eta_min = 0.0

In [None]:
from inr4ssh._src.trainers.osse_2020a import INRModel

learn = INRModel(
    model=net,
    spatial_transform=spatial_transform,
    temporal_transform=temporal_transform,
    optimizer_config=config.optimizer,
    lr_scheduler_config=config.lr_scheduler,
    loss_config=config.loss,
)

In [None]:
# run_path = "ige/inr4ssh/1st3rtl0"
# model_path = "checkpoints/epoch=990-step=39640.ckpt"

In [None]:
# from inr4ssh._src.io import get_wandb_config, get_wandb_model

In [None]:
# best_model = get_wandb_model(run_path, model_path)
# best_model.download(replace=True)

### Callbacks

In [None]:
from inr4ssh._src.callbacks.utils import get_callbacks

In [None]:
config.callbacks = config_dict.ConfigDict()
# wandb logging
config.callbacks.wandb = True
config.callbacks.model_checkpoint = True
# early stopping
config.callbacks.early_stopping = False
config.callbacks.patience = 20
config.callbacks.watch_model = False
# tqdm
config.callbacks.tqdm = True
config.callbacks.tqdm_refresh = 10

In [None]:
callbacks = get_callbacks(config.callbacks, wandb_logger)

### Learner

In [None]:
# state = torch.load(best_model.name, map_location=torch.device("cpu"))

In [None]:
# state["state_dict"]

In [None]:
# kwargs,
# net = SirenNet(**kwargs)
# net.load_state_dict(state_dict)

In [None]:
# learn = INRModel.load_from_checkpoint(
#     best_model.name,
#     model=net,
#     loss_data=nn.MSELoss("mean"),
#     reg_pde=reg_loss,
#     learning_rate=cfg.optim.learning_rate,
#     warmup=cfg.optim.warmup,
#     num_epochs=cfg.optim.num_epochs,
#     alpha=cfg.loss.alpha,
#     qg=cfg.loss.qg,
# )

### Trainer

In [None]:
config.trainer = config_dict.ConfigDict()
config.trainer.num_epochs = config.lr_scheduler.num_epochs
config.trainer.accelerator = "mps"  # "cpu", "gpu"
config.trainer.devices = 1
config.trainer.strategy = config_dict.placeholder(str)
config.trainer.num_nodes = 1
config.trainer.grad_batches = 10
config.trainer.dev_run = False

In [None]:
accelerator = "cpu"

trainer = Trainer(
    min_epochs=1,
    max_epochs=config.trainer.num_epochs,
    accelerator=config.trainer.accelerator,
    devices=config.trainer.devices,
    enable_progress_bar=True,
    logger=wandb_logger,
    callbacks=callbacks,
    accumulate_grad_batches=config.trainer.grad_batches,
)

### Train

In [None]:
trainer.fit(
    learn,
    datamodule=dm,
)

## Results

### Test

In [None]:
# res = trainer.test(learn, dataloaders=dm.test_dataloader())

### Inference

In [None]:
t0 = time.time()
# predictions = trainer.predict(learn, dataloaders=dm.predict_dataloader(), return_predictions=True)

predictions = trainer.predict(
    learn, dataloaders=dm.test_dataloader(), return_predictions=True
)
predictions = torch.cat(predictions)
t1 = time.time() - t0
print(f"Time Taken: {t1:.2f} secs")

In [None]:
df_pred = dm.ds_test.create_predict_df(predictions.detach().numpy())
ds_pred = df_pred.reset_index().set_index(["longitude", "latitude", "time"]).to_xarray()
ds_pred

In [None]:
ds_pred.to_netcdf("/Volumes/EMANS_HDD/data/dc20a_osse/test/results/test.nc")

In [None]:
ds_pred = xr.open_dataset("/Volumes/EMANS_HDD/data/dc20a_osse/test/results/test.nc")

In [None]:
ds_pred["ssh_model_noise"] = ds_pred["ssh_model"] + 0.01 * np.random.randn(
    *ds_pred["ssh_model"].shape
)

#### Metrics: Statistics

In [None]:
from inr4ssh._src.metrics.field.stats import nrmse_spacetime, rmse_space, nrmse_time

#### Normalized RMSE (Space-Time)

In [None]:
nrmse_xyt = nrmse_spacetime(ds_pred["ssh_model_predict"], ds_pred["ssh_model"]).values
logger.info(f"Leaderboard SSH RMSE score =  {nrmse_xyt:.2f}")
wandb_logger.log_metrics(
    {
        "nrmse_mu": nrmse_xyt,
    }
)

##### Error Variability (Temporal)


In [None]:
rmse_t = nrmse_time(ds_pred["ssh_model_predict"], ds_pred["ssh_model"])

err_var_time = rmse_t.std().values
logger.info(f"Error Variability =  {err_var_time:.2f}")
wandb_logger.log_metrics(
    {
        "nrmse_std": err_var_time,
    }
)

In [None]:
fig, ax = plt.subplots()

rmse_t.plot(ax=ax)

ax.set(xlabel="Time", ylabel="nRMSE")
plt.tight_layout()
plt.show()

##### Error Variability (Spatial)

In [None]:
rmse_xy = rmse_space(ds_pred["ssh_model_noise"], ds_pred["ssh_model"])

In [None]:
fig, ax = plt.subplots()

rmse_xy.T.plot.imshow(ax=ax)

plt.tight_layout()
plt.show()

#### Metrics: PSD

In [None]:
from inr4ssh._src.metrics.psd import (
    psd_isotropic_score,
    psd_spacetime_score,
    wavelength_resolved_spacetime,
    wavelength_resolved_isotropic,
)

In [None]:
ds_pred = xr.open_dataset("/Volumes/EMANS_HDD/data/dc20a_osse/test/results/test.nc")
ds_pred["ssh_model_noise"] = ds_pred["ssh_model"] + 0.01 * np.random.randn(
    *ds_pred["ssh_model"].shape
)
# rescale spatial coords
# rescale time coords
# time_norm = np.timedelta64(1, "D")
time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_pred["time"] = (ds_pred.time - ds_pred.time[0]) / time_norm

#### PSD Score: Space-Time

* Space-Time Average
* Isotropic

In [None]:
# Time-Longitude (Lat avg) PSD Score
psd_score = psd_spacetime_score(ds_pred["ssh_model"], ds_pred["ssh_model_predict"])

In [None]:
spatial_resolved, time_resolved = wavelength_resolved_spacetime(psd_score)
logger.info(
    f"Shortest Spatial Wavelength Resolved = {spatial_resolved:.2f} (degree lon)"
)
logger.info(f"Shortest Temporal Wavelength Resolved = {time_resolved:.2f} (days)")

wandb_logger.log_metrics(
    {
        "wavelength_space_deg": spatial_resolved,
    }
)
wandb_logger.log_metrics(
    {
        "wavelength_time_days": time_resolved,
    }
)

In [None]:
# Isotropic (Time avg) PSD Score
psd_iso_score = psd_isotropic_score(ds_pred["ssh_model"], ds_pred["ssh_model_predict"])

In [None]:
space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
logger.info(f"Shortest Spatial Wavelength Resolved = {space_iso_resolved:.2f} (degree)")
wandb_logger.log_metrics(
    {
        "wavelength_iso_degree": space_iso_resolved,
    }
)

#### Summary

In [None]:
import pandas as pd

data = [
    [
        "SIREN GF/GF",
        nrmse_xyt,
        err_var_time,
        spatial_resolved,
        time_resolved,
        space_iso_resolved,
        "GF/GF",
        "eval_siren.ipynb",
    ]
]

Leaderboard = pd.DataFrame(
    data,
    columns=[
        "Method",
        "µ(RMSE) ",
        "σ(RMSE)",
        "λx (degree)",
        "λt (days)",
        "λr (degree)",
        "Notes",
        "Reference",
    ],
)
print("Summary of the leaderboard metrics:")
print(Leaderboard.to_markdown())

Summary of the leaderboard metrics:
|    | Method      |   µ(RMSE)  |    σ(RMSE) |   λx (degree) |   λt (days) |   λr (degree) | Notes   | Reference        |
|---:|:------------|-----------:|-----------:|--------------:|------------:|--------------:|:--------|:-----------------|
|  0 | SIREN GF/GF |  -0.101113 | 0.00143019 |     0.0374532 |        2.75 |       2.72525 | GF/GF   | eval_siren.ipynb |

In [None]:
wandb.finish()

---
**DATA**

* convert this reference grid to `lat,lon,time,sossheig`
* create dataloader
* Make predictions
* Create xr.dataset from predictions

---
**Metrics**

* RMSE Metrics
* PSD Metrics

In [None]:
# res = trainer.test(learn, dataloaders=dm.test_dataloader())

# results["data"] = res

In [None]:
# import wandb

# wandb.finish()

### Predictions

In [None]:
t0 = time.time()
predictions = trainer.predict(learn, datamodule=dm, return_predictions=True)
predictions = torch.cat(predictions)
t1 = time.time() - t0
print(f"Time Taken: {t1:.2f} secs")

In [None]:
# ds_pred = dm.create_predictions_ds(predictions)

from inr4ssh._src.operators import differential_simp as diffops_simp

from inr4ssh._src.operators import differential as diffops

In [None]:
ds_pred

In [None]:
ds_pred.predict.thin(time=4).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
# ds_pred.predict.hvplot.image(x="Longitude", y="Latitude", width=500, height=400, cmap="viridis")

In [None]:
# ds_pred = dm.create_predictions_ds(predictions)
# ds_pred

In [None]:
from tqdm.notebook import tqdm, trange

In [None]:
learn.model.eval()
coords, truths, preds, grads, qs = [], [], [], [], []
for ibatch in tqdm(dm.predict_dataloader()):
    with torch.set_grad_enabled(True):
        # prediction
        ibatch["spatial"] = torch.autograd.Variable(
            ibatch["spatial"].clone(), requires_grad=True
        )
        ibatch["temporal"] = torch.autograd.Variable(
            ibatch["temporal"].clone(), requires_grad=True
        )
        ix = torch.cat([ibatch["spatial"], ibatch["temporal"]], dim=1)
        p_pred = learn.model(ix)

        # p_pred = p_pred.clone()
        # p_pred.require_grad_ = True

        # gradient
        p_grad = diffops_simp.gradient(p_pred, ibatch["spatial"])
        # p_grad = diffops.grad(p_pred, ix)
        # q
        q = diffops_simp.divergence(p_grad, ibatch["spatial"])
        # q = diffops.div(p_grad, ix)

    # collect
    # truths.append(ibatch["output"])
    coords.append(ix)
    preds.append(p_pred)
    grads.append(p_grad)
    qs.append(q)

In [None]:
coords = torch.cat(coords).detach().numpy()
preds = torch.cat(preds).detach().numpy()
# truths = torch.cat(truths).detach().numpy()
grads = torch.cat(grads).detach().numpy()
qs = torch.cat(qs).detach().numpy()

In [None]:
df_pred = dm.ds_predict.create_predict_df(predictions.detach().numpy())
df_pred["u"] = -grads[:, 1]
df_pred["v"] = grads[:, 0]
df_pred["q"] = qs
ds_pred = df_pred.reset_index().set_index(["longitude", "latitude", "time"]).to_xarray()

### Figure I: Predictions

In [None]:
ds_pred.q.thin(time=4).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure II: Ground Truth

In [None]:
ds_pred.true.thin(time=1).plot.imshow(
    col="time",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

In [None]:
ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure III: Absolute Error

In [None]:
(ds_pred.true - ds_pred.pred).thin(time=1).plot.imshow(
    col="time",
    robust=True,
    col_wrap=3,
    cmap="RdBu_r",
)

In [None]:
(ds_pred.true - ds_pred.pred).hvplot.image(
    x="Nx", y="Ny", width=500, height=400, cmap="viridis"
)