In [None]:
import sys, os
from pyprojroot import here


# spyder up to find the roo

root = here(project_files=[".local"])
# append to path
sys.path.append(str(root))

In [None]:
import lib

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx

from lib.siren import Siren, SirenNet
from lib.activations import Sine, ReLU

%load_ext autoreload
%autoreload 2

## Siren Net


### Sine Activation Layer

In [None]:
key = jrandom.PRNGKey(123)

# generate some fake data
n_dims = 10
data_key, key = jrandom.split(key, 2)
x = jrandom.normal(key=data_key, shape=(n_dims,))

# init network
w0 = 1.0
activation = Sine(w0=w0)

out = activation(x)

print(activation)

### Siren Layer

In [None]:
layer_key, key = jrandom.split(key, 2)

layer = Siren(in_dim=n_dims, out_dim=2, w0=1.0, c=6.0, key=layer_key)

out = layer(x)

print(layer)

### Siren Network

In [None]:
net_key, key = jrandom.split(key, 2)

model = SirenNet(
    in_dim=n_dims,
    hidden_dim=32,
    n_hidden=5,
    out_dim=2,
    w0_initial=30,
    w0=1.0,
    c=6.0,
    key=net_key,
    final_scale=1.0,
    final_activation=eqx.nn.Identity(),
)

out = model(x)

print(model)

In [None]:
out

### Modulated Siren

In [None]:
from typing import List


class Modulator(eqx.Module):
    layers: List[eqx.Module]

    def __init__(self, in_dim, hidden_dim, n_hidden, key):
        super().__init__()
        keys = jrandom.split(key, n_hidden + 1)

        self.layers = [eqx.nn.Linear(in_dim, hidden_dim, key=keys[0])]

        # Hidden layers
        for ikey in keys[1:-1]:
            self.layers.append(
                eqx.nn.Sequential(
                    [eqx.nn.Linear(hidden_dim, hidden_dim, key=ikey), ReLU()]
                )
            )

    def __call__(self, z):
        """
        Parameters
        ----------
        z : Array,

        Returns
        -------
        out : tuple(x)
        """
        x = z

        hidden = []

        for ilayer in self.layers:
            x = ilayer(x)
            print(x.shape, z.shape)
            hidden.append(x)
            x = jnp.concatenate([x, z])
            print(x.shape)

        return tuple(hidden)

In [None]:
latent_dim = 512
n_hidden = 5
layer = Modulator(latent_dim, model.hidden_dim, model.num_layers, key)

latent = jrandom.normal(key, (latent_dim,))
out = layer(latent)
# out.shape

In [None]:
from typing import Callable, List

Array = jnp.ndarray
from einops import rearrange


class ModulatedSirenNet(eqx.Module):
    """SirenNet"""

    siren_net: eqx.Module
    latent: jnp.ndarray
    layers: List[eqx.Module]

    def __init__(
        self,
        siren_net: eqx.Module,
        latent_dim,
        key,
    ):
        super().__init__()
        """"""
        keys = jrandom.split(key, siren_net.num_layers + 1)

        self.layers = [eqx.nn.Linear(latent_dim, siren_net.hidden_dim, key=keys[0])]

        # Hidden layers
        for ikey in keys[1:-1]:
            self.layers.append(
                eqx.nn.Sequential(
                    [
                        eqx.nn.Linear(
                            siren_net.hidden_dim, siren_net.hidden_dim, key=ikey
                        ),
                        ReLU(),
                    ]
                )
            )

        # First layer
        self.siren_net = siren_net
        self.latent = jrandom.normal(key=key, shape=(latent_dim,))

    def __call__(self, x: Array) -> Array:
        mod = self.latent
        for ilayer_siren, ilayer_mod in zip(self.siren_net.layers[:-1], self.layers):
            # siren layer
            x = ilayer_siren(x)
            #
            mod = ilayer_mod(mod)
            print(f"z: {self.latent.shape} | mod: {mod.shape}")
            print(f"x: {x.shape} | mod: {mod.shape}")
            # x *= rearrange(mod, 'd -> () d')
            # mod = jnp.concatenate([mod, self.latent], axis=0)
            print(f"mod: {mod.shape}")
            break
        return self.siren_net.layers[-1](x)

In [None]:
model = SirenNet(
    in_dim=n_dims,
    hidden_dim=32,
    n_hidden=5,
    out_dim=2,
    w0_initial=30,
    w0=1.0,
    c=6.0,
    key=net_key,
    final_scale=1.0,
    final_activation=eqx.nn.Identity(),
)

model = ModulatedSirenNet(siren_net=model, latent_dim=512, key=key)

# print(model)
output = model(x)
print(x.shape, output.shape)

In [None]:
output