
# NerF + QG Loss

The full QG equation is given by:

$$
\begin{aligned}
\partial_t q + \det \boldsymbol{J}(q, \psi) &= 0
\end{aligned}
$$

where:

* $q=\nabla^2 \psi$
* $\det \boldsymbol{J}(q, \psi)=\partial_x q\partial_y\psi - \partial_y q\partial_x\psi$.

We are interested in finding some NerF method that can take in the spatial-temporal coordinates, $\mathbf{x}_\phi$, and output a vector corresponding to the PV and stream function, $\psi$, i.e. $\mathbf{y}_\text{obs}$.

$$
\mathbf{y}_\text{obs} = \boldsymbol{f_\theta}(\mathbf{x}_\phi) + \epsilon, \hspace{5mm}\epsilon \sim \mathcal{N}(0, \sigma^2)
$$

We use a SIREN network which is a fully connected neural network with the $sin$ activation function.

* **Data Inputs**: `256x256x11`
* **Data Ouputs**: `2`


In [None]:
import os
import sys

from pyprojroot import here

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# spyder up to find the root
root = here(project_files=[".root"])
exp = here(
    relative_project_path=root.joinpath("experiments/dc21a"), project_files=[".local"]
)


# append to path
sys.path.append(str(root))
sys.path.append(str(exp))

In [None]:
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
import xarray as xr

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from ml_collections import config_dict
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from inr4ssh._src.datamodules.osse_2020a import AlongTrackDataModule

pl.seed_everything(123)

import matplotlib.pyplot as plt
import seaborn as sns
import wandb


sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# from ml_collections import config_dict

# cfg = config_dict.ConfigDict()

# # logging args
# cfg.log = config_dict.ConfigDict()
# cfg.log.mode = "online" #"disabled"
# cfg.log.project ="inr4ssh"
# cfg.log.entity = "ige"
# cfg.log.log_dir = "/Users/eman/code_projects/logs/"
# cfg.log.resume = False

# # data args
# cfg.data = config_dict.ConfigDict()
# cfg.data.data_dir =  f"/Users/eman/code_projects/torchqg/data/qgsim_simple_128x128.nc"

# # preprocessing args
# cfg.pre = config_dict.ConfigDict()
# cfg.pre.noise = 0.01
# cfg.pre.dt = 1.0
# cfg.pre.time_min = 500
# cfg.pre.time_max = 511
# cfg.pre.seed = 123

# # train/test args
# cfg.split = config_dict.ConfigDict()
# cfg.split.train_prct = 0.9

# # dataloader args
# cfg.dl = config_dict.ConfigDict()
# cfg.dl.batchsize_train = 2048
# cfg.dl.batchsize_val = 1_000
# cfg.dl.batchsize_test = 5_000
# cfg.dl.batchsize_predict = 10_000
# cfg.dl.num_workers = 0
# cfg.dl.pin_memory = False

# # loss arguments
# cfg.loss = config_dict.ConfigDict()
# cfg.loss.qg = True
# cfg.loss.alpha = 1e-4

# # optimizer args
# cfg.optim = config_dict.ConfigDict()
# cfg.optim.warmup = 10
# cfg.optim.num_epochs = 100
# cfg.optim.learning_rate = 1e-4

# # trainer args
# cfg.trainer = config_dict.ConfigDict()
# cfg.trainer.accelerator = None
# cfg.trainer.devices = 1
# cfg.trainer.grad_batches = 1

In [None]:
# from inr4ssh._src.io import transform_dict

# cfg = get_config()

# cfg.to_dict()

In [None]:
# wandb_logger = WandbLogger(
#     config=cfg.to_dict(),
#     mode="offline",  # cfg.log.mode,
#     project=cfg.log.project,
#     entity=cfg.log.entity,
#     dir=cfg.log.log_dir,
#     resume=False,
# )

In [None]:
!ls /Users/eman/code_projects/torchqg/data/

## Data Module

Now we will put all of the preprocessing routines together. This is **very important** for a few reasons:

1. It collapses all of the operations in a modular way
2. It makes it reproducible for the next people
3. It makes it very easy for the PyTorch-Lightning framework down the line.

In [None]:
from ml_collections import config_dict

config = config_dict.ConfigDict()

# data directory
config.data = data = config_dict.ConfigDict()
data.dataset_dir = "/Volumes/EMANS_HDD/data/dc20a_osse/test/ml/nadir1.nc"

# preprocessing
config.preprocess = config_dict.ConfigDict()
config.preprocess.subset_time = subset_time = config_dict.ConfigDict()
subset_time.subset_time = True
subset_time.time_min = "2012-10-22"
subset_time.time_max = "2012-12-02"

config.preprocess.subset_spatial = subset_spatial = config_dict.ConfigDict()
subset_spatial.subset_spatial = True
subset_spatial.lon_min = -65.0
subset_spatial.lon_max = -55.0
subset_spatial.lat_min = 33.0
subset_spatial.lat_max = 43.0

# transformations
config.preprocess.transform = transform = config_dict.ConfigDict()
transform.time_transform = "minmax"
transform.time_min = "2011-01-01"
transform.time_max = "2013-12-12"

# train/valid arguments
config.traintest = traintest = config_dict.ConfigDict()
traintest.train_prct = 0.9
traintest.seed = 42

# dataloader
config.dataloader = dataloader = config_dict.ConfigDict()
# train dataloader
dataloader.batchsize_train = 32
dataloader.num_workers_train = 2
dataloader.shuffle_train = True
dataloader.pin_memory_train = False
# valid dataloader
dataloader.batchsize_valid = 32
dataloader.num_workers_valid = 2
dataloader.shuffle_valid = False
dataloader.pin_memory_valid = False
# predict dataloader
dataloader.batchsize_predict = 32
dataloader.num_workers_predict = 4
dataloader.shuffle_predict = False
dataloader.pin_memory_predict = False

# EVALUATION
config.evaluation = evaluation = config_dict.ConfigDict()
evaluation.lon_min = -65.0
evaluation.lon_max = -55.0
evaluation.dlon = 0.1
evaluation.lat_min = 33.0
evaluation.lat_max = 43.0
evaluation.dlat = 0.1

evaluation.time_min = "2012-10-22"
evaluation.time_max = "2012-12-02"
evaluation.dt_freq = 1
evaluation.dt_unit = "D"
# , get_demo_config

# config = get_demo_config()

config.preprocess.subset_spatial.subset_spatial = True
config.preprocess.subset_time.subset_time = True

config

In [None]:
# initialize data module
dm = AlongTrackDataModule(
    root=None,
    config=config,
    download=False,
)

# initialize datamodule params
dm.setup()

# initialize dataloaders
train_ds = dm.train_dataloader()

valid_ds = dm.val_dataloader()

predict_ds = dm.predict_dataloader()

In [None]:
import math

data = dm.ds_train[:10]

data["spatial"].shape, data["temporal"].shape, data["output"].shape

In [None]:
x_init = torch.cat([data["spatial"], data["temporal"]], dim=1)
y_init = data["output"]
x_init.shape, y_init.shape

### Transformations

**Spatial**:

> We want to transform this from degrees to radians


**Temporal**:

> We want to transform this from time to sines and cosines

## NerF

This standard Neural Fields.

In [None]:
from inr4ssh._src.models.siren import Siren, SirenNet

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
w0 = 1.0
w0_initial = 30.0
c = 6.0
final_activation = None  # nn.Sigmoid()

net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    c=c,
    final_activation=final_activation,
)

In [None]:
out = net(x_init)

## Experiment

In [None]:
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from functools import partial
from typing import Dict, Any, cast
from torch.optim import Adam

In [None]:
class INRModel(pl.LightningModule):
    def __init__(
        self,
        model,
        spatial_transform=None,
        temporal_transform=None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.model = model
        self.hyperparams = cast(Dict[str, Any], self.hparams)
        self.loss_data = nn.MSELoss(reduction="mean")

    def forward(self, x):
        return self.model(x)

    def _data_loss(self, batch):
        x, y = self._extract_spacetime(batch=batch, outputs=True)

        pred = self.forward(x)

        # data loss function
        loss = self.loss_data(y, pred)

        return loss

    def _extract_spacetime(self, batch, outputs=False):

        x_space, x_time = batch["spatial"], batch["temporal"]
        x = torch.cat([x_space, x_time], dim=1)

        if outputs:
            return x, batch["output"]
        else:
            return x

    def training_step(self, batch, batch_idx):

        # loss function
        loss = self._data_loss(batch)

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):

        # loss function
        loss = self._data_loss(batch)

        self.log("val_loss", loss, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):

        # loss function
        loss = self._data_loss(batch)

        self.log("test_loss", loss, prog_bar=True)

        return loss

    def predict_step(self, batch, batch_idx):
        # output
        x = self._extract_spacetime(batch=batch, outputs=False)

        pred = self.forward(x)

        return pred

    def configure_optimizers(self):

        # configure optimizer
        optimizer = Adam(
            self.model.parameters(), lr=self.hyperparams.get("learning_rate", 1e-4)
        )

        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hyperparams.get("warmup", 10),
            max_epochs=self.hyperparams.get("num_epochs", 100),
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

In [None]:
learning_rate = 1e-4
warmup = 1
num_epochs = 5

learn = INRModel(
    model=net,
    loss_data=nn.MSELoss("mean"),
    learning_rate=learning_rate,
    warmup=warmup,
    num_epochs=num_epochs,
)

In [None]:
# run_path = "ige/inr4ssh/1st3rtl0"
# model_path = "checkpoints/epoch=990-step=39640.ckpt"

In [None]:
# from inr4ssh._src.io import get_wandb_config, get_wandb_model

In [None]:
# best_model = get_wandb_model(run_path, model_path)
# best_model.download(replace=True)

### Callbacks

In [None]:
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint

In [None]:
from pathlib import Path

In [None]:
# model_cb = ModelCheckpoint(

#     dirpath=str(Path(wandb_logger.experiment.dir).joinpath("checkpoints")),
#     monitor="val_loss",
#     mode="min",
#     save_top_k=1,
# )

In [None]:
callbacks = [
    # model_cb,
    TQDMProgressBar(refresh_rate=1),
]

### Learner

In [None]:
# state = torch.load(best_model.name, map_location=torch.device("cpu"))

In [None]:
# state["state_dict"]

In [None]:
# kwargs,
# net = SirenNet(**kwargs)
# net.load_state_dict(state_dict)

In [None]:
# learn = INRModel.load_from_checkpoint(
#     best_model.name,
#     model=net,
#     loss_data=nn.MSELoss("mean"),
#     reg_pde=reg_loss,
#     learning_rate=cfg.optim.learning_rate,
#     warmup=cfg.optim.warmup,
#     num_epochs=cfg.optim.num_epochs,
#     alpha=cfg.loss.alpha,
#     qg=cfg.loss.qg,
# )

### Trainer

In [None]:
accelerator = "cpu"

trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator=accelerator,
    # devices=cfg.trainer.devices,
    enable_progress_bar=True,
    # logger=wandb_logger,
    callbacks=callbacks,
    # accumulate_grad_batches=cfg.trainer.grad_batches,
)

### Train

In [None]:
trainer.fit(
    learn,
    datamodule=dm,
)

## Results

### Testing

In [None]:
evaluation.lon_min = -65.0
evaluation.lon_max = -55.0
evaluation.dlon = 0.1
evaluation.lat_min = 33.0
evaluation.lat_max = 43.0
evaluation.dlat = 0.1

evaluation.time_min = "2012-10-22"
evaluation.time_max = "2012-12-02"
evaluation.dt_freq = 1
evaluation.dt_unit = "D"

In [None]:
eval_ds_dir = Path(
    "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM*"
)
eval_ds_dir = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM*"
ds = xr.open_mfdataset(eval_ds_dir, engine="netcdf4")

from inr4ssh._src.preprocess.coords import correct_coordinate_labels


time_min = evaluation.time_min
time_max = evaluation.time_max
lon_min = evaluation.lon_min
lon_max = evaluation.lon_max
lat_min = evaluation.lat_min
lat_max = evaluation.lat_max

ds = (
    ds.sel(
        time=slice(time_min, time_max),
        lon=slice(lon_min, lon_max),
        lat=slice(lat_min, lat_max),
        drop=True,
    )
    .resample(time="1D")
    .mean()
)

ds = correct_coordinate_labels(ds)

ds

In [None]:
import pandas as pd

x, y, z = np.meshgrid(
    ds.coords["longitude"].data, ds.coords["latitude"].data, ds.coords["time"].data
)


ds_ref_coords = pd.DataFrame(
    {"longitude": x.flatten(), "latitude": y.flatten(), "time": z.flatten()}
)

from inr4ssh._src.datasets.alongtrack import AlongTrackDataset
from inr4ssh._src.transforms.dataset import transform_factory

transform = transform_factory(config.preprocess.transform)
ds_eval = AlongTrackDataset(
    ds_ref_coords, spatial_columns=["longitude", "latitude"], temporal_columns=["time"]
)
dl_eval = torch.utils.data.DataLoader(
    dm.ds_predict,
    batch_size=config.dataloader.batchsize_predict,
    shuffle=config.dataloader.shuffle_predict,
    num_workers=config.dataloader.num_workers_predict,
    pin_memory=config.dataloader.pin_memory_predict,
)

In [None]:
len(ds_eval), len(dm.ds_predict)

In [None]:
t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dl_eval, return_predictions=True)
predictions = torch.cat(predictions)
t1 = time.time() - t0
print(f"Time Taken: {t1:.2f} secs")

---
**DATA**

* convert this reference grid to `lat,lon,time,sossheig`
* create dataloader
* Make predictions
* Create xr.dataset from predictions

---
**Metrics**

* RMSE Metrics
* PSD Metrics

In [None]:
ds

In [None]:
ds = xr.open_dataset(self.config.data.dataset_dir)

# correct the labels
logger.info("Correcting labels...")
ds = correct_coordinate_labels(ds)

logger.info("Sorting array by time...")
ds = ds.sortby("time")

# temporal subset
if self.config.preprocess.subset_time.subset_time:
    logger.info("Subsetting temporal...")
    time_min = self.config.preprocess.subset_time.time_min
    time_max = self.config.preprocess.subset_time.time_max
    logger.debug(f"Time Min: {time_min} | Time Max: {time_max}...")
    ds = ds.sel(time=slice(time_min, time_max), drop=True)

# spatial subset
if self.config.preprocess.subset_spatial.subset_spatial:
    logger.info("Subseting spatial...")
    lon_min = self.config.preprocess.subset_spatial.lon_min
    lon_max = self.config.preprocess.subset_spatial.lon_max
    lat_min = self.config.preprocess.subset_spatial.lat_min
    lat_max = self.config.preprocess.subset_spatial.lat_max
    logger.debug(f"Lon Min: {lon_min} | Lon Max: {lon_min}...")
    ds = ds.where(
        (ds["longitude"] >= lon_min)
        & (ds["longitude"] <= lon_max)
        & (ds["latitude"] >= lat_min)
        & (ds["latitude"] <= lat_max),
        drop=True,
    )

In [None]:
# res = trainer.test(learn, dataloaders=dm.test_dataloader())

# results["data"] = res

In [None]:
# import wandb

# wandb.finish()

### Predictions

In [None]:
t0 = time.time()
predictions = trainer.predict(learn, datamodule=dm, return_predictions=True)
predictions = torch.cat(predictions)
t1 = time.time() - t0
print(f"Time Taken: {t1:.2f} secs")

In [None]:
# ds_pred = dm.create_predictions_ds(predictions)

from inr4ssh._src.operators import differential_simp as diffops_simp

from inr4ssh._src.operators import differential as diffops

In [None]:
df_pred = dm.ds_predict.create_predict_df(predictions.detach().numpy())
ds_pred = df_pred.reset_index().set_index(["longitude", "latitude", "time"]).to_xarray()

In [None]:
ds_pred

In [None]:
ds_pred.predict.thin(time=4).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
# ds_pred.predict.hvplot.image(x="Longitude", y="Latitude", width=500, height=400, cmap="viridis")

In [None]:
# ds_pred = dm.create_predictions_ds(predictions)
# ds_pred

In [None]:
from tqdm.notebook import tqdm, trange

In [None]:
learn.model.eval()
coords, truths, preds, grads, qs = [], [], [], [], []
for ibatch in tqdm(dm.predict_dataloader()):
    with torch.set_grad_enabled(True):
        # prediction
        ibatch["spatial"] = torch.autograd.Variable(
            ibatch["spatial"].clone(), requires_grad=True
        )
        ibatch["temporal"] = torch.autograd.Variable(
            ibatch["temporal"].clone(), requires_grad=True
        )
        ix = torch.cat([ibatch["spatial"], ibatch["temporal"]], dim=1)
        p_pred = learn.model(ix)

        # p_pred = p_pred.clone()
        # p_pred.require_grad_ = True

        # gradient
        p_grad = diffops_simp.gradient(p_pred, ibatch["spatial"])
        # p_grad = diffops.grad(p_pred, ix)
        # q
        q = diffops_simp.divergence(p_grad, ibatch["spatial"])
        # q = diffops.div(p_grad, ix)

    # collect
    # truths.append(ibatch["output"])
    coords.append(ix)
    preds.append(p_pred)
    grads.append(p_grad)
    qs.append(q)

In [None]:
coords = torch.cat(coords).detach().numpy()
preds = torch.cat(preds).detach().numpy()
# truths = torch.cat(truths).detach().numpy()
grads = torch.cat(grads).detach().numpy()
qs = torch.cat(qs).detach().numpy()

In [None]:
df_pred = dm.ds_predict.create_predict_df(predictions.detach().numpy())
df_pred["u"] = -grads[:, 1]
df_pred["v"] = grads[:, 0]
df_pred["q"] = qs
ds_pred = df_pred.reset_index().set_index(["longitude", "latitude", "time"]).to_xarray()

### Figure I: Predictions

In [None]:
ds_pred.q.thin(time=4).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure II: Ground Truth

In [None]:
ds_pred.true.thin(time=1).plot.imshow(
    col="time",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

In [None]:
ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure III: Absolute Error

In [None]:
(ds_pred.true - ds_pred.pred).thin(time=1).plot.imshow(
    col="time",
    robust=True,
    col_wrap=3,
    cmap="RdBu_r",
)

In [None]:
(ds_pred.true - ds_pred.pred).hvplot.image(
    x="Nx", y="Ny", width=500, height=400, cmap="viridis"
)