# RNN in Jax

In [22]:
import jax
import jax.numpy as np
import math
import numpy as onp
import tensorflow as tf
import time
import matplotlib.pyplot as plt
from jax import grad, vmap, jit
import jax
import jax.numpy as np
import pandas as pd
import itertools
import functools
import os
from shutil import copy
from shutil import move

In [23]:
key = jax.random.PRNGKey(1)
TRAINING_ON = False

## Pre - and Postprocessing
$$
\mathcal{D} : \text{CSV} \rightarrow [0,1]^{K \times n_i}
$$

In [24]:
def D(df):
    time = df["time"].values
    notes = df["notes"].values
    velocity = df["velocity"].values
    h = {}
    h[36] = 0
    h[38] = 1
    h[41] = 2
    h[42] = 3
    h[43] = 4
    h[45] = 5
    h[82] = 6

    u = list(itertools.repeat(np.zeros(7), max(time)+1))

    for i in range(len(notes)):
        v = [0]*7
        v[h[notes[i]]] = velocity[i] / 127
        if (np.any(u[time[i]])):        #If the current vector is not the zero vector
            u[time[i]] = u[time[i]] + np.array(v)
        else:
            u[time[i]] = np.array(v)
    
    y_train = u[1:]
    y_train.append(np.zeros(7))
    return np.array(u), np.array(y_train)

$$
  \mathcal{D}^{-1} : [0,1]^{K \times n_i} \rightarrow \text{CSV}
$$

In [25]:
def D_inv(y):
    h_inv = {}
    h_inv[0] = 36
    h_inv[1] = 38
    h_inv[2] = 41
    h_inv[3] = 42
    h_inv[4] = 43
    h_inv[5] = 45
    h_inv[6] = 82

    dfo = pd.DataFrame([], columns= ["track", "time", "note_is_played"
                            ,"channel", "notes", "velocity"])
    time_new = []
    notes_new = []
    velocity_new = []

    for n in range(len(y)):
        if (np.any(y[n])):
            yn = y[n]
            for i in range(len(yn)):
                v = (int)(yn[i] * 127)
                if (v > 0):
                    time_new.append(n)
                    velocity_new.append(v)
                    notes_new.append(h_inv[i])

    track_new = [2]*len(time_new)       # arbitrary
    note_is_played_new = ["Note_on_c"]*len(time_new)
    channel_new = [9]*len(time_new)     # 9 for drum

    dfo["track"] = track_new
    dfo["note_is_played"] = note_is_played_new
    dfo["channel"] = channel_new
    dfo["time"] = time_new
    dfo["velocity"] = velocity_new
    dfo["notes"] = notes_new

    latestTime = time_new[-1] + 10

    pre = [[0,0, "Header", 1, 2, 480, ''],
            [1, 0, "Start_track", '', '', '', ''],
            [1, 0, "Time_signature", 4, 2, 24, 8],
            [1, 0, "Title_t", "\"from model\"", '', '', ''],
            [1, 0, "End_track", '', '', '', ''],
            [2, 0, "Start_track", '', '', '', '']]

    post = [[2, latestTime, "End_track", '', '', '', ''],
            [0, 0, "End_of_file", '', '', '', '']]

    dfo["filler"] = ''
    # adding pre and post
    for i in range(len(pre)):
        dfo.loc[i] = pre[i]
    for j in range(len(post)):
        dfo.loc[len(dfo)+j] = post[j]

    #dfo.to_csv("new.csv", index = False, header = False)
    return dfo

## Data Loading

$$
    \mathcal{M} : \text{MIDI} \rightarrow \text{CSV}
$$

In [26]:
"""
Copies over the MIDI files to the folder 
where the can be converted.

"""
cwd = os.getcwd()

os.chdir('../midi/NetworkInputMIDI')

fileNames = []
for m in os.listdir():
    fileNames.append(m)
    copy(m, '../midicsv-1.1')

os.chdir('../midicsv-1.1')
for m in fileNames:
    command = "midicsv" + " " + m + " " + m[:-4] + ".csv"
    res = os.system(command)
    os.remove(m)
    try:
        move(m[:-4] + ".csv", '../NetworkInputCSV')
    except:
        print("csv file already there")
        os.remove(m[:-4] + ".csv")

# ensures that the cwd resets
os.chdir(cwd)

csv file already there
csv file already there
csv file already there
csv file already there
csv file already there
csv file already there
csv file already there
csv file already there


Now import the csv files as panda dataframes

In [27]:
os.chdir('../midi/NetworkInputCSV')
print(os.getcwd())
S = []  # S will be a list of tuples
        # each tuple contains two input signals over time
columnNames = ["track", "time", "note_is_played"
                            , "channel", "notes", "velocity"]

"""
Loop through all the csv files in
the network input folder
"""
for csv in os.listdir():
    # for now: generalize later
    df = None
    if (csv == "drumDemo.csv"):
        df = pd.read_csv(csv, skipfooter=2, skiprows=8, engine='python', names= columnNames)
        u_train, y_train = D(df)
        S.append((u_train, y_train))
        df = pd.read_csv(csv, skipfooter=2, skiprows=8, engine='python', names= columnNames)
        u_train, y_train = D(df)
        S.append((u_train, y_train))
        df = pd.read_csv(csv, skipfooter=2, skiprows=8, engine='python', names= columnNames)
        u_train, y_train = D(df)
        S.append((u_train, y_train))
        df = pd.read_csv(csv, skipfooter=2, skiprows=8, engine='python', names= columnNames)
        u_train, y_train = D(df)
        S.append((u_train, y_train))

print(type(S))
u_ex, y_ex = S[0]
print(jax.tree_map(lambda x: x.shape, u_ex))
print(jax.tree_map(lambda x: x.shape, y_ex))



os.chdir(cwd)

c:\Users\Matth\OneDrive\Documenten\NN\semester-project\midi\NetworkInputCSV
<class 'list'>
(7590, 7)
(7590, 7)


# Network Architecture

A RNN (with no LSTM units) is a neural network with recurrent connections (dynamical system)
$$
\begin{array}{l l}
          \mathbf{x}(n+1) = \sigma(W \mathbf{x}(n) + W^{in}\mathbf{u}(n+1) + \mathbf{b}) \\
          \mathbf{y}(n) = f(W^{out} \mathbf{x}(n)),
\end{array}
$$
Describes how the network activation state is updated and how output signal is generated. 

Input vector $\mathbf{u}(n) \in [0,1]^K$

Activation/state vector $\mathbf{x}(n) \in \mathbb{R}^L$

Output vector $\mathbf{y}(n) \in \mathbb{R}^K$

Bias vector $\mathbf{b} \in \mathbb{R}^L$

$W^{in} \in \mathbb{R}^{L \times K}, W \in \mathbb{R}^{L \times L}, W^{out} \in \mathbb{R}^{K \times L}$ are weight matrices charecterizing the connections between neurons in the layers


## Initialization
At time $n = 0$ the recurrent network state $\mathbf{x}(0)$ is often set to the zero vector $\mathbf{x}(0) = \mathbf{0}$

In [28]:
def init_weight_matrix(in_dim, out_dim, key, scale=1e-2):
    w = jax.random.normal(key, (out_dim, in_dim))
    return scale*w

In [29]:
def init_bias(dim, key, scale=1e-2):
    b = jax.random.normal(key, (dim, ))
    return scale*b

In [30]:
# state, params = (W^{in}, W, W^{out}, b)
# sizes = (input dim, state dim, output dim.)
def init_network(sizes, key):
    keys = jax.random.split(key, len(sizes))
    
    Win = init_weight_matrix(sizes[0], sizes[1], keys[0])
    W = init_weight_matrix(sizes[1], sizes[1], keys[1])
    Wout = init_weight_matrix(sizes[1], sizes[2], keys[3])
    b = init_bias(sizes[1], keys[2])

    return (Win, W, Wout, b)


In [31]:
K = 7 # K := input and output vector dim
L = 125 # Reservoir or State Vector dim
sizes = [K, L, K]
params = init_network(sizes, key)
print(jax.tree_map(lambda x: x.shape, params)) # printing shape of network

((125, 7), (125, 125), (7, 125), (125,))


## Training


In [32]:
# x_og is initially the zero vector
def forward_bp(params, u, x_og=np.zeros((L, ))):
    """ Loop over the time steps of the input sequence
    u[n] := [u_0, ..., u_{n_max}] where u_i \in [0, 1]^K or (K, )
    x_og: \in R^L or (L, )
    """
    Win, W, Wout, b = params
    x = x_og.copy()

    def apply_fun_scan(params, x, ut):
        """ Perform single step update of the network.
        x:  (L, )
        un: (K, )
        """
        Win, W, Wout, b = params
        x = jax.nn.relu(
            np.dot(Win, ut) + np.dot(W, x) + b
        )
        y = jax.nn.sigmoid(np.dot(Wout, x))
        return x, y

    f = functools.partial(apply_fun_scan, params)
    _, Y = jax.lax.scan(f, x, u)
    return Y

batch_forward_bp = jax.vmap(forward_bp, in_axes=(None, 0))

## Loss function
Logistic regression is used for the case where you do not model loudness. Because there it is a categorical task (to hit or not to hit). But if you do model loudness you use the procedure linear regression. Meaning using the loss functions that are mentioned in the RNN section of the reader like MSE or quadratic loss.

Time series prediction task $S = (\mathbf{u}^{(i)}(n), \mathbf{y}^{(i)}(n))_{i=1, ..., N;n=1, ..., n_i}$ where $\mathbf{y}^{i}(n) = \mathbf{u}^{(i)}(n+1)$
For now: quadratic loss which is used in stationary tasks
$$
    L(\hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}}) = \parallel \hat{\mathbf{Y}}_{i, \theta}^{\text{train}} - \mathbf{Y}_i^{\text{train}} \parallel^2
$$

Regularization
$$
\text{reg}(\theta) = \sum_{w \in \theta} w^2
$$

In [33]:

# could be made nicer
def getParameterVector(params):
    theta = []
    for w in params:
        for e in w:
            if (e.size > 1):
                for i in e:
                    theta.append(i)
            else:
                theta.append(e)
    return np.array(theta)

def reg(params):
    theta = getParameterVector(params)
    return np.sum(np.square(theta))

$$
\mathcal{R}^{\text{emp}}(\theta) = \frac{1}{N} \sum^{N}_{i=1} L(\hat{\mathbf{Y}}_{i, \theta}^{\text{train}}, \mathbf{Y}_i^{\text{train}}) + r^2 \; \text{reg}(\theta)
$$

In [34]:
def loss(params, u, y_true, alpha):
    y_hat = batch_forward_bp(params, u)
    return np.square(np.linalg.norm(np.subtract(y_hat, y_true))) + (alpha*alpha)*reg(params)

$$
    \theta^{(n+1)} = \theta^{(n)} - \mu \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}),
$$

$$
   \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}) = 
\bigg(\frac{\partial  R^{emp}}{\partial  w_1}(\theta^{(n)}), ...,\frac{\partial  R^{emp}}{\partial w_D}(\theta^{(n)}) \bigg)',
$$

In [35]:
@jax.jit
def update(params, u, y_true, alpha, step_size=1e-2):
    grads = jax.grad(loss)(params, u, y_true, alpha)
    return [
        w - step_size * dw
        for w, dw in zip(params, grads)
    ]

$$
    \mathcal{A}(S) = \theta_{\text{opt}} = \underset{\theta \in \Theta}{\text{argmin}} \; \underbrace{\frac{1}{N} \sum^N_{i=1} L(\hat{\mathbf{Y}}_{\theta}^{\text{train}}, \mathbf{Y}^{\text{train}})}_{\mathcal{R}^{\text{emp}}(\theta)},
$$
$$
    \theta^{(n+1)} = \theta^{(n)} - \mu \nabla \mathcal{R}^{\text{emp}}(\theta^{(n)}),
$$

## Cross Validation

In [36]:

"""
s is a list of tuples 
[(u1, y1), (u2, y2), ...] -> ((u_1, u_2, ...), (y_1, y_2, ...))
"""
def unpack(s):
    f1 = map(lambda x: x[0], s)
    f2 = map(lambda x: x[1], s)
    u_batch = np.array(list(f1))
    y_batch = np.array(list(f2))
    return (u_batch, y_batch)

def train(params, u_train, y_train, u_test, y_test, alpha, n_epochs=2):
    train_loss = []
    test_loss = []
    for epoch in range(n_epochs):
        start_time = time.time()

        params = update(params, u_train, y_train, alpha)
        train_loss.append(loss(params, u_train, y_train, alpha))
        test_loss.append(loss(params, u_test, y_test, alpha))

        epoch_time = time.time() - start_time

        print(f'Epoch {epoch+1:>2} ({epoch_time:<.2f}s): ', end='')
        print(f'train loss {train_loss[-1]:<5.2f} test loss {test_loss[-1]:<5.2f}', end='| ')
    
    return params

def train2(params, u_train, y_train, r, n_epochs=1):
    train_loss = []
    for epoch in range(n_epochs):
        start_time = time.time()

        params = update(params, u_train, y_train, r)
        train_loss.append(loss(params, u_train, y_train, r))

        epoch_time = time.time() - start_time

        print(f'Epoch {epoch+1:>2} ({epoch_time:<.2f}s): ', end='')
    return params

if (TRAINING_ON):
    # given S, params, loss/emprical risk, r (regularization)
    # Hyperparameters
    step_size=1e-2
    k = 4   # k fold cross validation
    # r is the hyperparameter (here r = alpha for the regularization)
    validation_risk_r = []

    # split S into k disjoint subsets

    """
    Edge case: if k > amount of tuples (u_train, y_train) in S then n = 0
    which will result in a valueError when constructing S_k.
    Furthermore: the procedure will break down for single (u_train, y_train)
    because then the reduced training set will be the empty set
    """
    n = (int) (len(S)/k)

    S_k = [S[i : i + n] for i in range(0, len(S), n)]

    for r in range(0, 2):
        validation_risk = []
        for j in range(0, k):
            V = S_k.pop(j)  # validation
            T = [x for l in S_k for x in l]         # reduced training set
            S_k.insert(j, V)
            u_train, y_train = unpack(T)
            u_val, y_val = unpack(V)
            print(u_train.shape)
            print(u_val.shape)
            params = train(params, u_train, y_train, u_val, y_val, r)
            validation_risk.append(loss(params, u_val, y_val, r))
        validation_risk_r.append(np.mean(np.array(validation_risk)))

    r_opt = np.argmin(np.array(validation_risk_r))
    u_train, y_train = unpack(V)
    params = train2(params, u_train, y_train, r_opt)






### Parameter Saving

## Music Generation
Get the file to prime the network with

In [37]:
def apply_fun_scan(params, x, un):
    """ Perform single step update of the network.
    x:  (L, ) at time step n -> x: (L, ) at time step n+1
    un: (K, )
    """
    Win, W, Wout, b = params
    x = jax.nn.relu(
        np.dot(Win, un) + np.dot(W, x) + b
    )
    y = jax.nn.sigmoid(np.dot(Wout, x))
    return x, y

df1 = pd.read_csv('drumDemo.csv', skipfooter=2, skiprows=8, engine='python', names= columnNames)
u_prime = D(df1)[0]
n_stop = 1000
n_output = 8000

y_signal = []
x = np.zeros((L,))
for n in range(n_stop):
    x, y = apply_fun_scan(params, x, u_prime[n])
    y_signal.append(y)

for n in range(n_output):
    x, y = apply_fun_scan(params, x, y)
    y_signal.append(y)

"""
The non-trained network produces pretty much a signal that is zero everywhere
the no csv file can be generated
"""
dfo = D_inv(np.array(y_signal))

$$
\mathcal{M}^{-1} : \text{CSV} \rightarrow \text{MIDI}
$$

In [43]:
"""
Puts generated CSV file in correct folder
And generates the MIDI file and puts that in the correct
folder as well.
"""
# for debugging: checking highest amplitude in output signal
"""
v = -99999999999999
for e in y_signal:
    if e.max() >= v:
        v = e.max()
print(v)
"""

os.chdir('../midi/NetworkOutputCSV')
name = "rnn.csv"
dfo.to_csv(name, index = False, header = False)
copy(name, '../midicsv-1.1')
os.chdir('../midicsv-1.1')

command = "csvmidi" + " " + name + " " + name[:-4] + ".mid"
res = os.system(command)
print(res)
os.remove(name)
try:
    move(name[:-4] + ".mid", '../NetworkOutputMIDI')
except:
    print("midi file already here")
    #os.remove(name[:-4] + ".mid")

os.chdir(cwd)




0
midi file already here
