# NerF + QG Loss

The full QG equation is given by:

$$
\begin{aligned}
\partial_t q + \det \boldsymbol{J}(q, \psi) &= 0
\end{aligned}
$$

where:

* $q=\nabla^2 \psi$
* $\det \boldsymbol{J}(q, \psi)=\partial_x q\partial_y\psi - \partial_y q\partial_x\psi$.

We are interested in finding some NerF method that can take in the spatial-temporal coordinates, $\mathbf{x}_\phi$, and output a vector corresponding to the PV and stream function, $\psi$, i.e. $\mathbf{y}_\text{obs}$.

$$
\mathbf{y}_\text{obs} = \boldsymbol{f_\theta}(\mathbf{x}_\phi) + \epsilon, \hspace{5mm}\epsilon \sim \mathcal{N}(0, \sigma^2)
$$

We use a SIREN network which is a fully connected neural network with the $sin$ activation function.

* **Data Inputs**: `256x256x11`
* **Data Ouputs**: `2`


In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])

# append to path
sys.path.append(str(root))

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from ml_collections import config_dict
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger

from inr4ssh._src.datamodules.qg_sim import QGSimulation
from inr4ssh._src.models.siren import Siren, SirenNet
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from functools import partial
from typing import Dict, Any, cast
from torch.optim import Adam
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pathlib import Path
from inr4ssh._src.operators import differential_simp as diffops_simp
from inr4ssh._src.operators import differential as diffops
import copy
from tqdm.notebook import tqdm

pl.seed_everything(123)

import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from experiments.qg.configs.config_image import get_config
from experiments.qg.losses import RegQG


sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# from ml_collections import config_dict

# cfg = config_dict.ConfigDict()

# # logging args
# cfg.log = config_dict.ConfigDict()
# cfg.log.mode = "online" #"disabled"
# cfg.log.project ="inr4ssh"
# cfg.log.entity = "ige"
# cfg.log.log_dir = "/Users/eman/code_projects/logs/"
# cfg.log.resume = False

# # data args
# cfg.data = config_dict.ConfigDict()
# cfg.data.data_dir =  f"/Users/eman/code_projects/torchqg/data/qgsim_simple_128x128.nc"

# # preprocessing args
# cfg.pre = config_dict.ConfigDict()
# cfg.pre.noise = 0.01
# cfg.pre.dt = 1.0
# cfg.pre.time_min = 500
# cfg.pre.time_max = 511
# cfg.pre.seed = 123

# # train/test args
# cfg.split = config_dict.ConfigDict()
# cfg.split.train_prct = 0.9

# # dataloader args
# cfg.dl = config_dict.ConfigDict()
# cfg.dl.batchsize_train = 2048
# cfg.dl.batchsize_val = 1_000
# cfg.dl.batchsize_test = 5_000
# cfg.dl.batchsize_predict = 10_000
# cfg.dl.num_workers = 0
# cfg.dl.pin_memory = False

# # loss arguments
# cfg.loss = config_dict.ConfigDict()
# cfg.loss.qg = True
# cfg.loss.alpha = 1e-4

# # optimizer args
# cfg.optim = config_dict.ConfigDict()
# cfg.optim.warmup = 10
# cfg.optim.num_epochs = 100
# cfg.optim.learning_rate = 1e-4

# # trainer args
# cfg.trainer = config_dict.ConfigDict()
# cfg.trainer.accelerator = None
# cfg.trainer.devices = 1
# cfg.trainer.grad_batches = 1

In [None]:
from inr4ssh._src.io import transform_dict

cfg = get_config()
# custom config for local

# data, log directory
cfg.data.data_dir = "/Users/eman/code_projects/torchqg/data/qgsim_simple_128x128.nc"
cfg.data.log_dir = "/Users/eman/code_projects/logs/"
cfg.data.res = "128x128"

# logging args
cfg.log.mode = "online"

# preprocessing params
cfg.pre.time_max = 2

# model params
cfg.model.hidden_dims = 512

# dataloader params
cfg.dl.batchsize_train = 128
cfg.dl.batchsize_val = 2048
cfg.dl.num_workers = 0
cfg.dl.pin_memory = False

# optimization args
cfg.optim.num_epochs = 2_000
cfg.optim.warmup = 50
cfg.optim_qg.num_epochs = 5_000
cfg.optim_qg.warmup = 50

# trainer params
cfg.trainer.grad_batches = 1
cfg.trainer.accelerator = "mps"
cfg.trainer_qg.accelerator = "cpu"


# cfg.to_dict()
cfg

In [None]:
wandb.finish()

In [None]:
wandb_logger = WandbLogger(
    config=cfg.to_dict(),
    mode=cfg.log.mode,
    project=cfg.log.project,
    entity=cfg.log.entity,
    dir=cfg.log.log_dir,
    resume=False,
)

## Data Module

Now we will put all of the preprocessing routines together. This is **very important** for a few reasons:

1. It collapses all of the operations in a modular way
2. It makes it reproducible for the next people
3. It makes it very easy for the PyTorch-Lightning framework down the line.

In [None]:
dm = QGSimulation(cfg)
dm.setup()

In [None]:
len(dm.ds_train)

In [None]:
x_init, y_init = dm.ds_train[:10]

In [None]:
x_init.shape, y_init.shape

## NerF

This standard Neural Fields.

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
w0 = 1.0
w0_initial = 30.0
c = 6.0
final_activation = None  # nn.Sigmoid()

net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    c=c,
    final_activation=final_activation,
)

In [None]:
out = net(x_init)

## PINNS Loss

$$
\partial_t \nabla^2 \psi + \det J(\psi, \nabla^2\psi) = 0
$$

In [None]:
x_init.max(dim=0)

## Experiment

In [None]:
data_loss = nn.MSELoss(reduction=cfg.loss.reduction)

### Callbacks

In [None]:
model_cb = ModelCheckpoint(
    dirpath=str(Path(wandb_logger.experiment.dir).joinpath("checkpoints")),
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)

In [None]:
callbacks = [model_cb, TQDMProgressBar(refresh_rate=9)]

### Learner

In [None]:
from experiments.qg.trainer import INRModel

In [None]:
learn = INRModel(
    model=net,
    loss_data=data_loss,
    reg_pde=None,
    learning_rate=cfg.optim.learning_rate,
    warmup=cfg.optim.warmup,
    warmup_start_lr=cfg.optim.warmup_start_lr,
    eta_min=cfg.optim.eta_min,
    num_epochs=cfg.optim.num_epochs,
    alpha=0.0,
    qg=False,
)

### Trainer

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=cfg.optim.num_epochs,
    accelerator=cfg.trainer.accelerator,
    devices=cfg.trainer.devices,
    enable_progress_bar=True,
    logger=wandb_logger,
    callbacks=callbacks,
    accumulate_grad_batches=cfg.trainer.grad_batches,
)

### Train

In [None]:
trainer.fit(
    learn,
    datamodule=dm,
)

## Results

### Testing

In [None]:
# res = trainer.test(learn, dataloaders=dm.test_dataloader())

# results["data"] = res

In [None]:
# import wandb

# wandb.finish()

### Predictions

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, datamodule=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
ds_pred = dm.create_predictions_ds(predictions)

In [None]:
# ds_pred = dm.create_predictions_ds(predictions)
# ds_pred

The full QG equation is given by:

$$
\begin{aligned}
\partial_t q + \det \boldsymbol{J}(q, \psi) &= 0
\end{aligned}
$$

where:

* $q=\nabla^2 \psi$
* $\det \boldsymbol{J}(q, \psi)=\partial_x q\partial_y\psi - \partial_y q\partial_x\psi$.

In [None]:
# learn.model.eval()
# coords, truths, preds, grads, qs = [], [], [], [], []
# for ix, iy in tqdm(dm.predict_dataloader()):
#     with torch.set_grad_enabled(True):
#         # prediction
#         ix = torch.autograd.Variable(ix.clone(), requires_grad=True)
#         p_pred = learn.model(ix)

#         # p_pred = p_pred.clone()
#         # p_pred.require_grad_ = True

#         # gradient
#         p_grad = diffops_simp.gradient(p_pred, ix)
#         # p_grad = diffops.grad(p_pred, ix)
#         # q
#         q = diffops_simp.divergence(p_grad, ix)
#         # q = diffops.div(p_grad, ix)

#     # collect
#     truths.append(iy.detach().cpu())
#     coords.append(ix.detach().cpu())
#     preds.append(p_pred.detach().cpu())
#     grads.append(p_grad.detach().cpu())
#     qs.append(q.detach().cpu())

In [None]:
learn.model.eval()
coords, truths, preds, grads, qs = [], [], [], [], []
for ix, iy in tqdm(dm.predict_dataloader()):
    with torch.set_grad_enabled(True):
        # prediction
        ix = torch.autograd.Variable(ix.clone(), requires_grad=True)
        p_pred = learn.model(ix)

        # p_pred = p_pred.clone()
        # p_pred.require_grad_ = True

        # gradient
        p_grad = diffops_simp.gradient(p_pred, ix)
        # p_grad = diffops.grad(p_pred, ix)
        # q
        q = diffops_simp.divergence(p_grad, ix, (0, 1))
        # q = diffops_simp.laplace(p_pred, ix)
        # q = diffops.div(p_grad, ix)

    # collect
    truths.append(iy.detach().cpu())
    coords.append(ix.detach().cpu())
    preds.append(p_pred.detach().cpu())
    grads.append(p_grad.detach().cpu())
    qs.append(q.detach().cpu())

In [None]:
coords = torch.cat(coords).numpy()
preds = torch.cat(preds).numpy()
truths = torch.cat(truths).numpy()
grads = torch.cat(grads).numpy()
qs = torch.cat(qs).numpy()

In [None]:
df_data = dm.create_predictions_df()

np.testing.assert_array_almost_equal(coords, df_data[["Nx", "Ny", "steps"]])
np.testing.assert_array_almost_equal(truths, df_data[["p"]])

In [None]:
df_data["p_pred"] = preds
df_data["u_pred"] = -grads[:, 0]
df_data["v_pred"] = grads[:, 1]
df_data["q_pred"] = qs

xr_data = df_data.set_index(["Nx", "Ny", "steps"]).to_xarray()

### Figure I: Stream Function

In [None]:
xr_data.p_pred.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.p.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.p_pred - xr_data.p).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="Reds",
)

In [None]:
# ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure II: PV

In [None]:
xr_data.q_pred.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.q.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.q_pred - xr_data.q).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

In [None]:
# ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure III: U Velocity

In [None]:
xr_data.u_pred.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.u.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.u_pred - xr_data.u).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="Reds",
)

In [None]:
# (ds_pred.true - ds_pred.pred).hvplot.image(
#     x="Nx", y="Ny", width=500, height=400, cmap="viridis"
# )

### Figure IV: V Velocity

In [None]:
xr_data.v_pred.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.v.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.v_pred - xr_data.v).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="Reds",
)

## QG PDE Regularization

In [None]:
class RegQG(nn.Module):
    """QG Regularization Loss

    Parameters
    ----------
    alpha: float
        the weight for the loss to regularize
        default=1e-4
    """

    def __init__(self, alpha: float = 1e-4, trainable=False):
        super().__init__()

        alpha = torch.Tensor([alpha])

        if trainable:
            self.alpha = nn.Parameter(alpha, requires_grad=True)
        else:
            self.register_buffer("alpha", alpha)

    def forward(self, x, f):
        """The forward operation to compute the loss.

        Parameters
        ----------
        x: torch.Tensor, [Batch, Dims]
            the input tensor where the dims are [Dx,Dy,Dt]
        f: Callable[[torch.Tensor], torch.Tensor]
            the function that takes the coordinates and outputs a
            scalar value

        Returns
        -------
        loss: torch.Tensor, [Batch,]"""
        with torch.set_grad_enabled(True):
            x = torch.autograd.Variable(x, requires_grad=True)

            u = f(x)

            # 𝛁𝑢
            u_grad = diffops_simp.gradient(u, x)
            u_x, u_y, u_t = torch.split(u_grad, [1, 1, 1], dim=1)

            # div𝛁𝑢 = ∂𝑥𝛁𝑢 + ∂𝑦𝛁𝑢 = △𝑢
            u_lap = diffops_simp.divergence(u_grad, x, [0, 1])

            # 𝛁△𝑢
            u_lap_x, u_lap_y, u_lap_t = torch.split(
                diffops_simp.gradient(u_lap, x), [1, 1, 1], dim=1
            )

            # det𝑱(𝐮,△𝐮) = ∂𝑥𝐮∂𝑦△𝐮 − ∂𝑦𝐮∂𝑥△𝐮
            det_u_ulap = u_x * u_lap_y - u_y * u_lap_x

            # ∂𝑡△𝐮 + det𝑱(𝐮,△𝐮) = 0
            pde_loss = F.mse_loss(u_lap_t + det_u_ulap, torch.zeros_like(det_u_ulap))

            return self.alpha * pde_loss

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
w0 = 1.0
w0_initial = 30.0
c = 6.0
final_activation = None  # nn.Sigmoid()

net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    c=c,
    final_activation=final_activation,
)

In [None]:
data_loss = nn.MSELoss(reduction=cfg.loss.reduction)

reg_loss = RegQG(cfg.loss.alpha)

In [None]:
learn_qg = INRModel(
    model=copy.deepcopy(learn.model),
    # model=net,
    loss_data=data_loss,
    reg_pde=reg_loss,
    learning_rate=cfg.optim_qg.learning_rate,
    warmup=cfg.optim_qg.warmup,
    warmup_start_lr=cfg.optim_qg.warmup_start_lr,
    eta_min=cfg.optim_qg.eta_min,
    num_epochs=cfg.optim.num_epochs,
    alpha=cfg.loss.alpha,
    qg=True,
)

In [None]:
model_cb = ModelCheckpoint(
    dirpath=str(Path(wandb_logger.experiment.dir).joinpath("checkpoints_qg")),
    monitor="val_loss",
    mode="min",
    save_top_k=1,
)

callbacks = [model_cb, TQDMProgressBar(refresh_rate=10)]

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=cfg.optim_qg.num_epochs,
    accelerator=cfg.trainer_qg.accelerator,
    devices=cfg.trainer_qg.devices,
    enable_progress_bar=True,
    logger=wandb_logger,
    callbacks=callbacks,
    accumulate_grad_batches=cfg.trainer.grad_batches,
)

In [None]:
trainer.fit(
    learn_qg,
    datamodule=dm,
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn_qg, datamodule=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
ds_pred = dm.create_predictions_ds(predictions)

In [None]:
# ds_pred = dm.create_predictions_ds(predictions)
# ds_pred

The full QG equation is given by:

$$
\begin{aligned}
\partial_t q + \det \boldsymbol{J}(q, \psi) &= 0
\end{aligned}
$$

where:

* $q=\nabla^2 \psi$
* $\det \boldsymbol{J}(q, \psi)=\partial_x q\partial_y\psi - \partial_y q\partial_x\psi$.

In [None]:
# learn.model.eval()
# coords, truths, preds, grads, qs = [], [], [], [], []
# for ix, iy in tqdm(dm.predict_dataloader()):
#     with torch.set_grad_enabled(True):
#         # prediction
#         ix = torch.autograd.Variable(ix.clone(), requires_grad=True)
#         p_pred = learn.model(ix)

#         # p_pred = p_pred.clone()
#         # p_pred.require_grad_ = True

#         # gradient
#         p_grad = diffops_simp.gradient(p_pred, ix)
#         # p_grad = diffops.grad(p_pred, ix)
#         # q
#         q = diffops_simp.divergence(p_grad, ix)
#         # q = diffops.div(p_grad, ix)

#     # collect
#     truths.append(iy.detach().cpu())
#     coords.append(ix.detach().cpu())
#     preds.append(p_pred.detach().cpu())
#     grads.append(p_grad.detach().cpu())
#     qs.append(q.detach().cpu())

In [None]:
dm.ds_train[:][0].min(axis=0), dm.ds_train[:][0].max(axis=0)

In [None]:
learn_qg.model.eval()
coords, truths, preds, grads, qs = [], [], [], [], []
for ix, iy in tqdm(dm.predict_dataloader()):
    with torch.set_grad_enabled(True):
        # prediction
        ix = torch.autograd.Variable(ix.clone(), requires_grad=True)
        p_pred = learn_qg.model(ix)

        # p_pred = p_pred.clone()
        # p_pred.require_grad_ = True

        # gradient
        p_grad = diffops_simp.gradient(p_pred, ix)
        # p_grad = diffops.grad(p_pred, ix)
        # q
        q = diffops_simp.divergence(p_grad, ix, (0, 1))
        # q = diffops_simp.laplace(p_pred, ix[:,:2])
        # q = diffops.div(p_grad, ix)

    # collect
    truths.append(iy.detach().cpu())
    coords.append(ix.detach().cpu())
    preds.append(p_pred.detach().cpu())
    grads.append(p_grad.detach().cpu())
    qs.append(q.detach().cpu())

In [None]:
coords = torch.cat(coords).numpy()
preds = torch.cat(preds).numpy()
truths = torch.cat(truths).numpy()
grads = torch.cat(grads).numpy()
qs = torch.cat(qs).numpy()

In [None]:
df_data = dm.create_predictions_df()

# np.testing.assert_array_almost_equal(coords, df_data[["Nx", "Ny", "steps"]])
# np.testing.assert_array_almost_equal(truths, df_data[["p"]])

In [None]:
df_data["p_pred_qg"] = preds
df_data["u_pred_qg"] = -grads[:, 0]
df_data["v_pred_qg"] = grads[:, 1]
df_data["q_pred_qg"] = qs

xr_data = df_data.set_index(["Nx", "Ny", "steps"]).to_xarray()

### Figure I: Stream Function

In [None]:
# xr_data.p_pred.thin(steps=1).plot.imshow(
#     col="steps",
#     robust=True,
#     col_wrap=3,
#     cmap="viridis",
# )

xr_data.p_pred_qg.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.p.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.p_pred_qg - xr_data.p).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="Reds",
)

In [None]:
# ds_pred.pred.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure II: PV

In [None]:
# xr_data.q_pred.thin(steps=1).plot.imshow(
#     col="steps",
#     robust=True,
#     col_wrap=3,
#     cmap="viridis",
# )

xr_data.q_pred_qg.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.q.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.q_pred_qg - xr_data.q).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

In [None]:
# ds_pred.true.hvplot.image(x="Nx", y="Ny", width=500, height=400, cmap="viridis")

### Figure III: U Velocity

In [None]:
# xr_data.u_pred.thin(steps=1).plot.imshow(
#     col="steps",
#     robust=True,
#     col_wrap=3,
#     cmap="viridis",
# )

xr_data.u_pred_qg.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.u.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.u_pred_qg - xr_data.u).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="Reds",
)

In [None]:
# (ds_pred.true - ds_pred.pred).hvplot.image(
#     x="Nx", y="Ny", width=500, height=400, cmap="viridis"
# )

### Figure IV: V Velocity

In [None]:
# xr_data.v_pred.thin(steps=1).plot.imshow(
#     col="steps",
#     robust=True,
#     col_wrap=3,
#     cmap="viridis",
# )

xr_data.v_pred_qg.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)

xr_data.v.thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="viridis",
)
np.abs(xr_data.v_pred_qg - xr_data.v).thin(steps=1).plot.imshow(
    col="steps",
    robust=True,
    col_wrap=3,
    cmap="Reds",
)

In [None]:
class RegQG(nn.Module):
    """QG Regularization Loss

    Parameters
    ----------
    alpha: float
        the weight for the loss to regularize
        default=1e-4
    """

    def __init__(self, alpha: float = 1e-4, trainable=False):
        super().__init__()

        alpha = torch.Tensor([alpha])

        if trainable:
            self.alpha = nn.Parameter(alpha, requires_grad=True)
        else:
            self.register_buffer("alpha", alpha)

    def forward(self, x, f):
        """The forward operation to compute the loss.

        Parameters
        ----------
        x: torch.Tensor, [Batch, Dims]
            the input tensor where the dims are [Dx,Dy,Dt]
        f: Callable[[torch.Tensor], torch.Tensor]
            the function that takes the coordinates and outputs a
            scalar value

        Returns
        -------
        loss: torch.Tensor, [Batch,]"""
        with torch.set_grad_enabled(True):
            x = torch.autograd.Variable(x, requires_grad=True)

            u = f(x)

            # 𝛁𝑢
            u_grad = diffops_simp.gradient(u, x)
            u_x, u_y, u_t = torch.split(u_grad, [1, 1, 1], dim=1)

            # div𝛁𝑢 = ∂𝑥𝛁𝑢 + ∂𝑦𝛁𝑢 = △𝑢
            u_lap = diffops_simp.divergence(u_grad, x, [0, 1])

            # 𝛁△𝑢
            u_lap_x, u_lap_y, u_lap_t = torch.split(
                diffops_simp.gradient(u_lap, x), [1, 1, 1], dim=1
            )

            # det𝑱(𝐮,△𝐮) = ∂𝑥𝐮∂𝑦△𝐮 − ∂𝑦𝐮∂𝑥△𝐮
            det_u_ulap = u_x * u_lap_y - u_y * u_lap_x

            # ∂𝑡△𝐮 + det𝑱(𝐮,△𝐮) = 0
            pde_loss = F.mse_loss(u_lap_t + det_u_ulap, torch.zeros_like(det_u_ulap))

            return self.alpha * pde_loss

$$
\begin{aligned}
\boldsymbol{\nabla}u &= \\
\text{div}\boldsymbol{\nabla}u &= \partial_x \boldsymbol{\nabla}u + \partial_y \boldsymbol{\nabla}u = \boldsymbol{\triangle}u \\
\boldsymbol{\nabla} \boldsymbol{\triangle}u \\
\det\boldsymbol{J}(\mathbf{u},\boldsymbol{\triangle}\mathbf{u}) &= 
\partial_x \mathbf{u}\partial_y\boldsymbol{\triangle}\mathbf{u} - 
\partial_y \mathbf{u}\partial_x\boldsymbol{\triangle}\mathbf{u}\\
\partial_t\boldsymbol{\triangle}\mathbf{u} + \det\boldsymbol{J}(\mathbf{u},\boldsymbol{\triangle}\mathbf{u}) &= 0
\end{aligned}
$$