# (Simple) RNN in Jax

In [213]:
import jax
import jax.numpy as np
import math
import numpy as onp
import tensorflow as tf
import time
import matplotlib.pyplot as plt
from jax import grad, vmap, jit
import pandas as pd
import itertools

In [214]:
key = jax.random.PRNGKey(1)

## Data loading

In [215]:
df = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names=["1", "2", "3", "4", "5", "6"])
print(df.head())

track = (df["1"].values)
time = (df["2"].values)
note_is_played = (df["3"].values)
channel = df["4"].values
notes = df["5"].values
velocity = df["6"].values


"""
print(track)
print(time)
print(note_is_played)
print(channel)
print(notes)
print(velocity)
"""

h = {}
h[36] = 0
h[38] = 1
h[41] = 2
h[42] = 3
h[43] = 4
h[45] = 5
h[82] = 6
print(h)

u = list(itertools.repeat(np.zeros(7), max(time)+1))

# this is not working correctly yet
for i in range(len(notes)):
    v = [0]*7
    v[h[notes[i]]] = velocity[i] / 127
    if (np.any(u[time[i]])):
        u[time[i]] = u[time[i]] + np.array(v)
    else:
        u[time[i]] = np.array(v)

print(jax.tree_map(lambda x: x.shape, u[236]))

   1    2           3  4   5   6
0  2    0   Note_on_c  9  42  96
1  2    0   Note_on_c  9  36  84
2  2   29   Note_on_c  9  36   0
3  2   29   Note_on_c  9  42   0
4  2  236   Note_on_c  9  38  61
{36: 0, 38: 1, 41: 2, 42: 3, 43: 4, 45: 5, 82: 6}
(7,)


# Network Architecture

A RNN (with no LSTM units) is a neural network with recurrent connections (dynamical system)
$$
\begin{array}{l l}
\pmb{x}(n+1) = \sigma(W\pmb{x}(n) + W^{in}\pmb{u}(n+1) + \pmb{b}) \\
\pmb{y}(n) = f(W^{out} \pmb{x}(n))
\end{array}
$$
Describes how the network activation state is updated and how output signal is generated. 

Input vector $\pmb{u}(n) \in \mathbb{R}^K$

Activation/state vector $\pmb{x}(n) \in \mathbb{R}^L$

Output vector $\pmb{y}(n) \in \mathbb{R}^M$

Bias vector $\pmb{b} \in \mathbb{R}^L$

$W^{in} \in \mathbb{R}^{L \times K}, W \in \mathbb{R}^{L \times L}, W^{out} \in \mathbb{R}^{M \times L}$ are weight matrices charecterizing the connections between neurons in the layers


## Initialization
At time $n = 0$ the recurrent network state $\mathbf{x}(0)$ is often set to the zero vector $\mathbf{x}(0) = \mathbf{0}$

In [216]:
def init_weight_matrix(in_dim, out_dim, key, scale=1e-2):
    w = jax.random.normal(key, (out_dim, in_dim))
    return scale*w

In [217]:
def init_bias(dim, key, scale=1e-2):
    b = jax.random.normal(key, (dim, ))
    return scale*b

In [218]:
# state, params = (W^{in}, W, W^{out}, b)
# sizes = (input dim, state dim, output dim.)
def init_network(sizes, key):
    keys = jax.random.split(key, len(sizes))
    params = {} # hashmap
    # don't know if this is the best way to do it but this is to keep track of the state vector over time
    x = []
    x.append(np.zeros(sizes[1]))
    # as well as output signal
    y = []
    
    params["input matrix"] = init_weight_matrix(sizes[0], sizes[1], keys[0])
    params["state matrix"] = init_weight_matrix(sizes[1], sizes[1], keys[1])
    params["bias vector"] = init_bias(sizes[1], keys[2])
    params["output matrix"] = init_weight_matrix(sizes[1], sizes[2], keys[3])

    return x, y, params


In [219]:
sizes = [7, 12, 7]
x, y, params = init_network(sizes, key)
print(jax.tree_map(lambda x: x.shape, params)) # printing shape of network


{'bias vector': (12,), 'input matrix': (12, 7), 'output matrix': (7, 12), 'state matrix': (12, 12)}


$$\pmb{x}(n) = \sigma(W\pmb{x}(n-1) + W^{in}\pmb{u}(n) + \pmb{b})$$

In [220]:
# params = weights and bias
# u = input signal at time n
# x = state vector : returns state at time n
# b = bias vector
# n = time
# changed: the entire state vector and input signal is passed now
# adds new state vector to state signal and also returns new state vector
def nextState(params, x, u, n):
    w_in = params["input matrix"]
    w = params["state matrix"]
    b = params["bias vector"]
    x_new = jax.nn.relu(np.dot(w, x[-1]) + np.dot(w_in, u[n]) + b)
    x.append(x_new)
    return x_new

$$\pmb{y}(n) = f(W^{out} \pmb{x}(n))$$
Softmax makes the output vector a valid probability vector. Given $\textbf{v} = (v_1, ..., v_d)' \in \mathbb{R}^d$:
$$
f(\textbf{v}) = \text{softmax}(\textbf{v}) = \frac{1}{Z}(\exp(v_1), ..., \exp(v_d))'
$$
where $Z = \sum_{i=1, ..., d} \exp(v_i)$ is the normalization constant.

(For now identity function is used)

In [221]:
# adds new output vector to output signal and also returns new output vector
def readOut(params, x, y):
    w_out = params["output matrix"]
    y_new = np.dot(w_out, x[-1])
    y.append(y_new)
    return y_new

Input signal
$$
\mathbf{u}(n)_{n=0,..., n_{\text{max midi file}}} \in \mathbb{R}^7
$$

In [222]:

for n in range(len(u)):
    nextState(params, x, u, n)
    readOut(params, x, y)


In [223]:
# some printing
# note: u is a jax numpy matrix
# x, y are lists whose elements are jax numpy arrays
# but u[i], x[i], y[i] are all jax numpy arrays
print(u[6000])
#print(type(u))
#print(type(u[2]))
print(x[6000])
#print(type(x))
#print(type(x[2]))
print(y[6000])
#print(type(y))
#print(type(y[2]))

[0. 0. 0. 0. 0. 0. 0.]
[0.0000000e+00 3.9604981e-03 0.0000000e+00 0.0000000e+00 2.0270241e-02
 0.0000000e+00 3.3458474e-03 0.0000000e+00 7.6153837e-03 7.5484539e-05
 0.0000000e+00 1.3498385e-03]
[-2.0801499e-04 -3.0292012e-04 -1.0615522e-04 -3.3284104e-04
 -1.1379622e-04 -1.4905985e-04  9.9559227e-05]


## Training
Time series prediction task $S = (\mathbf{u}(n), \mathbf{y}(n))_{n=1, ..., N}$ where $\mathbf{y}(n) = \mathbf{u}(n+1)$
For now: quadratic loss which is used in stationary tasks
$$
L(\hat{\mathbf{y}}(n), \mathbf{y}(n)) = \parallel \hat{\mathbf{y}}(n) - \mathbf{y}(n) \parallel^2
$$

In [224]:
def loss(params, x, y, y_true, n):
    y_hat = readOut(params, x, y)
    return np.square(np.subtract(y_hat, y_true[n]))