# Demo - Siren

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])

# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import torch
from torch import nn
from tqdm.notebook import tqdm as tqdm
import os, imageio

from inr4ssh._src.models.siren import Siren, SirenNet, Modulator, ModulatedSirenNet
from inr4ssh._src.models.activations import Sine
from inr4ssh._src.data.images import load_fox
from inr4ssh._src.features.coords import get_image_coordinates
from inr4ssh._src.datamodules.images import ImageFox, ImageCameraman
from torch.nn import ReLU
import pytorch_lightning as pl
from inr4ssh._src.models.image import ImageModel
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.loggers import WandbLogger

import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

## Data

The input data is a coordinate vector, $\mathbf{x}_\phi$, of the image coordinates.

$$
\mathbf{x}_\phi \in \mathbb{R}^{D_\phi}
$$

where $D_\phi = [\text{x}, \text{y}]$. So we are interested in learning a function, $\boldsymbol{f}$, such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

$$
\mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta})
$$

In [None]:
img = load_fox()

In [None]:
plt.figure()
plt.imshow(img)
plt.show()

### Data Module

In [None]:
dm = ImageFox(batch_size=4096).setup()
# dm = ImageCameraman(batch_size=4096).setup()

In [None]:
len(dm.ds_train)

In [None]:
X_train, y_train = dm.ds_train[:]
X_valid, y_valid = dm.ds_valid[:]
X_test, y_test = dm.ds_test[:]
X_train = torch.cat([X_train, X_valid])
y_train = torch.cat([y_train, y_valid])

## Siren Net


### Sine Activation Layer

In [None]:
init = dm.ds_train[:32]
x_init, y_init = init
x_init.shape, y_init.shape

In [None]:
out = Sine()(x_init)

out.shape

### Siren Layer

$$
\mathbf{f}_\ell(\mathbf{x}) = \sin\left(\omega_0 \left(\mathbf{w}^{(\ell)}\mathbf{x} + \mathbf{b}^{(\ell)} \right)\right)
$$

In [None]:
dim_in = x_init.shape[1]
dim_out = y_init.shape[1]
w0 = 1.0
c = 6.0

layer = Siren(
    dim_in=dim_in,
    dim_out=dim_out,
    w0=w0,
    c=c,
)

In [None]:
out = layer(x_init)

### Siren Network

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
w0 = 1.0
w0_initial = 30.0
c = 6.0
final_activation = nn.Sigmoid()

siren_net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    final_activation=final_activation,
)

In [None]:
out = siren_net(x_init)

In [None]:
siren_net

## Training

In [None]:
if torch.has_mps:
    device = "mps"
elif torch.has_cuda:
    device = "cuda"
else:
    device = "cpu"

### Optimization

In [None]:
from skorch.callbacks import EarlyStopping, LRScheduler

In [None]:
# learning rate scheduler
lr_scheduler = LRScheduler(
    policy="ReduceLROnPlateau",
    monitor="valid_loss",
    mode="min",
    factor=0.1,
    patience=5,
)
# learning rate scheduler
# lr_scheduler = LRScheduler(
#     policy="CosineAnnealingWarmUpLR",
#     monitor="valid_loss",
#     mode="min",
#     factor=0.1,
#     patience=5,
# )

# early stopping
estop_callback = EarlyStopping(
    monitor="valid_loss",
    patience=10,
)

callbacks = [
    ("earlystopping", estop_callback),
    ("lrscheduler", lr_scheduler),
]

In [None]:
from skorch.dataset import ValidSplit

# train split percentage
train_split = ValidSplit(0.5, stratified=False)

### Model Wrapper

In [None]:
from skorch import NeuralNetRegressor

skorch_net = NeuralNetRegressor(
    module=siren_net,
    max_epochs=200,
    lr=0.01,
    batch_size=4096,
    device=device,
    optimizer=torch.optim.Adam,
    train_split=train_split,
    callbacks=callbacks,
    iterator_train__num_workers=2,
)

In [None]:
skorch_net.fit(X_train, y_train)

In [None]:
fig, ax = plt.subplots()

ax.plot(skorch_net.history[:, "train_loss"], label="Train Loss")
ax.plot(skorch_net.history[:, "valid_loss"], label="Validation Loss")

ax.set(yscale="log", xlabel="Epochs", ylabel="Mean Squared Error")

plt.legend()
plt.show()

In [None]:
y_pred = skorch_net.predict(X_test)

In [None]:
img_pred = dm.coordinates_2_image(y_pred)

In [None]:
fig, axs = plt.subplots(ncols=2)
axs[0].imshow(
    img,
    cmap="gray",
)
axs[0].set_title("True Image")
axs[1].imshow(
    img_pred,
    cmap="gray",
)
axs[1].set_title("Interpolated Image")
plt.tight_layout()
plt.show()

In [None]:
plt.figure()
plt.imshow(img, cmap="gray")
plt.show()

In [None]:
img_pred = dm.coordinates_2_image(y_pred)

## Experiment

In [None]:
import pytorch_lightning as pl

pl.seed_everything(123)

#### Dataset



In [None]:
learning_rate = 1e-4
siren_net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    final_activation=final_activation,
)
dm = ImageFox(batch_size=4096).setup()

In [None]:
learn = ImageModel(siren_net, learning_rate=learning_rate)

In [None]:
callbacks = [TQDMProgressBar(refresh_rate=100)]

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=100,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(
    learn,
    train_dataloaders=dm.train_dataloader(),
    val_dataloaders=dm.train_dataloader(),
)

In [None]:
trainer.test(learn, dataloaders=dm.test_dataloader())

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
predictions.shape

In [None]:
predictions.min(), predictions.max()

In [None]:
from einops import rearrange

In [None]:
img_pred = dm.coordinates_2_image(predictions)

In [None]:
plt.figure()
plt.imshow(img_pred, cmap="gray")
plt.show()