# (Simple) RNN in Jax

In [511]:
import jax
import jax.numpy as np
import math
import numpy as onp
import tensorflow as tf
import time
import matplotlib.pyplot as plt
from jax import grad, vmap, jit
import jax
import jax.numpy as np
import pandas as pd
import itertools
import functools

In [512]:
key = jax.random.PRNGKey(1)

## Pre - and Postprocessing
$$
  \mathcal{D} : \text{CSV} \rightarrow [0,1]^{K \times n_{\max}}
$$

In [513]:
def D(df):
    time = df["time"].values
    notes = df["notes"].values
    velocity = df["velocity"].values
    h = {}
    h[36] = 0
    h[38] = 1
    h[41] = 2
    h[42] = 3
    h[43] = 4
    h[45] = 5
    h[82] = 6

    u = list(itertools.repeat(np.zeros(7), max(time)+1))

    for i in range(len(notes)):
        v = [0]*7
        v[h[notes[i]]] = velocity[i] / 127
        if (np.any(u[time[i]])):        #If the current vector is not the zero vector
            u[time[i]] = u[time[i]] + np.array(v)
        else:
            u[time[i]] = np.array(v)
    
    y_train = u[1:]
    y_train.append(np.zeros(7))
    return np.array(u), np.array(y_train)

$$
\mathcal{D}^{-1} : [0,1]^{K \times n_{max}} \rightarrow \text{CSV}
$$

In [514]:
def D_inv(y):
    h_inv = {}
    h_inv[0] = 36
    h_inv[1] = 38
    h_inv[2] = 41
    h_inv[3] = 42
    h_inv[4] = 43
    h_inv[5] = 45
    h_inv[6] = 82

    dfo = pd.DataFrame([], columns= ["track", "time", "note_is_played"
                            ,"channel", "notes", "velocity"])
    time_new = []
    notes_new = []
    velocity_new = []

    for n in range(len(y)):
        if (np.any(y[n])):
            yn = y[n]
            for i in range(len(yn)):
                v = (int)(yn[i] * 127)
                if (v > 0):
                    time_new.append(n)
                    velocity_new.append(v)
                    notes_new.append(h_inv[i])

    track_new = [2]*len(time_new)       # arbitrary
    note_is_played_new = ["Note_on_c"]*len(time_new)
    channel_new = [9]*len(time_new)     # 9 for drum

    dfo["track"] = track_new
    dfo["note_is_played"] = note_is_played_new
    dfo["channel"] = channel_new
    dfo["time"] = time_new
    dfo["velocity"] = velocity_new
    dfo["notes"] = notes_new

    latestTime = time_new[-1] + 10

    pre = [[0,0, "Header", 1, 2, 480, ''],
            [1, 0, "Start_track", '', '', '', ''],
            [1, 0, "Time_signature", 4, 2, 24, 8],
            [1, 0, "Title_t", "\"from model\"", '', '', ''],
            [1, 0, "End_track", '', '', '', ''],
            [2, 0, "Start_track", '', '', '', '']]

    post = [[2, latestTime, "End_track", '', '', '', ''],
            [0, 0, "End_of_file", '', '', '', '']]

    dfo["filler"] = ''
    # adding pre and post
    for i in range(len(pre)):
        dfo.loc[i] = pre[i]
    for j in range(len(post)):
        dfo.loc[len(dfo)+j] = post[j]

    dfo.to_csv("new.csv", index = False, header = False)
    return dfo

## TBD: Data Loading

In [515]:
columnNames = ["track", "time", "note_is_played"
                            , "channel", "notes", "velocity"]

df = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names= columnNames)

u = []
y_train = []

u_i, ytrain_i = D(df)

u.append(u_i)
y_train.append(ytrain_i)

print(jax.tree_map(lambda x: x.shape, u[0]))
print(jax.tree_map(lambda x: x.shape, y_train[0]))

(7590, 7)
(7590, 7)


# Network Architecture

A RNN (with no LSTM units) is a neural network with recurrent connections (dynamical system)
$$
\begin{array}{l l}
          \mathbf{x}(n+1) = \sigma(W \mathbf{x}(n) + W^{in}\mathbf{u}(n+1) + \mathbf{b}) \\
          \mathbf{y}(n) = f(W^{out} \mathbf{x}(n)),
\end{array}
$$
Describes how the network activation state is updated and how output signal is generated. 

Input vector $\mathbf{u}(n) \in [0,1]^K$

Activation/state vector $\mathbf{x}(n) \in \mathbb{R}^L$

Output vector $\mathbf{y}(n) \in \mathbb{R}^K$

Bias vector $\mathbf{b} \in \mathbb{R}^L$

$W^{in} \in \mathbb{R}^{L \times K}, W \in \mathbb{R}^{L \times L}, W^{out} \in \mathbb{R}^{K \times L}$ are weight matrices charecterizing the connections between neurons in the layers


## Initialization
At time $n = 0$ the recurrent network state $\mathbf{x}(0)$ is often set to the zero vector $\mathbf{x}(0) = \mathbf{0}$

In [516]:
def init_weight_matrix(in_dim, out_dim, key, scale=1e-2):
    w = jax.random.normal(key, (out_dim, in_dim))
    return scale*w

In [517]:
def init_bias(dim, key, scale=1e-2):
    b = jax.random.normal(key, (dim, ))
    return scale*b

In [518]:
# state, params = (W^{in}, W, W^{out}, b)
# sizes = (input dim, state dim, output dim.)
def init_network(sizes, key):
    keys = jax.random.split(key, len(sizes))
    params = {} # hashmap
    # don't know if this is the best way to do it but this is to keep track of the state vector over time
    x = []
    x.append(np.zeros(sizes[1]))
    # as well as output signal
    y = []
    
    params["input matrix"] = init_weight_matrix(sizes[0], sizes[1], keys[0])
    params["state matrix"] = init_weight_matrix(sizes[1], sizes[1], keys[1])
    params["output matrix"] = init_weight_matrix(sizes[1], sizes[2], keys[3])
    params["bias vector"] = init_bias(sizes[1], keys[2])

    return x, y, params


In [519]:
K = 7 # K := input and output vector dim
L = 12 # Reservoir or State Vector dim
sizes = [K, L, K]
x, y, params = init_network(sizes, key)
print(jax.tree_map(lambda x: x.shape, params)) # printing shape of network


{'bias vector': (12,), 'input matrix': (12, 7), 'output matrix': (7, 12), 'state matrix': (12, 12)}


$$ \mathbf{x}(n+1) = \sigma(W \mathbf{x}(n) + W^{in}\mathbf{u}(n+1) + \mathbf{b})$$

In [520]:
# params = weights and bias
# u = input signal at time n
# x = state vector : returns state at time n
# b = bias vector
# n = time
# changed: the entire state vector and input signal is passed now
# adds new state vector to state signal and also returns new state vector
def nextState(params, x, u, n):
    w_in = params["input matrix"]
    w = params["state matrix"]
    b = params["bias vector"]
    x_new = jax.nn.relu(np.dot(w, x[-1]) + np.dot(w_in, u[n]) + b)
    x.append(x_new)
    return x_new

$$\mathbf{y}(n) = f(W^{out} \mathbf{x}(n))$$
$f$ is a function that ensures the readout $W^{out} \mathbf{x}(n) \in \mathbb{R}^{K} \mapsto \mathbf{y}(n) \in [0,1]^K$. In other words, that the output of the network has vectors whose elements are always between $0$ and $1$:
$$
       f(x) := \frac{1}{1+e^{-x}}
$$
Both $\sigma$ and $f$ are applied element-wise on a given vector.

In [521]:
# adds new output vector to output signal and also returns new output vector
def readOut(params, x, y):
    w_out = params["output matrix"]
    y_new = jax.nn.sigmoid(np.dot(w_out, x[-1]))
    y.append(y_new)
    return y_new

Input signal
$$
\mathbf{u}[n]_{n=0, ..., n_{max}} \in [0,1]^K
$$

In [522]:
"""
for n in range(len(u)):
    nextState(params, x, u, n)
    readOut(params, x, y)
"""


'\nfor n in range(len(u)):\n    nextState(params, x, u, n)\n    readOut(params, x, y)\n'

In [523]:
# some printing
# note: u is a jax numpy matrix
# x, y are lists whose elements are jax numpy arrays
# but u[i], x[i], y[i] are all jax numpy arrays
#print(u[20])
#print(type(u))
#print(type(u[2]))
#print(x[20])
#print(type(x))
#print(type(x[2]))
#print(y[20])
#print(type(y))
#print(type(y[2]))

## Training


In [524]:
# x_og is initially the zero vector
def forward_bp(params, u, x_og=np.zeros((L, ))):
    """ Loop over the time steps of the input sequence
    u[n] := [u_0, ..., u_{n_max}] where u_i \in [0, 1]^K or (K, )
    x_og: \in R^L or (L, )
    """
    Win, W, Wout, b = params.values()
    x = x_og.copy()

    def apply_fun_scan(params, x, ut):
        """ Perform single step update of the network.
        x:  (L, )
        un: (K, )
        """
        Win, W, Wout, b = params.values()
        x = jax.nn.relu(
            np.dot(Win, ut) + np.dot(W, x) + b
        )
        y = jax.nn.sigmoid(np.dot(Wout, x))
        return x, y

    f = functools.partial(apply_fun_scan, params)
    _, Y = jax.lax.scan(f, x, u)
    return Y

In [547]:
print(jax.tree_map(lambda x: x.shape, x))
print(jax.tree_map(lambda x: x.shape, u[0]))
Y = forward_bp(params, u[0])
print(jax.tree_map(lambda x: x.shape, Y))
print(jax.tree_map(lambda x: x.shape, Y[0]))

print(Y[7589])
print(Y[7590])

# u[n] (7,)
# y[n] = u[n+1] (7,)


[(12,)]
(7590, 7)
(7590, 7)
(7,)
[0.49994802 0.49992424 0.49997348 0.4999168  0.4999715  0.49996275
 0.5000249 ]
[0.49994802 0.49992424 0.49997348 0.4999168  0.4999715  0.49996275
 0.5000249 ]


## Loss function
Logistic regression is used for the case where you do not model loudness. Because there it is a categorical task (to hit or not to hit). But if you do model loudness you use the procedure linear regression. Meaning using the loss functions that are mentioned in the RNN section of the reader like MSE or quadratic loss.

Time series prediction task $S = (\mathbf{u}^{(i)}(n), \mathbf{y}^{(i)}(n))_{i=1, ..., N;n=1, ..., n_i}$ where $\mathbf{y}^{i}(n) = \mathbf{u}^{(i)}(n+1)$
For now: quadratic loss which is used in stationary tasks
$$
    L(\hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}}) = \parallel \hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}} \parallel^2
$$

In [526]:
def loss(params, u, y_true):
    y_hat = forward_bp(params, u)
    return np.square(np.linalg.norm(np.subtract(y_hat, y_true)))

In [532]:
print(u[0].shape)
print(y_train[0].shape)
loss(params, u[0], y_train[0])

(7590, 7)
(7590, 7)


DeviceArray(13261.639, dtype=float32)

$$
    \mathcal{A}(S) = \theta_{\text{opt}} = \underset{\theta \in \Theta}{\text{argmin}} \; \underbrace{\frac{1}{N} \sum^N_{i=1} L(\hat{\mathbf{Y}}_{\theta}^{\text{train}}, \mathbf{Y}^{\text{train}})}_{\mathcal{R}^{\text{emp}}(\theta)},
$$
$$
    \theta^{(n+1)} = \theta^{(n)} - \mu \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}),
$$