# QG Simulation

* Fourier Feature Networks (FFN)
* Siren
* Modulated Siren (ModSiren)
* Multiplicative Filter Networks (MFN)
    * Fourier
    * Gabor

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])
local = root.joinpath("experiments/qg")

# append to path
sys.path.append(str(root))
sys.path.append(str(local))

In [None]:
from typing import Dict, Any, cast
import tabulate
from IPython.display import display, HTML

import numpy as np
import xarray as xr
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import ReLU

from tqdm.notebook import tqdm as tqdm
import os, imageio

from inr4ssh._src.models.mlp import MLP
from inr4ssh._src.models.activations import Swish
from inr4ssh._src.datamodules.qg import QGSimulation

import pytorch_lightning as pl
from inr4ssh._src.models.image import ImageModel
from inr4ssh._src.models.siren import Siren, SirenNet, Modulator, ModulatedSirenNet
from inr4ssh._src.models.mfn import FourierNet, GaborNet
from inr4ssh._src.models.activations import get_activation

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.loggers import WandbLogger

import hvplot.xarray

pl.seed_everything(123)

import matplotlib.pyplot as plt
import seaborn as sns
import wandb

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Config

In [None]:
import config
from simple_parsing import ArgumentParser

In [None]:
# initialize argparse
parser = ArgumentParser()

# add all experiment arguments
parser.add_arguments(config.Logging, dest="logging")
parser.add_arguments(config.DataDir, dest="data")
parser.add_arguments(config.PreProcess, dest="preprocess")
parser.add_arguments(config.Features, dest="features")
parser.add_arguments(config.TrainTestSplit, dest="traintest")
parser.add_arguments(config.DataLoader, dest="dataloader")
parser.add_arguments(config.Model, dest="model")
parser.add_arguments(config.Siren, dest="siren")
parser.add_arguments(config.MLP, dest="mlp")
parser.add_arguments(config.FFN, dest="ffn")
parser.add_arguments(config.ModulatedSiren, dest="modsiren")
parser.add_arguments(config.MFN, dest="mfn")
parser.add_arguments(config.Losses, dest="losses")
parser.add_arguments(config.Optimizer, dest="optimizer")
parser.add_arguments(config.LRScheduler, dest="lr_scheduler")
parser.add_arguments(config.Callbacks, dest="callbacks")
# parser.add_arguments(config.EvalData, dest="eval")
# parser.add_arguments(config.Metrics, dest="metrics")
# parser.add_arguments(config.Viz, dest="viz")

# parse args
args = parser.parse_args([])

In [None]:
args.data.data_dir = f"/Users/eman/.CMVolumes/cal1_data/qg_data/public/"
# args.data.data_dir = f"/Volumes/EMANS_HDD/data/qg_sim/"
args.logging.log_dir = "~/code_projects/logs/"
args.logging.mode = "disabled"

args.preprocess.coarsen_Nx = 1
args.preprocess.coarsen_Ny = 1
args.preprocess.coarsen_time = 5

args.traintest.noise = None

args.traintest.step_Nx = 2
args.traintest.step_Ny = 2
args.traintest.step_time = 2
args.traintest.missing_data = 0.9

args.dataloader.batch_size = 4096  # 8192

args.model = "siren"
model_config = args.siren

## Logger

In [None]:
from inr4ssh._src.io import simpleargs_2_ndict

log_options = args.logging

# params_dict = simpleargs_2_ndict(args)

wandb_logger = WandbLogger(
    config=args,
    mode=log_options.mode,
    project=log_options.project,
    entity=log_options.entity,
    dir=log_options.log_dir,
    resume=False,
)

## Data

In [None]:
# from inr4ssh._src.data.qg import load_qg_data

# ds = load_qg_data(dm.data.data_dir)

# ds = ds.coarsen({"time": 2}, boundary="trim", coord_func="mean")
# ds

In [None]:
dm = QGSimulation(
    data=args.data,
    preprocess=args.preprocess,
    traintest=args.traintest,
    features=args.features,
    dataloader=args.dataloader,
    # eval=args.eval
)

dm.setup()

In [None]:
len(dm.ds_train), len(dm.ds_valid), len(dm.ds_test), len(dm.ds_predict)

In [None]:
# import pandas as pd
# import numpy as np


# def array_2_da(coords, data, name="full_pred", coords_name: List[str]=["x", "y", "t"]):
#     return pd.DataFrame(np.hstack([coords, data]), columns=[coords_name]+[name]).set_index(
#         coords_name).to_xarray()

In [None]:
xr_data = dm.create_xr_dataset("predict")
xr_data = xr.merge([xr_data, dm.create_xr_dataset("train")])

In [None]:
xr_data.train.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
# xr_data.predict.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
# data.obs.thin(time=2).plot.imshow(
#     col="time", robust=True, col_wrap=4, cmap="viridis",
# )

The input data is a coordinate vector, $\mathbf{x}_\phi$, of the image coordinates.

$$
\mathbf{x}_\phi \in \mathbb{R}^{D_\phi}
$$

where $D_\phi = [\text{x}, \text{y}]$. So we are interested in learning a function, $\boldsymbol{f}$, such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

$$
\mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta})
$$

### Data Module


#### Train-Test Split

In this example, we are only taking every other pixel for training and validation. It is a very simple and well-defined problem which each of the neural networks should be able to solve. The final test image is the original full resolution image.

Notice how we have `131_072` points from training and validation and `262_144` for the testing. This is because we have *raveled* the image where each coordinate is a vector of `x,y`. So these are a lot of points...

In [None]:
init = dm.ds_train[:32]
x_init, y_init = init
x_init.shape, y_init.shape

In [None]:
x_init.min(), x_init.max()

### Optimizer

For this, we will use a simple adam optimizer with a `learning_rate` of 1e-4. From many studies, it appears that a lower learning rate works well with this methods because there is a lot of data. In addition, a bigger `batch_size` is also desireable. We will set the `num_epochs` to `1_000` which should be good enough for a single image. Obviously more epochs and a better learning rate scheduler would result in better results but this will be sufficient for this demo.

In [None]:
num_epochs = 100
learning_rate = 1e-4
warmup = 5

### Scheduler

<p align="center">
<img src="http://www.bdhammel.com/assets/learning-rate/resnet_loss.png" alt="drawing" width="300"/>
<figcaption align = "center">
  <b>Fig.1 - An example for learning rate reduction when the validation loss stagnates. Source: 
    <a href="http://www.bdhammel.com/assets/learning-rate/resnet_loss.png">Blog</a>
  </b>
  </figcaption>
</p>

We will use a simple learning rate scheduler - `reduce_lr_on_plateau`. This will automatically reduce the learning rate as the validation loss stagnates. It will ensure that we really squeeze out as much performance as possible from our models during the training procedure.We start with a (relatively) high `learning_rate` of `1e-4` so we will set the `patience` to 5 epochs. So if there is no change in with every epoch, we decrease the learning rate by a factor of `0.1`.

This is a rather crude (but effective) method but it tends to work well in some situations. A better method might be the `cosine_annealing` method or the `exponential_decay` method. See other [examples](https://www.kaggle.com/code/snnclsr/learning-rate-schedulers/).

### Loss

We are going with a very simple `loss` function: the *mean squared error* (MSE). This is given by:

$$
\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{N} \sum_n^N \left( \mathbf{y}_n - \boldsymbol{f}_{\boldsymbol{\theta}}(\mathbf{x}_n) \right)^2
$$

We won't code this from scratch, we will just use the PyTorch function, `nn.MSELoss`, and we will use the `mean` reduction parameter.


### PSNR

We will also keep track of the signal to noise ratio (PSNR) which will give us an indication of how well we are learning.

$$
\text{PSNR}(\mathbf{x}) = - 10 \log (2 * \text{MSE}(\mathbf{x}))
$$

In [None]:
results = dict()

## Experiment

In [None]:
num_steps_per_epoch = len(dm.train_dataloader())

In [None]:
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from inr4ssh._src.losses.qg import qg_loss
from inr4ssh._src.operators.differential_simp import gradient
from inr4ssh._src.operators.differential import grad as grad_adv

In [None]:
class ImageModel(pl.LightningModule):
    def __init__(
        self,
        model,
        optimizer: str = "adam",
        qg: bool = True,
        alpha: float = 0.1,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.model = model
        self.hyperparams = cast(Dict[str, Any], self.hparams)
        self.loss = nn.MSELoss(reduction="mean")
        # self.reg = QGRegularization()

    def forward(self, x):
        return self.model(x)

    def _qg_loss(self, x):
        with torch.set_grad_enabled(True):

            f = self.hyperparams.get("f", 0.0001)
            g = self.hyperparams.get("g", 9.81)

            x_var = torch.autograd.Variable(x, requires_grad=True)
            out = self.model(x_var)
            out *= f / g
            loss = qg_loss(
                out,
                x_var,
                f=f,
                g=g,
                Lr=self.hyperparams.get("Lr", 1.0),
                reduction="mean",
            )

        return loss

    def training_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        # loss function
        loss_data = self.loss(y, pred)

        if self.hyperparams.get("qg", False):
            # x_var = torch.autograd.Variable(x, requires_grad=True)
            # out = self.forward(x_var)
            # reg = qg_loss(out, x_var, 1.0, 1.0, 1.0, "mean")
            reg = self._qg_loss(x)

            loss = loss_data + self.hyperparams.get("alpha", 0.2) * reg

            self.log("train_reg", reg, prog_bar=True)
            self.log("train_data", loss_data, prog_bar=True)
        else:
            loss = loss_data

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        # loss function
        loss_data = self.loss(y, pred)

        if self.hyperparams.get("qg", False):
            # x_var = torch.autograd.Variable(x, requires_grad=True)
            # out = self.forward(x_var)
            # reg = qg_loss(out, x_var, 1.0, 1.0, 1.0, "mean")
            reg = self._qg_loss(x)

            loss = loss_data + self.hyperparams.get("alpha", 0.2) * reg

            self.log("val_reg", reg, prog_bar=True)
            self.log("val_data", loss_data, prog_bar=True)
        else:
            loss = loss_data

        self.log("val_loss", loss, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        # loss function
        loss_data = self.loss(y, pred)

        if self.hyperparams.get("qg", False):
            # x_var = torch.autograd.Variable(x, requires_grad=True)
            # out = self.forward(x_var)
            # reg = qg_loss(out, x_var, 1.0, 1.0, 1.0, "mean")
            reg = self._qg_loss(x)

            loss = loss_data + self.hyperparams.get("alpha", 0.2) * reg

            self.log("test_reg", reg, prog_bar=True)
            self.log("test_data", loss_data, prog_bar=True)
        else:
            loss = loss_data

        self.log("test_loss", loss, prog_bar=True)

        return loss

    def predict_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        return pred

    def configure_optimizers(self):

        # configure optimizer
        if self.hyperparams.optimizer == "adam":
            optimizer = Adam(
                self.model.parameters(), lr=self.hyperparams.get("lr", 1e-4)
            )
        elif self.hyperparams.optimizer == "lbfgs":
            optimizer = Adam(
                self.model.parameters(), lr=self.hyperparams.get("lr", 1e-4)
            )
        else:
            raise ValueError(f"Unrecognized optimizer: {optimizer}")

        # configure lr scheduler
        # scheduler = ReduceLROnPlateau(
        #     optimizer, patience=self.hyperparams.get("lr_schedule_patience", 5)
        # )
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(
        #     optimizer,
        #     max_lr=self.hyperparams.get("lr", 1e-4),
        #     steps_per_epoch=num_steps_per_epoch,
        #     epochs=num_epochs,
        # )
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer, T_max=num_epochs, eta_min=0
        # )
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hyperparams.get("warmup", 10),
            max_epochs=self.hyperparams.get("num_epochs", 100),
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

### Callbacks

In [None]:
callbacks = [TQDMProgressBar(refresh_rate=1)]

## Model

* Fourier Feature Networks (FFN)
* Siren
* Modulated Siren (ModSiren)
* Multiplicative Filter Networks (MFN)
    * Fourier
    * Gabor


In [None]:
from inr4ssh._src.models.models_factory import model_factory

In [None]:
dim_in = x_init.shape[1]
dim_out = y_init.shape[1]


# args.ffn.encoder = "positional"
net = model_factory(args.model, dim_in, dim_out, model_config)
net

In [None]:
# dim_in = x_init.shape[1]
# dim_hidden = 256
# dim_out = y_init.shape[1]
# num_layers = 5
# activation = "swish"  # Swish()  # nn.ReLU()#
# final_activation = "identity"

# mlp_net = MLP(
#     dim_in=dim_in,
#     dim_hidden=dim_hidden,
#     dim_out=dim_out,
#     num_layers=num_layers,
#     activation=get_activation(activation),
#     final_activation=get_activation(final_activation),
# )

In [None]:
learn = ImageModel(
    net,
    learning_rate=learning_rate,
    warmup=warmup,
    num_epochs=num_epochs,
    alpha=0.0,
    Lr=1.0,
    f=1.0,
    g=1.0,
    qg=False,
)

In [None]:
out = learn.forward(x_init)

# assert out.shape[0] == x_init.shape[0]

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    # devices=1,
    enable_progress_bar=True,
    logger=wandb_logger,
    callbacks=callbacks,
    # gradient_clip_val=1.0,
    # gradient_clip_algorithm="norm",
)

In [None]:
trainer.fit(
    learn,
    datamodule=dm,
)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["data"] = res

In [None]:
table = [
    [
        key,
        f"{results[key][0]['test_loss']:4.4f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
ds_pred = dm.create_predictions_ds(predictions)
ds_pred

In [None]:
ds_pred.pred.thin(time=2).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
ds_pred.true.thin(time=2).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
# ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
# ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
# from inr4ssh._src.operators.differential import grad, laplacian, jac
# from torch.autograd import functional as f_grad

In [None]:
# # gradient once
# with torch.enable_grad():
#     x = dm.ds_train[:32][0]
#     x_var = torch.autograd.Variable(x.clone(), requires_grad=True)
#     u = learn.model(x_var)

#     x_var_space = x_var[:, :2].clone()
#     x_var_space = torch.autograd.Variable(x_var_space, requires_grad=True)
#     u_lap = laplacian(u, x_var_space)

#     # jacobian
#     u_jac = grad(u, x_var)

#     # laplacian
#     x_var_space = x_var[:, :2].clone()
#     x_var_space = torch.autograd.Variable(x_var_space, requires_grad=True)
#     u_lap = laplacian(u, x_var_space)

#     # gradient o laplacian
#     u_lap_jac = grad(u_lap, x_var)
#     assert u_jac.shape == x_var.shape
#     assert u_lap.shape == u.shape
#     assert u_lap_jac.shape == x_var.shape

#     loss = u_lap

#     loss = loss.square().mean()

In [None]:
# x_var.shape

In [None]:
# def f(x):
#     return learn.model(x).sum()

In [None]:
# u_grad_ = f_grad.jacobian(f, x_var)

# torch.testing.assert_close(u_grad, u_grad_)

In [None]:
# loss = qg_loss(out, x_var, reduction="mean")
# loss

## QG Regularization

In [None]:
num_epochs = 1_000
learning_rate = 1e-4
warmup = 50

In [None]:
class ImageModelQG(ImageModel):
    def configure_optimizers(self):

        # configure optimizer
        if self.hyperparams.optimizer == "adam":
            optimizer = Adam(
                self.model.parameters(), lr=self.hyperparams.get("lr", 1e-4)
            )
        elif self.hyperparams.optimizer == "lbfgs":
            optimizer = Adam(
                self.model.parameters(), lr=self.hyperparams.get("lr", 1e-4)
            )
        else:
            raise ValueError(f"Unrecognized optimizer: {optimizer}")

        # configure lr scheduler
        # scheduler = ReduceLROnPlateau(
        #     optimizer, patience=self.hyperparams.get("lr_schedule_patience", 5)
        # )
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(
        #     optimizer,
        #     max_lr=self.hyperparams.get("lr", 1e-4),
        #     steps_per_epoch=num_steps_per_epoch,
        #     epochs=num_epochs,
        # )
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer, T_max=num_epochs, eta_min=0
        # )
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hyperparams.get("warmup", 10),
            max_epochs=self.hyperparams.get("num_epochs", 100),
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

In [None]:
learn_pinns = ImageModelQG(
    learn.model,
    learning_rate=learning_rate,
    alpha=0.1,
    qg=True,
    num_epochs=num_epochs,
    warmup=warmup,
)

In [None]:
out = learn_pinns.forward(x_init)

# assert out.shape[0] == x_init.shape[0]

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    # devices=1,
    enable_progress_bar=True,
    logger=wandb_logger,
    callbacks=callbacks,
    # gradient_clip_val=1.0,
    # gradient_clip_algorithm="norm",
)

In [None]:
trainer.fit(
    learn_pinns,
    datamodule=dm,
)

In [None]:
res = trainer.test(learn_pinns, dataloaders=dm.test_dataloader())

results["qg"] = res

In [None]:
table = [
    [
        key,
        f"{results[key][0]['test_loss']:4.4f}",
        # f"{results[key][0]['test_reg']:4.4f}",
        # f"{results[key][0]['test_data']:4.4f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "Loss",  # "Num Parameters"
                # "Reg",
                # "MSE"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
ds_pred = dm.create_predictions_ds(predictions)
ds_pred

In [None]:
ds_pred.pred.thin(time=2).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
ds_pred.true.thin(time=2).plot.imshow(
    col="time",
    robust=True,
    col_wrap=4,
    cmap="viridis",
)

In [None]:
ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
x_init.min(dim=0), x_init.max(dim=0)

```python
coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input
SSH = self.firstnet(coords)
gradSSH = self.gradient(SSH, coords)
dSSHdx = gradSSH[:,0:1]
dSSHdy = gradSSH[:,1:2]
d2SHHd2x = self.gradient(dSSHdx, coords)[:,0:1]
d2SHHd2y = self.gradient(dSSHdy, coords)[:,1:2]
dQ = self.gradient(d2SHHd2x+d2SHHd2y, coords)
output = self.secondnet(self.Bnorm(torch.cat((dSSHdy,
                                              dSSHdx,
                                              d2SHHd2x+d2SHHd2y,
                                              dQ[:,0:1],
                                              dQ[:,1:2]),1)))
                                              #dQ[:,0:1] * dSSHdy,
                                              #dQ[:,1:2] * dSSHdx
output =  dSSHdx *  dQ[:,1:2] -   dSSHdy * dQ[:,0:1]
return (1e-5*dQ[:,2:3]-output), coords, SSH
```

In [None]:
# create variable [Nx, Ny, T]
x_var = torch.autograd.Variable(x_init[:5], requires_grad=True)
# ssh
ssh = learn.model(x_var)

$$
\mathbf{J} = \boldsymbol{J}(\mathbf{X})
$$

where:
* $\mathbf{X} \in \mathbb{R}^{N \times D_\phi}$
* $\boldsymbol{J}: \mathbb{R}^{N \times D_\phi} \rightarrow \mathbb{R}^{N \times D_\phi}$
* $\mathbf{J} \in \mathbb{R}^{N \times D_\phi}$

In [None]:
class QGRegularization(nn.Module):
    def __init__(
        self, f: float = 1.0, g: float = 1.0, Lr: float = 1.0, reduction: str = "mean"
    ):
        super(QGRegularization).__init__()

        self.f = f
        self.g = g
        self.Lr = Lr
        self.reduction = reduction

    def __call__(self, out, x):

        x = x.requires_grad_(True)

        # gradient, nabla x
        out_jac = diffops_simp.gradient(out, x)
        assert ssh_jac.shape == x.shape

        # calculate term 1
        loss1 = _qg_term1(out_jac, x, self.f, self.g, self.Lr)
        # calculate term 2
        loss2 = _qg_term2(out_jac, self.f, self.g, self.Lr)

        loss = loss_1 - loss_2

        if self.reduction == "sum":
            return loss.sum()
        elif self.reduction == "mean":
            return loss.mean()
        else:
            raise ValueError(f"Unrecognized reduction: {self.reduction}")

In [None]:
def qg_constants(f, g, L_r):
    c_1 = f / g
    c_2 = 1 / L_r**2
    c_3 = c_1 * c_2
    return c_1, c_2, c_3

In [None]:
def _qg_term1(ssh_grad, x_var, f: float = 1.0, g: float = 1.0, L_r: float = 1.0):
    """
    t1 = ∂𝑡∇2𝑢 + 𝑐1 ∂𝑥𝑢 ∂𝑦∇2𝑢 − 𝑐1 ∂𝑦𝑢 ∂𝑥∇2𝑢
    Parameters:
    ----------
    ssh_grad: torch.Tensor, (B, Nx, Ny, T)
    x_var: torch.Tensor, (B,
    f: float, (,)
    g: float, (,)
    Lr: float, (,)

    Returns:
    --------
    loss : torch.Tensor, (B,)
    """

    x_var = x_var.requires_grad_(True)
    c_1, c_2, c_3 = qg_constants(f, g, L_r)
    # jacobian^2 x2, ∇2
    ssh_grad2 = diffops_simp.gradient(ssh_grad, x_var)
    assert ssh_grad2.shape == x_var.shape

    # split jacobian -> partial x, partial y, partial t
    ssh_grad2_x, ssh_grad2_y, ssh_grad2_t = torch.split(ssh_grad2, [1, 1, 1], dim=1)
    assert ssh_grad_x.shape == ssh_grad_y.shape == ssh_grad_t.shape

    # laplacian (spatial), nabla^2
    ssh_lap = ssh_grad2_x + ssh_grad2_y
    assert ssh_lap.shape == ssh_grad_x.shape == ssh_grad_y.shape

    # gradient of laplacian, ∇ ∇2
    ssh_grad_lap = diffops_simp.gradient(ssh_lap, x_var)
    assert ssh_grad_lap.shape == x_var.shape

    # split laplacian into partials
    ssh_grad_lap_x, ssh_grad_lap_y, ssh_grad_lap_t = torch.split(
        ssh_grad_lap, [1, 1, 1], dim=1
    )
    assert ssh_grad_lap_x.shape == ssh_grad_lap_y.shape == ssh_grad_lap_t.shape

    # term 1
    loss = (
        ssh_grad_lap_t
        + c_1 * ssh_grad_x * ssh_grad_lap_y
        - c_1 * ssh_grad_y * ssh_grad_lap_x
    )
    assert (
        loss.shape
        == ssh_grad_lap_t.shape
        == ssh_grad_lap_y.shape
        == ssh_grad_lap_x.shape
    )

    return loss


def _qg_term2(ssh_grad, f: float = 1.0, g: float = 1.0, Lr: float = 1.0):

    """
    t2 = 𝑐2 ∂𝑡(𝑢) + 𝑐3 ∂𝑥(𝑢) ∂𝑦(𝑢) − 𝑐3 ∂𝑦(𝑢) ∂𝑥(𝑢)

    Parameters:
    ----------
    ssh_grad: torch.Tensor, (B, Nx, Ny, T)
    f: float, (,)
    g: float, (,)
    Lr: float, (,)

    Returns:
    --------
    loss : torch.Tensor, (B,)
    """
    _, c_2, c_3 = qg_constants(f, g, Lr)

    # get partial derivatives | partial x, y, t
    ssh_grad_x, ssh_grad_y, ssh_grad_t = torch.split(ssh_jac, [1, 1, 1], dim=1)

    # calculate term 2
    loss = (
        c_2 * ssh_grad_t + c_3 * ssh_grad_x * ssh_grad_y - c_3 * ssh_grad_y * ssh_grad_x
    )

    return loss

In [None]:
def qg_loss(ssh, x, f, g, Lr):

    # gradient, nabla x
    ssh_jac = diffops_simp.gradient(ssh, x)
    assert ssh_jac.shape == x.shape

    # calculate term 1
    loss1 = _qg_term1(ssh_jac, x, f, g, Lr)
    # calculate term 2
    loss2 = _qg_term2(ssh_jac, f, g, Lr)

    return loss1 - loss_2

In [None]:
# create variable [Nx, Ny, T]
x_var = torch.autograd.Variable(x_init[:5], requires_grad=True)
# ssh
ssh = learn.model(x_var)

qg_reg = QGRegularization(f=1.0, g=1.0, Lr=1.0)

loss = qg_reg(ssh, x_var)
assert loss.shape[0] == x_var.shape[0]

**Term II** ($\nabla$)

In [None]:
# get partial derivatives | partial x, y, t
ssh_grad_x, ssh_grad_y, ssh_grad_t = torch.split(ssh_jac, [1, 1, 1], dim=1)

assert ssh_grad_x.shape == ssh_grad_y.shape == ssh_grad_t.shape

In [None]:
f = 1.1
g = 1.0
c_1 = f / g

constant_lr2 = 1.1
c_2 = 1 / constant_lr2

c_3 = c_1 * c_2

In [None]:
# calculate term 2
t2 = c_2 * ssh_grad_t + c_3 * ssh_grad_x * ssh_grad_y - c_3 * ssh_grad_y * ssh_grad_x

assert t2.shape == ssh_grad_x.shape

### Term I ($\nabla \cdot \nabla^2$)

In [None]:
def _qg_term1(ssh_grad, x_var, f, g, L_r):
    """
    t1 = ∂𝑡∇2𝑢 + 𝑐1 ∂𝑥𝑢 ∂𝑦∇2𝑢 − 𝑐1 ∂𝑦𝑢 ∂𝑥∇2𝑢
    """
    c_1, c_2, c_3 = qg_constants(f, g, L_r)
    # jacobian^2 x2, ∇2
    ssh_grad2 = diffops_simp.gradient(ssh_grad, x_var)
    assert ssh_hessian.shape == x_var.shape

    # split jacobian -> partial x, partial y, partial t
    ssh_grad2_x, ssh_grad2_y, ssh_grad2_t = torch.split(ssh_grad2, [1, 1, 1], dim=1)
    assert ssh_grad_x.shape == ssh_grad_y.shape == ssh_grad_t.shape

    # laplacian (spatial), nabla^2
    ssh_lap = ssh_grad2_x + ssh_grad2_y
    assert ssh_lap.shape == ssh_grad_x.shape == ssh_grad_y.shape

    # gradient of laplacian, ∇ ∇2
    ssh_grad_lap = diffops_simp.gradient(ssh_lap, x_var)
    assert ssh_grad_lap.shape == x_var.shape

    # split laplacian into partials
    ssh_grad_lap_x, ssh_grad_lap_y, ssh_grad_lap_t = torch.split(
        ssh_grad_lap, [1, 1, 1], dim=1
    )
    assert ssh_grad_lap_x.shape == ssh_grad_lap_y.shape == ssh_grad_lap_t.shape

    # term 1
    t1 = (
        ssh_grad_lap_t
        + c_1 * ssh_grad_x * ssh_grad_lap_y
        - c_1 * ssh_grad_y * ssh_grad_lap_x
    )
    assert t1.shape == ssh_grad_lap_t.shape

    return loss

In [None]:
# gradient x2, ∇2
ssh_grad2 = diffops_simp.gradient(ssh_grad, x_var)
assert ssh_hessian.shape == x_var.shape

ssh_grad2_x, ssh_grad2_y, ssh_grad2_t = torch.split(ssh_grad2, [1, 1, 1], dim=1)
assert ssh_grad_x.shape == ssh_grad_y.shape == ssh_grad_t.shape

# laplacian
ssh_lap = ssh_grad2_x + ssh_grad2_y
assert ssh_lap.shape == ssh_grad_x.shape

# gradient of laplacian, ∇ ∇2
ssh_grad_lap = diffops_simp.gradient(ssh_lap, x_var)
assert ssh_grad_lap.shape == x_var.shape

ssh_grad_lap_x, ssh_grad_lap_y, ssh_grad_lap_t = torch.split(
    ssh_grad_lap, [1, 1, 1], dim=1
)
assert ssh_grad_lap_x.shape == ssh_grad_lap_y.shape == ssh_grad_lap_t.shape

# term 1
t1 = (
    ssh_grad_lap_t
    + c_1 * ssh_grad_x * ssh_grad_lap_y
    - c_1 * ssh_grad_y * ssh_grad_lap_x
)
assert t1.shape == ssh_grad_lap_t.shape

t1_ = (
    ssh_grad_lap[:, -1:]
    + c_1 * ssh_grad_x * ssh_grad_lap[:, 1:2]
    - c_1 * ssh_grad_y * ssh_grad_lap[:, 0:1]
)
assert t1.shape == ssh_grad_lap_t.shape

In [None]:
t1

$$
\partial_t \nabla^2 u + c_1 \partial_x u \partial_y \nabla^2 u -  c_1 \partial_y u \partial_x \nabla^2u
$$

---
$$
\underbrace{\partial_t \nabla^2 u + c_1 \partial_x u \partial_y \nabla^2 u -  c_1 \partial_y u \partial_x \nabla^2u}_{\nabla^2} +  \underbrace{c_2 \partial_t u + c_3 \partial_x u \partial_y u - c_3\partial_y u \partial_x u}_{\nabla} = 0
$$

---
**Gradient** (order 1)

$$
\nabla u = 
\begin{bmatrix}
\nabla_t u \\ \nabla_x u \\ \nabla_y u
\end{bmatrix}
$$

---
**Gradient** (order 3)

$$
\nabla^3 u = 
\begin{bmatrix}
\nabla^3_t u \\ \nabla^3_x u \\ \nabla^3_y u
\end{bmatrix}
$$


$$
\nabla \cdot \nabla^2 u = 
\begin{bmatrix} 
\nabla_t & \nabla_x & \nabla_y
\end{bmatrix}
\cdot
\begin{bmatrix}
1 \\ c_1 \\ -c_1
\end{bmatrix}
\circ
\begin{bmatrix}
\nabla^2_t u \\ \nabla^2_x u \\ \nabla^2_y u
\end{bmatrix}
$$

In [None]:
# grad u

In [None]:
ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

In [None]:
ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")