In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from glob import glob
import librosa
import librosa.display
from sklearn.model_selection import train_test_split
import json

import jax
from jax import numpy as jnp, jit, grad, Array
from jax.typing import ArrayLike
import flax
from flax import linen as nn
from flax.training import train_state
from flax.core import FrozenDict
import optax

### Data Preproscessing
- Load data
- Convert raw audio to spectrograms
- Get notes at every timestep of each spectrogram
- Truncate data so all sequences are of equal length
- Split data into train and test sets
- Load proscessed data into TensorFlow datasets

In [None]:
data_files = glob('../musicnet/musicnet/*/*.wav')
label_files = glob('../musicnet/musicnet/*/*.csv')

In [None]:
sr = 22050
hop_length = 512
n_mels = 512

def wav_to_mel_spec(path: str) -> np.ndarray:
  y, _ = librosa.load(path)
  spec = librosa.feature.melspectrogram(y=y, sr=sr, hop_length=hop_length, n_mels=n_mels)
  return librosa.amplitude_to_db(spec, ref=np.max).T

In [None]:
data = {path[-8:-4]: wav_to_mel_spec(path) for path in data_files}

In [None]:
min_note = 21
max_note = 104
num_notes = max_note - min_note + 1

labels = {}

for path in label_files:
  key = path[-8:-4]
  df = pd.read_csv(path)
  label_mat = np.zeros((len(data[key]), num_notes), np.float32)

  for row in df.itertuples():
    note = row.note
    start = row.start_time // 1024
    end = row.end_time // 1024
    label_mat[start:end, note - min_note] = 1

  labels[key] = label_mat

In [None]:
keys = sorted(data.keys())

In [None]:
truncated_len = 512

truncated_data = []
truncated_labels = []

for key in keys:
  x = data[key]
  y = labels[key]
  for i in range(0, x.shape[0] - truncated_len + 1, truncated_len):
    truncated_data.append(x[i:i + truncated_len])
    truncated_labels.append(y[i:i + truncated_len])

truncated_data = np.array(truncated_data)
truncated_labels = np.array(truncated_labels)

In [None]:
x1_train, x1_test, y1_train, y1_test = train_test_split(truncated_data, truncated_labels)

In [None]:
train_ds1 = tf.data.Dataset.from_tensor_slices((x1_train, y1_train)).batch(batch_size=16)
test_ds1 = tf.data.Dataset.from_tensor_slices((x1_test, y1_test)).batch(batch_size=16)

### LSTM Model

In [None]:
class LSTM(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x: ArrayLike) -> Array:
    ScanLSTM = nn.scan(
      nn.LSTMCell, 
      variable_broadcast='params',
      split_rngs={'params': False}, 
      in_axes=1, 
      out_axes=1,
    )

    lstm = ScanLSTM(features=128)
    carry = lstm.initialize_carry(jax.random.key(0), x[:, 0].shape)
    carry, x = lstm(carry, x)

    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    x = nn.Dense(features=self.features)(x)
    return x

In [None]:
TrainState = train_state.TrainState

def create_train_state(model: LSTM, x: ArrayLike, rng_key: Array, learning_rate: float) -> TrainState:
  params = model.init(rng_key, x=x)['params']
  tx = optax.adam(learning_rate=learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [None]:
@jit
def train_step_bce(state: TrainState, batch: tuple[ArrayLike, ArrayLike]) -> TrainState:

  def loss_fn(params: FrozenDict) -> Array:
    logits = state.apply_fn({'params': params}, x=batch[0])
    loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).mean()
    return loss
  
  grad_fn = grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [None]:
@jit
def compute_metrics_bce(state: TrainState, batch: ArrayLike) -> tuple[float, float]:
  logits = state.apply_fn({'params': state.params}, x=batch[0])
  loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).mean()
  preds = jnp.round(nn.sigmoid(logits))
  acc = jnp.mean(preds == batch[1])
  return loss, acc

In [None]:
model1 = LSTM(num_notes)
state1 = create_train_state(model1, x1_train[:1], jax.random.PRNGKey(0), learning_rate=1e-4)

In [None]:
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []

for batch in train_ds1.as_numpy_iterator():
  loss, acc = compute_metrics_bce(state1, batch)
  train_loss_list.append(loss)
  train_acc_list.append(acc)

for batch in test_ds1.as_numpy_iterator():
  loss, acc = compute_metrics_bce(state1, batch)
  test_loss_list.append(loss)
  test_acc_list.append(acc)

train_loss = sum(train_loss_list) / len(train_loss_list)
train_acc = sum(train_acc_list) / len(train_acc_list)
test_loss = sum(test_loss_list) / len(test_loss_list)
test_acc = sum(test_acc_list) / len(test_acc_list)

print(
  f'train loss: {train_loss},', 
  f'train acc: {train_acc},', 
  f'test loss: {test_loss},',
  f'test acc: {test_acc},',
)

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
  for batch in train_ds1.as_numpy_iterator():
    state1 = train_step_bce(state1, batch)
  
  train_loss_list = []
  train_acc_list = []
  test_loss_list = []
  test_acc_list = []

  for batch in train_ds1.as_numpy_iterator():
    loss, acc = compute_metrics_bce(state1, batch)
    train_loss_list.append(loss)
    train_acc_list.append(acc)

  for batch in test_ds1.as_numpy_iterator():
    loss, acc = compute_metrics_bce(state1, batch)
    test_loss_list.append(loss)
    test_acc_list.append(acc)
  
  train_loss = sum(train_loss_list) / len(train_loss_list)
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_loss = sum(test_loss_list) / len(test_loss_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(
    f'[epoch {epoch + 1}]', 
    f'train loss: {train_loss},', 
    f'train acc: {train_acc},', 
    f'test loss: {test_loss},',
    f'test acc: {test_acc},',
  )

In [None]:
def unfreeze_dict(frozen_dict: FrozenDict[any]) -> dict[any]:
  unfrozen_dict = {}
  for k, v in frozen_dict.items():
    if isinstance(v, FrozenDict) or isinstance(v, dict):
      unfrozen_dict[k] = unfreeze_dict(v)
    else:
      unfrozen_dict[k] = v.tolist()
  return unfrozen_dict

params_save_path1 = 'checkpoints/lstm1-params.json'
params_dict1 = unfreeze_dict(state1.params)
with open(params_save_path1, 'w') as f:
  json.dump(params_dict1, f)

In [None]:
def matrix_to_notes(x: ArrayLike) -> list[tuple[float, float, int]]:
  notes = []
  for j in range(num_notes):
    in_note = False
    begin = 0
    for i in range(len(x)):
      if x[i, j] > 0:
        if not in_note:
          in_note = True
          begin = i * 1024
      else:
        if in_note:
          in_note = False
          end = i * 1024
          notes.append((begin, end, j + min_note))
  
  return notes

In [None]:
notes = matrix_to_notes(labels['1727'])
notes.sort()

In [None]:
print(notes)

In [None]:
times = []
beats = []

for path in label_files:
  piece_times = []
  piece_beats = []

  for row in df.itertuples():
    piece_times.append((row.start_time, row.end_time))
    piece_beats.append((row.start_beat, row.end_beat))

  times.append(piece_times)
  beats.append(piece_beats)

In [None]:
x2_full = np.array(times, np.float32)
y2_full = np.array(beats, np.float32)

x2_train, x2_test, y2_train, y2_test = train_test_split(x2_full, y2_full)

train_ds2 = tf.data.Dataset.from_tensor_slices((x2_train, y2_train)).batch(batch_size=16)
test_ds2 = tf.data.Dataset.from_tensor_slices((x2_test, y2_test)).batch(batch_size=16)

In [None]:
@jit
def train_step_mse(state: TrainState, batch: tuple[ArrayLike, ArrayLike]) -> TrainState:

  def loss_fn(params: FrozenDict) -> Array:
    logits = state.apply_fn({'params': params}, x=batch[0])
    loss = optax.squared_error(predictions=logits, targets=batch[1]).mean()
    return loss
  
  grad_fn = grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [None]:
@jit
def compute_metrics_mse(state: TrainState, batch: ArrayLike) -> tuple[float, float]:
  logits = state.apply_fn({'params': state.params}, x=batch[0])
  loss = optax.squared_error(predictions=logits, targets=batch[1]).mean()
  return loss

In [None]:
model2 = LSTM(2)
state2 = create_train_state(model2, x2_train[:1], jax.random.PRNGKey(0), learning_rate=1e-3)

In [None]:
num_epochs = 100

for epoch in range(num_epochs):
  for batch in train_ds2.as_numpy_iterator():
    state2 = train_step_mse(state2, batch)
  
  if (epoch + 1) % 10 == 0:
    train_loss_list = []
    test_loss_list = []

    for batch in train_ds2.as_numpy_iterator():
      loss = compute_metrics_mse(state2, batch)
      train_loss_list.append(loss)

    for batch in test_ds2.as_numpy_iterator():
      loss = compute_metrics_mse(state2, batch)
      test_loss_list.append(loss)
    
    train_loss = sum(train_loss_list) / len(train_loss_list)
    test_loss = sum(test_loss_list) / len(test_loss_list)

    print(
      f'[epoch {epoch + 1}]', 
      f'train loss: {train_loss},', 
      f'test loss: {test_loss},',
    )

In [None]:
params_save_path2 = 'checkpoints/lstm2-params.json'
params_dict2 = unfreeze_dict(state2.params)
with open(params_save_path2, 'w') as f:
  json.dump(params_dict2, f)