In [1]:
import numpy as np
import numpy.linalg as LA
import jax
import jax.numpy as jnp
import jax.numpy.linalg as JLA

import pandas as pd
import matplotlib.pyplot as plt
import japanize_matplotlib
from jax.example_libraries import optimizers
from tqdm.notebook import trange
from functools import partial
from flax import linen as nn
from flax.experimental import nnx
from typing import Sequence, Callable, Tuple

In [3]:
K = 50 
noise_std = 0.75
n = 4
h = 50
H = jnp.array(np.random.randn(n, n))
adam_lr = 0.01

In [26]:
def mini_batch(K):
    x = -2.0*np.random.binomial(1,0.5,size=(n, K)) + 1
    y = H @ x + noise_std * jnp.array(np.random.randn(n, K))
    return x.T, y.T

In [34]:
class Detector(nn.Module):
    hidden_dim : int
    output_dim : int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nnx.relu(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nnx.relu(x)
        x = nn.Dense(self.output_dim)(x)
        x = nnx.tanh(x)        
        return x

In [35]:
x, y = mini_batch(K)
detector = Detector(hidden_dim=h, output_dim=n)
key = jax.random.PRNGKey(0) 
params = detector.init(key, x[:1])["params"]
detector.apply({"params":params}, x).shape

(50, 4)

In [36]:
@jax.jit
def get_dot(x):
    return x @ x.T
batch_get_dot = jax.vmap(get_dot, in_axes=0, out_axes=0)

In [43]:
def loss_func(X, Y, params):
    X_hat = detector.apply({"params":params}, Y)
    loss = jnp.mean(batch_get_dot(X_hat - X))
    return loss

In [41]:
train_itr = 500

opt_init, opt_update, get_params = optimizers.adam(adam_lr)

def step(x, y, step_num, opt_state):
    value, grads = jax.value_and_grad(loss_func, argnums=-1)(x, y, get_params(opt_state))
    new_opt_state = opt_update(step_num, grads, opt_state)
    return value, new_opt_state

def train(params):
    opt_state = opt_init(params)
    for itr in trange(train_itr, leave=False):
        x, y = mini_batch(K)
        value, opt_state = step(x, y, itr, opt_state)
        print("\r"+"\rloss:{}".format(value), end=" ")


In [42]:
trained_params = train(params)

  0%|          | 0/500 [00:00<?, ?it/s]

loss:0.33613890409469604 