In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from glob import glob
import librosa
import librosa.display
from sklearn.model_selection import train_test_split
import json

import jax
from jax import numpy as jnp, jit, grad, Array
from jax.typing import ArrayLike
import flax
from flax import linen as nn
from flax.training import train_state
from flax.core import FrozenDict
import optax

### Data Preproscessing
- Load data
- Convert raw audio to spectrograms
- Get notes at every timestep of each spectrogram
- Truncate data so all sequences are of equal length
- Split data into train and test sets
- Load proscessed data into TensorFlow datasets

In [2]:
data_files = glob('../musicnet/musicnet/*/*.wav')
label_files = glob('../musicnet/musicnet/*/*.csv')

In [3]:
sr = 22050
hop_length = 512
n_mels = 512

def wav_to_mel_spec(path: str) -> np.ndarray:
  y, _ = librosa.load(path)
  spec = librosa.feature.melspectrogram(y=y, sr=sr, hop_length=hop_length, n_mels=n_mels)
  return librosa.amplitude_to_db(spec, ref=np.max).T

In [4]:
data = {path[-8:-4]: wav_to_mel_spec(path) for path in data_files}

In [5]:
min_note = 21
max_note = 104
num_notes = max_note - min_note + 1

labels = {}

for path in label_files:
  key = path[-8:-4]
  df = pd.read_csv(path)
  label_mat = np.zeros((len(data[key]), num_notes), np.float32)

  for row in df.itertuples():
    note = row.note
    start = row.start_time // 1024
    end = row.end_time // 1024
    label_mat[start:end, note - min_note] = 1

  labels[key] = label_mat

In [6]:
keys = sorted(data.keys())

In [7]:
truncated_len = 512

truncated_data = []
truncated_labels = []

for key in keys:
  x = data[key]
  y = labels[key]
  for i in range(0, x.shape[0] - truncated_len + 1, truncated_len):
    truncated_data.append(x[i:i + truncated_len])
    truncated_labels.append(y[i:i + truncated_len])

truncated_data = np.array(truncated_data)
truncated_labels = np.array(truncated_labels)

In [8]:
x_train, x_test, y_train, y_test = train_test_split(truncated_data, truncated_labels)

In [9]:
batch_size = 16
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size=batch_size)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size=batch_size)

### LSTM Model

In [21]:
class LSTM(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x):
    ScanLSTM = nn.scan(
      nn.LSTMCell, 
      variable_broadcast='params',
      split_rngs={'params': False}, 
      in_axes=1, 
      out_axes=1,
    )

    lstm = ScanLSTM(features=128)
    carry = lstm.initialize_carry(jax.random.key(0), x[:, 0].shape)
    carry, x = lstm(carry, x)

    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    x = nn.Dense(features=self.features)(x)
    return x

In [22]:
TrainState = train_state.TrainState

def create_train_state(model: LSTM, rng_key: Array, learning_rate: float) -> TrainState:
  params = model.init(rng_key, x=x_train[:1])['params']
  tx = optax.adam(learning_rate=learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [23]:
@jit
def train_step(state: TrainState, batch: tuple[ArrayLike, ArrayLike]) -> TrainState:

  def loss_fn(params: FrozenDict) -> Array:
    logits = state.apply_fn({'params': params}, x=batch[0])
    loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).mean()
    return loss
  
  grad_fn = grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [27]:
@jit
def compute_metrics(state: TrainState, batch: ArrayLike) -> tuple[float, float]:
  logits = state.apply_fn({'params': state.params}, x=batch[0])
  loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).mean()
  preds = jnp.round(nn.sigmoid(logits))
  acc = jnp.mean(preds == batch[1])
  return loss, acc

In [32]:
model = LSTM(num_notes)
state = create_train_state(model, jax.random.PRNGKey(0), learning_rate=1e-4)

In [33]:
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []

for batch in train_ds.as_numpy_iterator():
  loss, acc = compute_metrics(state, batch)
  train_loss_list.append(loss)
  train_acc_list.append(acc)

for batch in test_ds.as_numpy_iterator():
  loss, acc = compute_metrics(state, batch)
  test_loss_list.append(loss)
  test_acc_list.append(acc)

train_loss = sum(train_loss_list) / len(train_loss_list)
train_acc = sum(train_acc_list) / len(train_acc_list)
test_loss = sum(test_loss_list) / len(test_loss_list)
test_acc = sum(test_acc_list) / len(test_acc_list)

print(
  f'train loss: {train_loss},', 
  f'train acc: {train_acc},', 
  f'test loss: {test_loss},',
  f'test acc: {test_acc},',
)

train loss: 0.7163073420524597, train acc: 0.49891921877861023, test loss: 0.7164062857627869, test acc: 0.4987044334411621,


In [34]:
num_epochs = 5

for epoch in range(num_epochs):
  for batch in train_ds.as_numpy_iterator():
    state = train_step(state, batch)
  
  train_loss_list = []
  train_acc_list = []
  test_loss_list = []
  test_acc_list = []

  for batch in train_ds.as_numpy_iterator():
    loss, acc = compute_metrics(state, batch)
    train_loss_list.append(loss)
    train_acc_list.append(acc)

  for batch in test_ds.as_numpy_iterator():
    loss, acc = compute_metrics(state, batch)
    test_loss_list.append(loss)
    test_acc_list.append(acc)
  
  train_loss = sum(train_loss_list) / len(train_loss_list)
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_loss = sum(test_loss_list) / len(test_loss_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(
    f'[epoch {epoch + 1}]', 
    f'train loss: {train_loss},', 
    f'train acc: {train_acc},', 
    f'test loss: {test_loss},',
    f'test acc: {test_acc},',
  )

[epoch 1] train loss: 0.14465948939323425, train acc: 0.9632760286331177, test loss: 0.14487622678279877, test acc: 0.9631381034851074,
[epoch 2] train loss: 0.1396123170852661, train acc: 0.9632760286331177, test loss: 0.1398719996213913, test acc: 0.9631381034851074,
[epoch 3] train loss: 0.1391063928604126, train acc: 0.9632760286331177, test loss: 0.1393844038248062, test acc: 0.9631381034851074,
[epoch 4] train loss: 0.13879244029521942, train acc: 0.9632760286331177, test loss: 0.13907095789909363, test acc: 0.9631381034851074,
[epoch 5] train loss: 0.13887061178684235, train acc: 0.9632760286331177, test loss: 0.13914521038532257, test acc: 0.9631381034851074,
[epoch 6] train loss: 0.13883280754089355, train acc: 0.9632760286331177, test loss: 0.1391068994998932, test acc: 0.9631381034851074,
[epoch 7] train loss: 0.13881346583366394, train acc: 0.9632760286331177, test loss: 0.13908827304840088, test acc: 0.9631381034851074,
[epoch 8] train loss: 0.1387949436903, train acc: 0.9

In [35]:
def unfreeze_dict(frozen_dict: FrozenDict[any]) -> dict[any]:
  unfrozen_dict = {}
  for k, v in frozen_dict.items():
    if isinstance(v, FrozenDict) or isinstance(v, dict):
      unfrozen_dict[k] = unfreeze_dict(v)
    else:
      unfrozen_dict[k] = v.tolist()
  return unfrozen_dict

params_save_path = 'checkpoints/lstm-params.json'
params_dict = unfreeze_dict(state.params)
with open(params_save_path, 'w') as f:
  json.dump(params_dict, f)