# (Simple) RNN in Jax

In [320]:
import jax
import jax.numpy as np
import math
import numpy as onp
import tensorflow as tf
import time
import matplotlib.pyplot as plt
from jax import grad, vmap, jit

In [321]:
key = jax.random.PRNGKey(1)

## Data loading

tbc

# Network Architecture

A RNN (with no LSTM units) is a neural network with recurrent connections (dynamical system)
$$
\begin{array}{l l}
\pmb{x}(n+1) = \sigma(W\pmb{x}(n) + W^{in}\pmb{u}(n+1) + \pmb{b}) \\
\pmb{y}(n) = f(W^{out} \pmb{x}(n))
\end{array}
$$
Describes how the network activation state is updated and how output signal is generated. 

Input vector $\pmb{u}(n) \in \mathbb{R}^K$

Activation/state vector $\pmb{x}(n) \in \mathbb{R}^L$

Output vector $\pmb{y}(n) \in \mathbb{R}^M$

Bias vector $\pmb{b} \in \mathbb{R}^L$

$W^{in} \in \mathbb{R}^{L \times K}, W \in \mathbb{R}^{L \times L}, W^{out} \in \mathbb{R}^{M \times L}$ are weight matrices charecterizing the connections between neurons in the layers


## Initialization
At time $n = 0$ the recurrent network state $\mathbf{x}(0)$ is often set to the zero vector $\mathbf{x}(0) = \mathbf{0}$

In [322]:
def init_weight_matrix(in_dim, out_dim, key, scale=1e-2):
    w = jax.random.normal(key, (out_dim, in_dim))
    return scale*w

In [323]:
def init_bias(dim, key, scale=1e-2):
    b = jax.random.normal(key, (dim, ))
    return scale*b

In [324]:
# state, params = (W^{in}, W, W^{out}, b)
# sizes = (input dim, state dim, output dim.)
def init_network(sizes, key):
    keys = jax.random.split(key, len(sizes))
    params = {} # hashmap
    # don't know if this is the best way to do it but this is to keep track of the state vector over time
    x = []
    x.append(np.zeros(sizes[1]))
    # as well as output signal
    y = []
    
    params["input matrix"] = init_weight_matrix(sizes[0], sizes[1], keys[0])
    params["state matrix"] = init_weight_matrix(sizes[1], sizes[1], keys[1])
    params["bias vector"] = init_bias(sizes[1], keys[2])
    params["output matrix"] = init_weight_matrix(sizes[1], sizes[2], keys[3])

    return x, y, params


In [325]:
sizes = [3, 5, 3]
x, y, params = init_network(sizes, key)
print(jax.tree_map(lambda x: x.shape, params)) # printing shape of network


{'bias vector': (5,), 'input matrix': (5, 3), 'output matrix': (3, 5), 'state matrix': (5, 5)}


$$\pmb{x}(n) = \sigma(W\pmb{x}(n-1) + W^{in}\pmb{u}(n) + \pmb{b})$$

In [326]:
# params = weights and bias
# u = input signal at time n
# x = state vector : returns state at time n
# b = bias vector
# n = time
# changed: the entire state vector and input signal is passed now
# adds new state vector to state signal and also returns new state vector
def nextState(params, x, u, n):
    w_in = params["input matrix"]
    w = params["state matrix"]
    b = params["bias vector"]
    x_new = jax.nn.relu(np.dot(w, x[-1]) + np.dot(w_in, u[n]) + b)
    x.append(x_new)
    return x_new

$$\pmb{y}(n) = f(W^{out} \pmb{x}(n))$$
Softmax makes the output vector a valid probability vector. Given $\textbf{v} = (v_1, ..., v_d)' \in \mathbb{R}^d$:
$$
f(\textbf{v}) = \text{softmax}(\textbf{v}) = \frac{1}{Z}(\exp(v_1), ..., \exp(v_d))'
$$
where $Z = \sum_{i=1, ..., d} \exp(v_i)$ is the normalization constant.

In [327]:
# adds new output vector to output signal and also returns new output vector
def readOut(params, x, y):
    w_out = params["output matrix"]
    y_new = jax.nn.softmax(np.dot(w_out, x[-1]))
    y.append(y_new)
    return y_new

Example run with random input signal
$$
\mathbf{u}(n)_{n=0,..., 20} \in \mathbb{R}^3
$$

In [328]:
u = jax.random.normal(jax.random.PRNGKey(2), shape=(20, 3))
for n in range(len(u)):
    nextState(params, x, u, n)
    readOut(params, x, y)
    #x.append(nextState(params, x, u, n))
    #y.append(readOut(params, x))


In [329]:
# some printing
# note: u is a jax numpy matrix
# x, y are lists whose elements are jax numpy arrays
# but u[i], x[i], y[i] are all jax numpy arrays
print(u[2])
print(type(u))
print(type(u[2]))
print(x[2])
print(type(x))
print(type(x[2]))
print(y[2])
print(type(y))
print(type(y[2]))

[-0.4072965 -0.5142992  0.7693824]
<class 'jaxlib.xla_extension.DeviceArray'>
<class 'jaxlib.xla_extension.DeviceArray'>
[5.4300297e-05 3.8219169e-03 0.0000000e+00 3.5909493e-02 0.0000000e+00]
<class 'list'>
<class 'jaxlib.xla_extension.DeviceArray'>
[0.33337    0.33329844 0.33333156]
<class 'list'>
<class 'jaxlib.xla_extension.DeviceArray'>


## Training
Time series prediction task $S = (\mathbf{u}(n), \mathbf{y}(n))_{n=1, ..., N}$ where $\mathbf{y}(n) = \mathbf{u}(n+1)$
For now: quadratic loss which is used in stationary tasks
$$
L(\hat{\mathbf{y}}(n), \mathbf{y}(n)) = \parallel \hat{\mathbf{y}}(n) - \mathbf{y}(n) \parallel^2
$$

In [332]:
def loss(params, x, y, y_true, n):
    y_hat = readOut(params, x, y)
    return np.square(np.subtract(y_hat, y_true[n]))

[0.44441688 0.44447714 0.4444393 ]
