# Demo - Fourier Feature Networks

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])

# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import torch
from torch import nn
from tqdm.notebook import tqdm as tqdm
import os, imageio

from ml4ssh._src.models.mlp import MLP
from ml4ssh._src.models.activations import Swish
from ml4ssh._src.data.images import load_fox
from ml4ssh._src.features import get_image_coordinates
from ml4ssh._src.datamodules.images import ImageFox, ImageCameraman
from torch.nn import ReLU
import pytorch_lightning as pl
from ml4ssh._src.models.image import ImageModel
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.loggers import WandbLogger

import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

## Data

The input data is a coordinate vector, $\mathbf{x}_\phi$, of the image coordinates.

$$
\mathbf{x}_\phi \in \mathbb{R}^{D_\phi}
$$

where $D_\phi = [\text{x}, \text{y}]$. So we are interested in learning a function, $\boldsymbol{f}$, such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

$$
\mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta})
$$

In [None]:
img = load_fox()

In [None]:
plt.figure()
plt.imshow(img)
plt.show()

### Data Module

In [None]:
# dm = ImageFox(batch_size=1024).setup()
dm = ImageFox(batch_size=4096, shuffle=False).setup()

In [None]:
len(dm.ds_train)

## Multi-layer Perceptron (MLP)


### Swish Activation Layer

In [None]:
init = dm.ds_train[:32]
x_init, y_init = init
x_init.shape, y_init.shape

In [None]:
out = Swish()(x_init)

out.shape

In [None]:
# x_img = rearrange(out.numpy(), "(x y) c -> x y c", x=img.shape[0], y=img.shape[0])

In [None]:
# plt.imshow(x_img)
# plt.show()

### MLP Layer

$$
\mathbf{f}_\ell(\mathbf{x}) = \sigma\left(\mathbf{w}^{(\ell)}\mathbf{x} + \mathbf{b}^{(\ell)} \right)
$$

where $\sigma$ is the *swish* activation function.

$$
\sigma(\mathbf{x}) = \mathbf{x} \odot \text{Sigmoid}(\mathbf{x})
$$

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
activation = Swish()  # nn.ReLU()#
final_activation = nn.Sigmoid()

mlp_net = MLP(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    activation=activation,
    final_activation=final_activation,
)

In [None]:
out = mlp_net(x_init)

In [None]:
# x_img = rearrange(out.detach().numpy(), "(x y) c -> x y c", x=img.shape[0], y=img.shape[0])

In [None]:
mlp_net

## Experiment

In [None]:
import pytorch_lightning as pl

pl.seed_everything(123)

#### Dataset



In [None]:
learning_rate = 1e-4
mlp_net = MLP(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    activation=activation,
)
dm = ImageFox(batch_size=4096, shuffle=True).setup()

In [None]:
learn = ImageModel(mlp_net, learning_rate=learning_rate)

In [None]:
callbacks = [TQDMProgressBar(refresh_rate=50)]

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=100,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(
    learn,
    datamodule=dm,
)

In [None]:
trainer.test(learn, dataloaders=dm.test_dataloader())

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

In [None]:
plt.figure()
plt.imshow(img_pred, cmap="gray")
plt.show()

### Encoders

In [None]:
from ml4ssh._src.models.encoders import (
    IdentityPositionalEncoding,
    NeRFPositionalEncoding,
    GaussianFourierFeatureTransform,
)
from ml4ssh._src.models.ffn import FourierFeatureMLP

In [None]:
encoder = IdentityPositionalEncoding(in_dim=dim_in)
encoder = NeRFPositionalEncoding(in_dim=dim_in, n=50)
encoder = GaussianFourierFeatureTransform(in_dim=dim_in, mapping_size=256, sigma=1.0)

In [None]:
out = encoder(x_init)
x_init.shape, out.shape

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
activation = Swish()
final_activation = nn.Sigmoid()

ffn_net = FourierFeatureMLP(
    encoder=encoder,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    activation=activation,
    final_activation=final_activation,
)

In [None]:
out = ffn_net(x_init)
x_init.shape, out.shape

In [None]:
learning_rate = 1e-4
ffn_net = FourierFeatureMLP(
    encoder=encoder,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    activation=activation,
    final_activation=final_activation,
)
dm = ImageFox(batch_size=4096).setup()

In [None]:
learn = ImageModel(ffn_net, learning_rate=learning_rate)

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=100,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(learn, datamodule=dm)

In [None]:
trainer.test(learn, dataloaders=dm.test_dataloader())

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

In [None]:
plt.figure()
plt.imshow(img_pred)
plt.show()