# Demo - Siren

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])

# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import torch
from torch import nn
from tqdm.notebook import tqdm as tqdm
import os, imageio

from ml4ssh._src.models.siren import Siren, SirenNet, Modulator, ModulatedSirenNet
from ml4ssh._src.models.activations import Sine
from torch.nn import ReLU

import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

## Data

$$
\mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta})
$$

In [None]:
# Download image, take a square crop from the center
image = "fox"


if image == "earth":
    # EARTH
    image_url = "https://i0.wp.com/thepythoncodingbook.com/wp-content/uploads/2021/08/Earth.png?w=301&ssl=1"
    img = imageio.imread(image_url)[..., :3] / 255.0

elif image == "fox":
    # FOX
    image_url = "https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg"
    img = imageio.imread(image_url)[..., :3] / 255.0
    c = [img.shape[0] // 2, img.shape[1] // 2]
    r = 256
    img = img[c[0] - r : c[0] + r, c[1] - r : c[1] + r]

In [None]:
plt.figure()
plt.imshow(img)
plt.show()

In [None]:
# Create input pixel coordinates in the unit square
coords = np.linspace(0, 1, img.shape[0], endpoint=False)
x_test = np.stack(np.meshgrid(coords, coords), -1)
# y_test = img
# x_train =
test_data = [x_test, img]
train_data = [x_test[::2, ::2], img[::2, ::2]]

In [None]:
from einops import rearrange

In [None]:
x_train = rearrange(train_data[0], "x y c -> (x y) c")
y_train = rearrange(train_data[1], "x y c -> (x y) c")
x_test = rearrange(test_data[0], "x y c -> (x y) c")
y_test = rearrange(test_data[1], "x y c -> (x y) c")

In [None]:
x_train.shape, x_test.shape

## Siren Net


### Sine Activation Layer

In [None]:
x_train_tensor = torch.Tensor(x_train)
y_train_tensor = torch.Tensor(y_train)
x_train_tensor = torch.Tensor(x_test)
y_train_tensor = torch.Tensor(y_test)

In [None]:
x_train_tensor.min(), x_train_tensor.max()

In [None]:
out = Sine()(x_train_tensor)

out.shape

In [None]:
x_img = rearrange(out.numpy(), "(x y) c -> x y c", x=img.shape[0], y=img.shape[0])

In [None]:
# plt.imshow(x_img)
# plt.show()

### Siren Layer

$$
\mathbf{f}_\ell(\mathbf{x}) = \sin\left(\omega_0 \left(\mathbf{w}^{(\ell)}\mathbf{x} + \mathbf{b}^{(\ell)} \right)\right)
$$

In [None]:
dim_in = 2
dim_out = 3
w0 = 1.0
c = 6.0

layer = Siren(
    dim_in=dim_in,
    dim_out=dim_out,
    w0=w0,
    c=c,
)

In [None]:
out = layer(x_train_tensor[:100])

In [None]:
# x_img = rearrange(out.detach().numpy(), "(x y) c -> x y c", x=img.shape[0], y=img.shape[0])

In [None]:
# plt.imshow(x_img)
# plt.show()

### Siren Network

In [None]:
dim_in = 2
dim_hidden = 128
dim_out = 3
num_layers = 5
w0 = 1.0
w0_initial = 30.0
c = 6.0

siren_net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
)

In [None]:
out = net(x_train_tensor[:100])

In [None]:
# x_img = rearrange(out.detach().numpy(), "(x y) c -> x y c", x=img.shape[0], y=img.shape[0])

In [None]:
net

### Modulated Siren

#### Modulator

In [None]:
latent_dim_in = 512
latent_dim_hidden = 128
latent_num_layers = 5

# initialize the parameter
latent = nn.Parameter(torch.zeros(512).normal_(0, 1e-2))

mod_layer = Modulator(
    dim_in=latent_dim_in, dim_hidden=latent_dim_hidden, num_layers=latent_num_layers
)

In [None]:
mods = mod_layer(latent)

In [None]:
for imod in mods:
    print(imod.shape)

#### Modulated Siren Layer

In [None]:
out_modded = net(x_train_tensor[:100], mods)

In [None]:
out_modded.shape

In [None]:
latent_dim_in = 512
latent_dim_hidden = 128
latent_num_layers = 5

latent = nn.Parameter(torch.zeros(512).normal_(0, 1e-2))

mod_layer = Modulator(
    dim_in=latent_dim_in, dim_hidden=latent_dim_hidden, num_layers=latent_num_layers
)

In [None]:
# from typing import Callable, Optional

# class ModulatedSirenNet(nn.Module):
#     def __init__(self,
#                  dim_in,
#                  dim_hidden,
#                  dim_out,
#                  num_layers: int=5,
#                  latent_dim: int=512,
#                  num_layers_latent: int=3,
#                  w0: float = 1.,
#                  w0_initial: float = 30.,
#                  c: float = 6.0,
#                  use_bias: bool = True,
#                  final_activation: Optional[nn.Module] = None,
#                  resnet: bool = False
#                 ):
#         super().__init__()
#         self.num_layers = num_layers
#         self.dim_hidden = dim_hidden

#         self.layers = nn.ModuleList([])
#         for ind in range(num_layers):
#             is_first = ind == 0
#             layer_w0 = w0_initial if is_first else w0
#             layer_dim_in = dim_in if is_first else dim_hidden
#             res_first = False

#             self.layers.append(Siren(
#                 dim_in = layer_dim_in,
#                 dim_out = dim_hidden,
#                 w0 = layer_w0,
#                 c = c,
#                 use_bias = use_bias,
#                 is_first = is_first,
#                 resnet = True if resnet and res_first else False
#             ))
#             if res_first:
#                 res_first = False

#             self.latent = nn.Parameter(torch.zeros(latent_dim).normal_(0, 1e-2))

#             self.modulator = Modulator(
#                 dim_in=latent_dim,
#                 dim_hidden=dim_hidden,
#                 num_layers=num_layers_latent,
#             )

#         final_activation = nn.Identity() if not exists(final_activation) else final_activation
#         self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

#     def forward(self, x):

#         mods = self.modulator(self.latent)

#         mods = cast_tuple(mods, self.num_layers)


#         for layer, mod in zip(self.layers, mods):
#             x = layer(x)

#             x *= rearrange(mod, 'd -> () d')

#         return self.last_layer(x)

In [None]:
dim_in = 2
dim_hidden = 128
dim_out = 3
num_layers = 5
w0 = 1.0
w0_initial = 30.0
c = 6.0
latent_dim = 512
num_layers_latent = 3
operation = "add"

latent = nn.Parameter(torch.zeros(latent_dim).normal_(0, 1e-2))

net = ModulatedSirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    latent_dim=latent_dim,
    num_layers_latent=num_layers_latent,
    operation=operation,
)

In [None]:
out = net(x_train_tensor[:100], latent)

In [None]:
assert out.shape == y_train_tensor[:100].shape