# (Simple) RNN in Jax

In [872]:
import jax
import jax.numpy as np
import math
import numpy as onp
import tensorflow as tf
import time
import matplotlib.pyplot as plt
from jax import grad, vmap, jit
import jax
import jax.numpy as np
import pandas as pd
import itertools
import functools

In [873]:
key = jax.random.PRNGKey(1)

## Pre - and Postprocessing
$$
\mathcal{D} : \text{CSV} \rightarrow [0,1]^{K \times n_i}
$$

In [874]:
def D(df):
    time = df["time"].values
    notes = df["notes"].values
    velocity = df["velocity"].values
    h = {}
    h[36] = 0
    h[38] = 1
    h[41] = 2
    h[42] = 3
    h[43] = 4
    h[45] = 5
    h[82] = 6

    u = list(itertools.repeat(np.zeros(7), max(time)+1))

    for i in range(len(notes)):
        v = [0]*7
        v[h[notes[i]]] = velocity[i] / 127
        if (np.any(u[time[i]])):        #If the current vector is not the zero vector
            u[time[i]] = u[time[i]] + np.array(v)
        else:
            u[time[i]] = np.array(v)
    
    y_train = u[1:]
    y_train.append(np.zeros(7))
    return np.array(u), np.array(y_train)

$$
  \mathcal{D}^{-1} : [0,1]^{K \times n_i} \rightarrow \text{CSV}
$$

In [875]:
def D_inv(y):
    h_inv = {}
    h_inv[0] = 36
    h_inv[1] = 38
    h_inv[2] = 41
    h_inv[3] = 42
    h_inv[4] = 43
    h_inv[5] = 45
    h_inv[6] = 82

    dfo = pd.DataFrame([], columns= ["track", "time", "note_is_played"
                            ,"channel", "notes", "velocity"])
    time_new = []
    notes_new = []
    velocity_new = []

    for n in range(len(y)):
        if (np.any(y[n])):
            yn = y[n]
            for i in range(len(yn)):
                v = (int)(yn[i] * 127)
                if (v > 0):
                    time_new.append(n)
                    velocity_new.append(v)
                    notes_new.append(h_inv[i])

    track_new = [2]*len(time_new)       # arbitrary
    note_is_played_new = ["Note_on_c"]*len(time_new)
    channel_new = [9]*len(time_new)     # 9 for drum

    dfo["track"] = track_new
    dfo["note_is_played"] = note_is_played_new
    dfo["channel"] = channel_new
    dfo["time"] = time_new
    dfo["velocity"] = velocity_new
    dfo["notes"] = notes_new

    latestTime = time_new[-1] + 10

    pre = [[0,0, "Header", 1, 2, 480, ''],
            [1, 0, "Start_track", '', '', '', ''],
            [1, 0, "Time_signature", 4, 2, 24, 8],
            [1, 0, "Title_t", "\"from model\"", '', '', ''],
            [1, 0, "End_track", '', '', '', ''],
            [2, 0, "Start_track", '', '', '', '']]

    post = [[2, latestTime, "End_track", '', '', '', ''],
            [0, 0, "End_of_file", '', '', '', '']]

    dfo["filler"] = ''
    # adding pre and post
    for i in range(len(pre)):
        dfo.loc[i] = pre[i]
    for j in range(len(post)):
        dfo.loc[len(dfo)+j] = post[j]

    dfo.to_csv("new.csv", index = False, header = False)
    return dfo

## TBD: Data Loading

In [876]:
columnNames = ["track", "time", "note_is_played"
                            , "channel", "notes", "velocity"]

df1 = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names= columnNames)
df2 = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names= columnNames)
df3 = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names= columnNames)
df4 = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names= columnNames)
dfl = [df1, df2, df3, df4]

S = []
for df in dfl:
    utrain_i, ytrain_i = D(df)
    S.append((utrain_i, ytrain_i))

print(type(S))
u_ex, y_ex = S[0]
print(jax.tree_map(lambda x: x.shape, u_ex))
print(jax.tree_map(lambda x: x.shape, y_ex))

<class 'list'>
(7590, 7)
(7590, 7)


# Network Architecture

A RNN (with no LSTM units) is a neural network with recurrent connections (dynamical system)
$$
\begin{array}{l l}
          \mathbf{x}(n+1) = \sigma(W \mathbf{x}(n) + W^{in}\mathbf{u}(n+1) + \mathbf{b}) \\
          \mathbf{y}(n) = f(W^{out} \mathbf{x}(n)),
\end{array}
$$
Describes how the network activation state is updated and how output signal is generated. 

Input vector $\mathbf{u}(n) \in [0,1]^K$

Activation/state vector $\mathbf{x}(n) \in \mathbb{R}^L$

Output vector $\mathbf{y}(n) \in \mathbb{R}^K$

Bias vector $\mathbf{b} \in \mathbb{R}^L$

$W^{in} \in \mathbb{R}^{L \times K}, W \in \mathbb{R}^{L \times L}, W^{out} \in \mathbb{R}^{K \times L}$ are weight matrices charecterizing the connections between neurons in the layers


## Initialization
At time $n = 0$ the recurrent network state $\mathbf{x}(0)$ is often set to the zero vector $\mathbf{x}(0) = \mathbf{0}$

In [877]:
def init_weight_matrix(in_dim, out_dim, key, scale=1e-2):
    w = jax.random.normal(key, (out_dim, in_dim))
    return scale*w

In [878]:
def init_bias(dim, key, scale=1e-2):
    b = jax.random.normal(key, (dim, ))
    return scale*b

In [879]:
# state, params = (W^{in}, W, W^{out}, b)
# sizes = (input dim, state dim, output dim.)
def init_network(sizes, key):
    keys = jax.random.split(key, len(sizes))
    # don't know if this is the best way to do it but this is to keep track of the state vector over time
    x = []
    x.append(np.zeros(sizes[1]))
    # as well as output signal
    y = []
    
    Win = init_weight_matrix(sizes[0], sizes[1], keys[0])
    W = init_weight_matrix(sizes[1], sizes[1], keys[1])
    Wout = init_weight_matrix(sizes[1], sizes[2], keys[3])
    b = init_bias(sizes[1], keys[2])

    return x, y, (Win, W, Wout, b)


In [880]:
K = 7 # K := input and output vector dim
L = 12 # Reservoir or State Vector dim
sizes = [K, L, K]
x, y, params = init_network(sizes, key)
print(jax.tree_map(lambda x: x.shape, params)) # printing shape of network

((12, 7), (12, 12), (7, 12), (12,))


$$ \mathbf{x}(n+1) = \sigma(W \mathbf{x}(n) + W^{in}\mathbf{u}(n+1) + \mathbf{b})$$

In [881]:
# params = weights and bias
# u = input signal at time n
# x = state vector : returns state at time n
# b = bias vector
# n = time
# changed: the entire state vector and input signal is passed now
# adds new state vector to state signal and also returns new state vector
def nextState(params, x, u, n):
    W_in, W, W_out, b = params

    x_new = jax.nn.relu(np.dot(W, x[-1]) + np.dot(W_in, u[n]) + b)
    x.append(x_new)
    return x_new

$$\mathbf{y}(n) = f(W^{out} \mathbf{x}(n))$$
$f$ is a function that ensures the readout $W^{out} \mathbf{x}(n) \in \mathbb{R}^{K} \mapsto \mathbf{y}(n) \in [0,1]^K$. In other words, that the output of the network has vectors whose elements are always between $0$ and $1$:
$$
       f(x) := \frac{1}{1+e^{-x}}
$$
Both $\sigma$ and $f$ are applied element-wise on a given vector.

In [882]:
# adds new output vector to output signal and also returns new output vector
def readOut(params, x, y):
    W_out = params[2]
    y_new = jax.nn.sigmoid(np.dot(W_out, x[-1]))
    y.append(y_new)
    return y_new

## Training


In [883]:
# x_og is initially the zero vector
def forward_bp(params, u, x_og=np.zeros((L, ))):
    """ Loop over the time steps of the input sequence
    u[n] := [u_0, ..., u_{n_max}] where u_i \in [0, 1]^K or (K, )
    x_og: \in R^L or (L, )
    """
    Win, W, Wout, b = params
    x = x_og.copy()

    def apply_fun_scan(params, x, ut):
        """ Perform single step update of the network.
        x:  (L, )
        un: (K, )
        """
        Win, W, Wout, b = params
        x = jax.nn.relu(
            np.dot(Win, ut) + np.dot(W, x) + b
        )
        y = jax.nn.sigmoid(np.dot(Wout, x))
        return x, y

    f = functools.partial(apply_fun_scan, params)
    _, Y = jax.lax.scan(f, x, u)
    return Y

batch_forward_bp = jax.vmap(forward_bp, in_axes=(None, 0))

## Loss function
Logistic regression is used for the case where you do not model loudness. Because there it is a categorical task (to hit or not to hit). But if you do model loudness you use the procedure linear regression. Meaning using the loss functions that are mentioned in the RNN section of the reader like MSE or quadratic loss.

Time series prediction task $S = (\mathbf{u}^{(i)}(n), \mathbf{y}^{(i)}(n))_{i=1, ..., N;n=1, ..., n_i}$ where $\mathbf{y}^{i}(n) = \mathbf{u}^{(i)}(n+1)$
For now: quadratic loss which is used in stationary tasks
$$
    L(\hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}}) = \parallel \hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}} \parallel^2
$$

Regularization
$$
\text{reg}(\theta) = \sum_{w \in \theta} w^2
$$

In [884]:

# could be made nicer
def getParameterVector(params):
    theta = []
    for w in params:
        for e in w:
            if (e.size > 1):
                for i in e:
                    theta.append(i)
            else:
                theta.append(e)
    return np.array(theta)

def reg(params):
    theta = getParameterVector(params)
    return np.sum(np.square(theta))

def accuracy(params, u, y_true):
    true = np.argmax(y_true, axis=1)
    pred = np.argmax(batch_forward_bp(params, u), axis=1)
    return np.mean(pred == true)


$$
\mathcal{R}^{\text{emp}}(\theta) = \frac{1}{N} \sum^{N}_{i=1} L(\hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}}) + r^2 \; \text{reg}(\theta)
$$

In [885]:
def loss(params, u, y_true, r):
    y_hat = batch_forward_bp(params, u)
    return np.square(np.linalg.norm(np.subtract(y_hat, y_true))) + (r*r)*reg(params)

$$
    \theta^{(n+1)} = \theta^{(n)} - \mu \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}),
$$

$$
   \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}) = 
\bigg(\frac{\partial  R^{emp}}{\partial  w_1}(\theta^{(n)}), ...,\frac{\partial  R^{emp}}{\partial w_D}(\theta^{(n)}) \bigg)',
$$

In [886]:
@jax.jit
def update(params, u, y_true, r, step_size=1e-2):
    grads = jax.grad(loss)(params, u, y_true, r)
    return [
        w - step_size * dw
        for w, dw in zip(params, grads)
    ]

In [887]:

"""
u_train, y_train = S[0]
u_batch = []
u_batch.append(u_train)
y_batch = []
y_batch.append(y_train)
u_batch = np.array(u_batch)
y_batch = np.array(y_batch)
print(u_train.shape)
print(y_train.shape)
print(accuracy(params, u_batch, y_batch))
"""


'\nu_train, y_train = S[0]\nu_batch = []\nu_batch.append(u_train)\ny_batch = []\ny_batch.append(y_train)\nu_batch = np.array(u_batch)\ny_batch = np.array(y_batch)\nprint(u_train.shape)\nprint(y_train.shape)\nprint(accuracy(params, u_batch, y_batch))\n'

$$
    \mathcal{A}(S) = \theta_{\text{opt}} = \underset{\theta \in \Theta}{\text{argmin}} \; \underbrace{\frac{1}{N} \sum^N_{i=1} L(\hat{\mathbf{Y}}_{\theta}^{\text{train}}, \mathbf{Y}^{\text{train}})}_{\mathcal{R}^{\text{emp}}(\theta)},
$$
$$
    \theta^{(n+1)} = \theta^{(n)} - \mu \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}),
$$

## Cross Validation

In [888]:
# s is a list of tuples [(u1, y1), (u2, y2), ...]
def unpack(s):
    f1 = map(lambda x: x[0], s)
    f2 = map(lambda x: x[1], s)
    u_batch = np.array(list(f1))
    y_batch = np.array(list(f2))
    return (u_batch, y_batch)

def train(params, u_train, y_train, u_test, y_test, r, n_epochs=1):
    train_acc = []
    test_acc = []
    train_loss = []
    test_loss = []
    for epoch in range(n_epochs):
        start_time = time.time()

        params = update(params, u_train, y_train, r)
        train_acc.append(accuracy(params, u_train, y_train))
        test_acc.append(accuracy(params, u_test, y_test))
        train_loss.append(loss(params, u_train, y_train, r))
        test_loss.append(loss(params, u_test, y_test, r))

        epoch_time = time.time() - start_time

        print(f'Epoch {epoch+1:>2} ({epoch_time:<.2f}s): ', end='')
        print(f'train loss {train_loss[-1]:<5.2f} test loss {test_loss[-1]:<5.2f}', end='| ')
        print(f'train acc {train_acc[-1]:<7.2%} test acc {test_acc[-1]:<7.2%}')
    
    return params

def train2(params, u_train, y_train, r, n_epochs=1):
    train_acc = []
    train_loss = []
    for epoch in range(n_epochs):
        start_time = time.time()

        params = update(params, u_train, y_train, r)
        train_acc.append(accuracy(params, u_train, y_train))
        train_loss.append(loss(params, u_train, y_train, r))

        epoch_time = time.time() - start_time

        print(f'Epoch {epoch+1:>2} ({epoch_time:<.2f}s): ', end='')
    return params


# given S, params, loss/emprical risk, r (regularization)
# Hyperparameters
step_size=1e-2
k = 2   # k fold cross validation

validation_risk_r = []
for r in range(0, 3):
    validation_risk = []
    for j in range(0, k):
        V = [S.pop(j)]  # validation set
        T = S         # reduced training set
        S.insert(j, V[0])
        u_train, y_train = unpack(T)
        u_val, y_val = unpack(V)
        params = train(params, u_train, y_train, u_val, y_val, r)
        validation_risk.append(loss(params, u_val, y_val, r))
    validation_risk_r.append(np.mean(np.array(validation_risk)))

r_opt = np.argmin(np.array(validation_risk_r))
u_train, y_train = unpack(V)
params = train2(params, u_train, y_train, r_opt)



Epoch  1 (4.97s): train loss 110.39 test loss 27.60| train acc 0.00%   test acc 0.00%  
Epoch  1 (4.91s): train loss 110.38 test loss 27.60| train acc 0.00%   test acc 0.00%  
Epoch  1 (2.84s): train loss 152.69 test loss 69.93| train acc 0.00%   test acc 0.00%  
Epoch  1 (2.88s): train loss 150.99 test loss 68.23| train acc 0.00%   test acc 0.00%  
Epoch  1 (2.82s): train loss 248.74 test loss 165.41| train acc 0.00%   test acc 0.00%  
Epoch  1 (2.88s): train loss 229.54 test loss 146.35| train acc 0.00%   test acc 0.00%  
Epoch  1 (4.03s): 

[DeviceArray([[-0.00119001,  0.01259624, -0.00131898,  0.00153466,
                0.01790148, -0.00584567, -0.00977399],
              [-0.00944623, -0.00674881, -0.0020011 ,  0.01082675,
               -0.00501226, -0.00223849, -0.00982482],
              [-0.00129515,  0.00074726, -0.00435336,  0.00055631,
               -0.00723396, -0.00704851, -0.01171278],
              [ 0.00877953,  0.00424988, -0.00521509, -0.00389891,
                0.01553905,  0.0143659 ,  0.01274391],
              [ 0.01746623, -0.0012544 ,  0.01016922,  0.01293689,
               -0.01183019,  0.01439269, -0.00593833],
              [ 0.00969443, -0.00107673,  0.00478851, -0.01253107,
                0.00870384, -0.003004  ,  0.00622379],
              [-0.01348182, -0.00261646,  0.00646225, -0.00182024,
                0.01281928,  0.00430677,  0.00874748],
              [ 0.01025784, -0.00033019,  0.00027282, -0.00199228,
                0.00259043,  0.00282665,  0.00316758],
              [ 0.009805