# (Simple) RNN in Jax

In [32]:
import jax
import jax.numpy as np
import math
import numpy as onp
import tensorflow as tf
import time
import matplotlib.pyplot as plt

In [33]:
key = jax.random.PRNGKey(1)

## Data loading

tbc

# Network Architecture

A RNN (with no LSTM units) is a neural network with recurrent connections (dynamical system)
$$
\begin{array}{l l}
\pmb{x}(n+1) = \sigma(W\pmb{x}(n) + W^{in}\pmb{u}(n+1) + \pmb{b}) \\
\pmb{y}(n) = f(W^{out} \pmb{x}(n))
\end{array}
$$
Describes how the network activation state is updated and how output signal is generated. 

Input vector $\pmb{u}(n) \in \mathbb{R}^K$

Activation/state vector $\pmb{x}(n) \in \mathbb{R}^L$

Output vector $\pmb{y}(n) \in \mathbb{R}^M$

Bias vector $\pmb{b} \in \mathbb{R}^L$

$W^{in} \in \mathbb{R}^{L \times K}, W \in \mathbb{R}^{L \times L}, W^{out} \in \mathbb{R}^{M \times L}$ are weight matrices charecterizing the connections between neurons in the layers


## Initialization
At time $n = 0$ the recurrent network state $\mathbf{x}(0)$ is often set to the zero vector $\mathbf{x}(0) = \mathbf{0}$

In [34]:
print(jax.random.normal(key, (2,)))
print(jax.random.normal(key, (2,1)))

[-0.11617039  2.2125063 ]
[[-0.11617039]
 [ 2.2125063 ]]


In [35]:
def init_weight_matrix(in_dim, out_dim, key, scale=1e-2):
    w = jax.random.normal(key, (out_dim, in_dim))
    return scale*w

In [36]:
def init_bias(dim, key, scale=1e-2):
    b = jax.random.normal(key, (dim, ))
    return scale*b

In [37]:
# state, params = (W^{in}, W, W^{out}, b)
# sizes = (input dim, state dim, output dim.)
def init_network(sizes, key):
    keys = jax.random.split(key, len(sizes))
    params = {} # hasmap
    x = np.zeros(sizes[1])
    
    params["input matrix"] = init_weight_matrix(sizes[0], sizes[1], keys[0])
    params["state matrix"] = init_weight_matrix(sizes[1], sizes[1], keys[1])
    params["bias vector"] = init_bias(sizes[1], keys[2])
    params["output matrix"] = init_weight_matrix(sizes[1], sizes[2], keys[3])

    return x, params


In [38]:
sizes = [3, 5, 3]
x, params = init_network(sizes, key)
print(jax.tree_map(lambda x: x.shape, params)) # printing shape of network


{'bias vector': (5,), 'input matrix': (5, 3), 'output matrix': (10, 5), 'state matrix': (5, 5)}


In [39]:
# params = weights and bias
# u = input signal at time n
# x = state vector : returns state at time n
# b = bias vector
# n = time

def nextState(x, params, u):
    w_in = params["input matrix"]
    w = params["state matrix"]
    b = params["bias vector"]
    x = jax.nn.relu(np.dot(w, x) + np.dot(w_in, u) + b)
    return x

In [40]:
def readOut(x, params):
    # here we just use the identity function for now
    w_out = params["output matrix"]
    y = np.dot(w_out, x)
    return y

In [41]:
u = jax.random.normal(jax.random.PRNGKey(2), shape=(20, 3))
print(u)
for n in range(len(u)):
    print("at time")
    print(n)
    x = nextState(x, params, u[n])
    print(readOut(x, params))


[[-0.09380784 -1.8266813   1.2070532 ]
 [-1.3838278  -1.7946863   0.739586  ]
 [-0.4072965  -0.5142992   0.7693824 ]
 [ 1.7207135   0.8223581  -0.06358211]
 [ 1.0225011   0.4435419   0.8082118 ]
 [ 0.03334755  0.4504489   1.4941233 ]
 [-0.03919963  0.28523797  0.5016509 ]
 [ 0.1117193   0.5748015   1.949763  ]
 [-0.8118261  -0.33672643 -1.1827446 ]
 [-0.59318626  0.01991192 -0.7724237 ]
 [ 0.11230356  0.4132285   0.29627615]
 [-0.9084105   0.11455189 -0.80134195]
 [-0.7631346  -0.76251495 -0.926171  ]
 [ 1.9448895   0.7570497  -1.677625  ]
 [-0.04892532 -0.430676   -0.9932122 ]
 [-0.7845826  -0.3665046  -0.5999799 ]
 [ 1.1914481  -1.7223476  -0.12748261]
 [ 0.96800166  0.2489037   0.7475761 ]
 [ 0.9835043  -0.17859596 -0.04589614]
 [-0.2840927  -1.3884628   1.5984961 ]]
at time
0
[-5.8886588e-05  2.4645540e-04 -3.7043955e-04 -2.4138135e-04
 -4.8458238e-05 -2.5399908e-04  3.0565402e-04  2.2559115e-04
 -3.6294572e-04  1.3628873e-04]
at time
1
[-9.22519466e-05  2.28241232e-04 -3.42467072e