In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from colora.networks import DNN
from colora.utils import init_net, split, merge
from jax.flatten_util import ravel_pytree
import jax

import flax.linen as nn

# define the structure of u_hat
layers = ['C', 'C', 'C']
output_field_dim = 2
x_dim = 1
width = 4

activations = [nn.selu, nn.sigmoid,  nn.sigmoid]
net = DNN(width=width, layers=layers, out_dim=output_field_dim, activation=activations)
params_init, u_apply = init_net(net, x_dim)

# split up the params of u into the offline params (theta) and the online params (phi)
lora_filter = 'alpha'
phi, theta = split(params_init,lora_filter)
flat_phi, phi_unravel = ravel_pytree(phi)
n_phi = len(flat_phi)

# define the structure of h, our hyper network
layers = ['D', 'D', 'D']
out_dim = n_phi # must output to the size of phi
mu_t_dim = 2 # for time and mu
width = 4
net = DNN(width=width, layers=layers, out_dim=out_dim)
psi_init, h_net = init_net(net, mu_t_dim)


In [None]:
# now we define a wrapper over the u neural network
# this will allow us to take theta and phi sepearatly
# and then it will automatically merge them and pass the resulting
# combined paramters to u_apply
# it also unravels phi so that we can pass in phi as a vector that is output from h
def build_u_hat(u_apply, phi_unravel):
    def u_hat(theta, phi, *args):
        phi = phi_unravel(phi)
        theta_phi = merge(theta, phi)
        return u_apply(theta_phi, *args)
    return u_hat

u_hat = build_u_hat(u_apply, phi_unravel)

In [None]:
import jax.numpy as jnp
mu_t = jnp.ones(mu_t_dim)
phi = h_net(psi_init, mu_t)

In [None]:
x = jnp.ones(x_dim)
u_hat(theta, phi, x)