# Load model
> params.npy

In [None]:
import jax 
import jax.numpy as jnp
from flax import linen as nn 

from typing import Sequence

In [None]:
class SPINN4d(nn.Module):
    features: Sequence[int]
    r: int
    out_dim: int

    @nn.compact
    def __call__(self, t, x, y, z):
        inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], []

        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.r*self.out_dim, kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1, 0))]

        for i in range(self.out_dim):
            tx += [jnp.einsum('ft, fx->ftx', 
            outputs[0][self.r*i:self.r*(i+1)], 
            outputs[1][self.r*i:self.r*(i+1)])]

            txy += [jnp.einsum('ftx, fy->ftxy', 
            tx[i], 
            outputs[2][self.r*i:self.r*(i+1)])]

            pred += [jnp.einsum('ftxy, fz->txyz', 
            txy[i], 
            outputs[3][self.r*i:self.r*(i+1)])]

        if len(pred) == 1:
            # 1-dimensional output
            return pred[0]
        else:
            # n-dimensional output
            return pred

In [None]:
seed = 111
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)

nc = 32
features = 64 
n_layers = 5 
feat_sizes = tuple([features for _ in range(n_layers)]) 
r = 128
out_dim = 3

model = SPINN4d(feat_sizes, r, out_dim)
params = model.init(
            key,
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1)),
            jnp.ones((nc, 1))
        )
apply_fn = jax.jit(model.apply)

In [None]:
import orbax

In [None]:
checkpoint = orbax.checkpoint.PyTreeCheckpointer()
ckpt = checkpoint.restore('ckpt')
params = ckpt['params']
train_data = ckpt['train_data']

In [None]:
u_test, _, _ = apply_fn(params, train_data[0], train_data[1], train_data[2], train_data[3])
u_test[0][0][0][0]

Array(-0.82522225, dtype=float32)