In [1]:
import numpy as np
import numpy.linalg as LA
import jax
import jax.numpy as jnp
import jax.numpy.linalg as JLA

from jax.example_libraries import optimizers

import optax

from tqdm.notebook import trange
from functools import partial
from flax import linen as nn
from flax.experimental import nnx
from flax.training import train_state

In [2]:
K = 50 
noise_std = 0.75
n = 4
h = 50
H = jnp.array(np.random.randn(n, n))
adam_lr = 0.01

In [3]:
def mini_batch(K):
    x = -2.0*np.random.binomial(1,0.5,size=(n, K)) + 1
    y = H @ x + noise_std * jnp.array(np.random.randn(n, K))
    return x.T, y.T

In [4]:
class Detector(nn.Module):
    hidden_dim : int
    output_dim : int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim)(x)
        x = nnx.relu(x)
        x = nn.Dense(self.hidden_dim)(x)
        x = nnx.relu(x)
        x = nn.Dense(self.output_dim)(x)
        x = nnx.tanh(x)        
        return x

In [5]:
x, y = mini_batch(K)
detector = Detector(hidden_dim=h, output_dim=n)
key = jax.random.PRNGKey(0) 
params = detector.init(key, x[:1])["params"]
detector.apply({"params":params}, x).shape

(50, 4)

In [6]:
@jax.jit
def get_dot(x):
    return x @ x.T
batch_get_dot = jax.vmap(get_dot, in_axes=0, out_axes=0)

In [7]:
@jax.jit
def loss_func(X, Y, params):
    X_hat = detector.apply({"params":params}, Y)
    loss = jnp.sum(batch_get_dot(X_hat - X))
    # loss = optax.squared_error(X_hat, X).mean()
    return loss

In [8]:
train_itr = 500

opt_init, opt_update, get_params = optimizers.adam(adam_lr)

def step(x, y, step_num, opt_state):
    value, grads = jax.value_and_grad(loss_func, argnums=-1)(x, y, get_params(opt_state))
    new_opt_state = opt_update(step_num, grads, opt_state)
    return value, new_opt_state

def train(params):
    opt_state = opt_init(params)
    for itr in trange(train_itr, leave=False):
        x, y = mini_batch(K)
        value, opt_state = step(x, y, itr, opt_state)
        print("\r"+"\rloss:{}".format(value), end=" ")
    return get_params(opt_state)

In [9]:
trained_params = train(params)

  0%|          | 0/500 [00:00<?, ?it/s]

loss:12.050289154052734 

In [10]:
def neural_recog(params, num_loops = 1000) :
    total_syms = num_loops * n * K
    error_syms = 0
    
    for i in range(num_loops):
        x, y = mini_batch(K)
        x_hat = detector.apply({"params":params}, y)
        error_syms += jnp.sum(jnp.sign(x_hat) != x)
    return total_syms, int(error_syms)

In [11]:
def ZF_recog(num_loops = 1000) :
    total_syms = num_loops * n * K
    error_syms = 0
    Hinv = LA.inv(H)    
    
    for i in range(num_loops):
        x, y = mini_batch(K)
        x_hat =  Hinv @ y.T
        error_syms += jnp.sum(jnp.sign(x_hat.T) != x)
    return total_syms, int(error_syms)

In [12]:
total_syms, error_syms = neural_recog(trained_params)
print("total_syms = ", total_syms)
print("error_syms = ", error_syms)
print("symbols error rate = ", error_syms/total_syms)

total_syms =  200000
error_syms =  5895
symbols error rate =  0.029475


In [13]:
total_syms, error_syms = ZF_recog()
print("total_syms = ", total_syms)
print("error_syms = ", error_syms)
print("symbols error rate = ", error_syms/total_syms)

total_syms =  200000
error_syms =  9522
symbols error rate =  0.04761


## optax版

In [14]:
tx = optax.adam(learning_rate=adam_lr)
state = train_state.TrainState.create(apply_fn=detector.apply, params=params, tx=tx)

In [15]:
def step(x, y, state):
    value, grads = jax.value_and_grad(loss_func, argnums=-1)(x, y, state.params)
    new_state = state.apply_gradients(grads=grads)
    return value, new_state

def train(state):
    for itr in trange(train_itr, leave=False):
        x, y = mini_batch(K)
        value, state = step(x, y, state)
        print("\r"+"\rloss:{}".format(value), end=" ")
    return state

In [16]:
trained_state = train(state)

  0%|          | 0/500 [00:00<?, ?it/s]

loss:11.116448402404785 

In [17]:
total_syms, error_syms = neural_recog(trained_state.params)
print("total_syms = ", total_syms)
print("error_syms = ", error_syms)
print("symbols error rate = ", error_syms/total_syms)

total_syms =  200000
error_syms =  5467
symbols error rate =  0.027335
