In [6]:
%matplotlib inline
import os
import time
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from tqdm import trange
from jax import jvp, value_and_grad
from flax import linen as nn
from typing import Sequence
from functools import partial

In [43]:
NC = 64
NI = 64
NB = 64
NC_TEST = 100
SEED = 444
LR = 1e-3
EPOCHS = 50000
N_LAYERS = 4
FEATURES = 3
LOG_ITER = 25000

In [72]:
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)

In [85]:
class SPINN(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.glorot_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)
            outputs += [jnp.transpose(X, (1,0))]
        xy = jnp.einsum('fx, fy->fxy', outputs[0], outputs[1])
        xyz = jnp.einsum('fxy, fz->xyz', xy, outputs[-1])
        return xyz

In [88]:
x = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
y = jnp.array([[1., 3.], [9., 10.], [11., 12.]])
print(x.shape)
print(x)
print(y.shape)
print(y)
xy = jnp.einsum('fx, fy->fxy', x, y)
print(xy.shape)
print(xy)
z = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
print(z.shape)
print(z)
xyz = jnp.einsum('fxy, fz->xyz', xy, z)
print(xyz.shape)
print(xyz)

(3, 2)
[[1. 2.]
 [3. 4.]
 [5. 6.]]
(3, 2)
[[ 1.  3.]
 [ 9. 10.]
 [11. 12.]]
(3, 2, 2)
[[[ 1.  3.]
  [ 2.  6.]]

 [[27. 30.]
  [36. 40.]]

 [[55. 60.]
  [66. 72.]]]
(3, 2)
[[1. 2.]
 [3. 4.]
 [5. 6.]]
(2, 2, 2)
[[[357. 440.]
  [393. 486.]]

 [[440. 544.]
  [486. 604.]]]


In [87]:
feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))
model = SPINN(feat_sizes)
x = jnp.array([1., 2.]).reshape(-1, 1)
y = jnp.array([4., 5.]).reshape(-1, 1)
z = jnp.array([7., 8.]).reshape(-1, 1)
params = model.init(subkey, x, y, z)
apply_fn = jax.jit(model.apply)
u = apply_fn(params, x, y, z)
u.shape

(2, 2, 2)