In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from glob import glob
import librosa
import librosa.display
from sklearn.model_selection import train_test_split
import music21 as m21
from midiutil import MIDIFile
import json

import jax
from jax import numpy as jnp, jit, grad, Array
from jax.typing import ArrayLike
import flax
from flax import linen as nn
from flax.training import train_state
from flax.core import FrozenDict
import optax

### Data Preproscessing for Note Sequence Classification
- Load data
- Convert raw audio to spectrograms
- Get notes at every timestep of each spectrogram
- Truncate data so all sequences are of equal length
- Split data into train and test sets
- Load proscessed data into TensorFlow datasets

In [2]:
data_files = glob('../musicnet/musicnet/*/*.wav')
label_files = glob('../musicnet/musicnet/*/*.csv')

In [3]:
sr = 22050
hop_length = 512
n_mels = 512

def wav_to_mel_spec(path: str) -> np.ndarray:
  y, _ = librosa.load(path)
  spec = librosa.feature.melspectrogram(y=y, sr=sr, hop_length=hop_length, n_mels=n_mels)
  return librosa.amplitude_to_db(spec, ref=np.max).T

In [4]:
data = {path[-8:-4]: wav_to_mel_spec(path) for path in data_files}

In [5]:
min_note = 21
max_note = 104
num_notes = max_note - min_note + 1

labels = {}

for path in label_files:
  key = path[-8:-4]
  df = pd.read_csv(path)
  label_mat = np.zeros((len(data[key]), num_notes), np.float32)

  for row in df.itertuples():
    note = row.note
    start = row.start_time // 1024
    end = row.end_time // 1024
    label_mat[start:end, note - min_note] = 1

  labels[key] = label_mat

In [6]:
keys = sorted(data.keys())

In [7]:
truncated_len = 512

truncated_data = []
truncated_labels = []

for key in keys:
  x = data[key]
  y = labels[key]
  for i in range(0, x.shape[0] - truncated_len + 1, truncated_len):
    truncated_data.append(x[i:i + truncated_len])
    truncated_labels.append(y[i:i + truncated_len])

truncated_data = np.array(truncated_data)
truncated_labels = np.array(truncated_labels)

In [8]:
x1_train, x1_test, y1_train, y1_test = train_test_split(truncated_data, truncated_labels)

In [9]:
train_ds1 = tf.data.Dataset.from_tensor_slices((x1_train, y1_train)).batch(batch_size=16)
test_ds1 = tf.data.Dataset.from_tensor_slices((x1_test, y1_test)).batch(batch_size=16)

### LSTM Model for Note Sequence Classification

In [10]:
class LSTM(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x: ArrayLike) -> Array:
    ScanLSTM = nn.scan(
      nn.LSTMCell, 
      variable_broadcast='params',
      split_rngs={'params': False}, 
      in_axes=1, 
      out_axes=1,
    )

    lstm = ScanLSTM(features=128)
    carry = lstm.initialize_carry(jax.random.key(0), x[:, 0].shape)
    carry, x = lstm(carry, x)

    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    x = nn.Dense(features=self.features)(x)
    return x

In [11]:
TrainState = train_state.TrainState

def create_train_state(model: LSTM, x: ArrayLike, rng_key: Array, learning_rate: float) -> TrainState:
  params = model.init(rng_key, x=x)['params']
  tx = optax.adam(learning_rate=learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [12]:
def load_train_state(model: LSTM, learning_rate: float, params_path: str) -> TrainState:
  with open(params_path, 'r') as f:
    params_dict = json.load(f)
  
  def freeze_dict(unfrozen_dict: dict[any]) -> FrozenDict[any]:
    frozen_dict = {}
    for k, v in unfrozen_dict.items():
      if isinstance(v, dict):
        frozen_dict[k] = freeze_dict(v)
      else:
        frozen_dict[k] = jnp.array(v)
    return FrozenDict(frozen_dict) 

  params = freeze_dict(params_dict)
  tx = optax.adam(learning_rate=learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [13]:
@jit
def train_step_bce(state: TrainState, batch: tuple[ArrayLike, ArrayLike]) -> TrainState:

  def loss_fn(params: FrozenDict) -> Array:
    logits = state.apply_fn({'params': params}, x=batch[0])
    loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).mean()
    return loss
  
  grad_fn = grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [14]:
@jit
def compute_metrics_bce(state: TrainState, batch: ArrayLike) -> tuple[float, float]:
  logits = state.apply_fn({'params': state.params}, x=batch[0])
  loss = optax.sigmoid_binary_cross_entropy(logits=logits, labels=batch[1]).mean()
  preds = jnp.round(nn.sigmoid(logits))
  acc = jnp.mean(preds == batch[1])
  return loss, acc

In [15]:
model1 = LSTM(num_notes)
# state1 = create_train_state(model1, x1_train[:1], jax.random.PRNGKey(0), learning_rate=1e-4)
state1 = load_train_state(model1, learning_rate=1e-4, params_path='checkpoints/lstm1-params.json')

In [16]:
def test_model1(epoch: int) -> None:
  train_loss_list = []
  train_acc_list = []
  test_loss_list = []
  test_acc_list = []

  for batch in train_ds1.as_numpy_iterator():
    loss, acc = compute_metrics_bce(state1, batch)
    train_loss_list.append(loss)
    train_acc_list.append(acc)

  for batch in test_ds1.as_numpy_iterator():
    loss, acc = compute_metrics_bce(state1, batch)
    test_loss_list.append(loss)
    test_acc_list.append(acc)

  train_loss = sum(train_loss_list) / len(train_loss_list)
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_loss = sum(test_loss_list) / len(test_loss_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(
    f'[epoch {epoch}]',
    f'train loss: {train_loss},', 
    f'train acc: {train_acc},', 
    f'test loss: {test_loss},',
    f'test acc: {test_acc},',
  )

In [17]:
# num_epochs = 2

# for epoch in range(num_epochs):
#   for batch in train_ds1.as_numpy_iterator():
#     state1 = train_step_bce(state1, batch)
  
#   test_model1(epoch + 1)

test_model1(0)

[epoch 0] train loss: 0.1401427984237671, train acc: 0.9632152318954468, test loss: 0.13950517773628235, test acc: 0.9633268117904663,


In [18]:
def unfreeze_dict(frozen_dict: FrozenDict[any]) -> dict[any]:
  unfrozen_dict = {}
  for k, v in frozen_dict.items():
    if isinstance(v, FrozenDict) or isinstance(v, dict):
      unfrozen_dict[k] = unfreeze_dict(v)
    else:
      unfrozen_dict[k] = v.tolist()
  return unfrozen_dict

# params_save_path1 = 'checkpoints/lstm1-params.json'
# params_dict1 = unfreeze_dict(state1.params)
# with open(params_save_path1, 'w') as f:
#   json.dump(params_dict1, f)

### Data Preproscessing for Note Times to Beat Regression
- Load data
- Convert data to intervals (if from model 1 output)
- Truncate data
- Load new training data and labels
- Load data into TensorFlow datasets

In [19]:
def matrix_to_notes(x: ArrayLike) -> list[tuple[float, float, int]]:
  notes = []
  for j in range(num_notes):
    in_note = False
    begin = 0
    for i in range(len(x)):
      if x[i, j] > 0:
        if not in_note:
          in_note = True
          begin = i * 1024 / 44100
      else:
        if in_note:
          in_note = False
          end = i * 1024 / 44100
          notes.append((begin, end, j + min_note))
  
  return notes

In [160]:
note_value_set = set()

for path in label_files:
  df = pd.read_csv(path)
  for row in df.itertuples():
    note_value_set.add(row.note_value)

In [161]:
note_values_to_nums = {v: i for i, v in enumerate(sorted(note_value_set))}
nums_to_note_values = {v: k for k, v in note_values_to_nums.items()}

In [162]:
note_value_nums_to_beats = {
  0: 0.75,
  1: 3,
  2: 1.5,
  3: 0.375,
  4: 0.5,
  5: 2,
  6: 1,
  7: 0.25,
  8: 0.0625,
  9: 0.125,
  10: 1.25,
  11: 1.125,
  12: 1 / 3,
  13: 0.25 / 3,
  14: 0.0625 / 3,
  15: 0.125 / 3,
  16: 1,
  17: 4,
}

In [166]:
times = []
note_values = []

for path in label_files:
  piece_times = []
  piece_note_values = []

  df = pd.read_csv(path)

  for row in df.itertuples():
    piece_times.append([row.start_time / 44100, row.end_time / 44100])
    piece_note_values.append(note_values_to_nums[row.note_value])

  times.append(piece_times)
  note_values.append(piece_note_values)

In [222]:
truncated_len = 64

truncated_times = []
truncated_note_values = []

for x, y in zip(times, note_values):
  for i in range(0, len(x) - truncated_len + 1, truncated_len):
    truncated_times.append(x[i:i + truncated_len])
    truncated_note_values.append(y[i:i + truncated_len])

In [223]:
x2_full = np.array(truncated_times, np.float32)
y2_full = np.array(truncated_note_values, np.int32)

x2_train, x2_test, y2_train, y2_test = train_test_split(x2_full, y2_full)

train_ds2 = tf.data.Dataset.from_tensor_slices((x2_train, y2_train)).batch(batch_size=16)
test_ds2 = tf.data.Dataset.from_tensor_slices((x2_test, y2_test)).batch(batch_size=16)

In [230]:
class LSTM2(nn.Module):
  features: int

  @nn.compact
  def __call__(self, x: ArrayLike) -> Array:
    ScanLSTM = nn.scan(
      nn.LSTMCell, 
      variable_broadcast='params',
      split_rngs={'params': False}, 
      in_axes=1, 
      out_axes=1,
    )

    lstm1 = ScanLSTM(features=128)
    carry = lstm1.initialize_carry(jax.random.key(0), x[:, 0].shape)
    carry, x = lstm1(carry, x)

    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    x = nn.Dense(features=64)(x)
    x = nn.relu(x)
    x = nn.Dense(features=32)(x)
    x = nn.relu(x)
    x = nn.Dense(features=self.features)(x)
    return x

In [231]:
@jit
def train_step_cce(state: TrainState, batch: tuple[ArrayLike, ArrayLike]) -> TrainState:

  def loss_fn(params: FrozenDict) -> Array:
    logits = state.apply_fn({'params': params}, x=batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss
  
  grad_fn = grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [232]:
@jit
def compute_metrics_cce(state: TrainState, batch: ArrayLike) -> tuple[float, float]:
  logits = state.apply_fn({'params': state.params}, x=batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
  preds = jnp.argmax(logits, axis=2)
  acc = jnp.mean(preds == batch[1])
  return loss, acc

In [233]:
model2 = LSTM2(18)
state2 = create_train_state(model2, x2_train[:1], jax.random.PRNGKey(0), learning_rate=1e-4)

In [234]:
def test_model2(epoch: int) -> None:
  train_loss_list = []
  train_acc_list = []
  test_loss_list = []
  test_acc_list = []

  for batch in train_ds2.as_numpy_iterator():
    loss, acc = compute_metrics_cce(state2, batch)
    train_loss_list.append(loss)
    train_acc_list.append(acc)

  for batch in test_ds2.as_numpy_iterator():
    loss, acc = compute_metrics_cce(state2, batch)
    test_loss_list.append(loss)
    test_acc_list.append(acc)

  train_loss = sum(train_loss_list) / len(train_loss_list)
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_loss = sum(test_loss_list) / len(test_loss_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(
    f'[epoch {epoch}]',
    f'train loss: {train_loss},', 
    f'train acc: {train_acc},', 
    f'test loss: {test_loss},',
    f'test acc: {test_acc},',
  )

In [235]:
num_epochs = 300

for epoch in range(num_epochs):
  for batch in train_ds2.as_numpy_iterator():
    state2 = train_step_cce(state2, batch)
  
  if (epoch + 1) % 10 == 0:
    test_model2(epoch + 1)

[epoch 10] train loss: 1.8079787492752075, train acc: 0.3520481288433075, test loss: 1.809807300567627, test acc: 0.353514164686203,
[epoch 20] train loss: 1.778620719909668, train acc: 0.3688410818576813, test loss: 1.7797287702560425, test acc: 0.3674545884132385,
[epoch 30] train loss: 1.7525087594985962, train acc: 0.379303514957428, test loss: 1.754604697227478, test acc: 0.37645891308784485,
[epoch 40] train loss: 1.742892861366272, train acc: 0.3841315507888794, test loss: 1.7450793981552124, test acc: 0.3815096616744995,
[epoch 50] train loss: 1.730048418045044, train acc: 0.3914686441421509, test loss: 1.7320880889892578, test acc: 0.3890003561973572,
[epoch 60] train loss: 1.7169861793518066, train acc: 0.39790603518486023, test loss: 1.7196967601776123, test acc: 0.39521780610084534,
[epoch 70] train loss: 1.7131528854370117, train acc: 0.4000147581100464, test loss: 1.7159243822097778, test acc: 0.3974764943122864,
[epoch 80] train loss: 1.7092208862304688, train acc: 0.402

In [211]:
params_save_path2 = 'checkpoints/lstm2-params.json'
params_dict2 = unfreeze_dict(state2.params)
with open(params_save_path2, 'w') as f:
  json.dump(params_dict2, f)

In [123]:
def create_midi_file(notes, output_file='output.mid', tempo=120):
  midi = MIDIFile(1)
  midi.addTempo(0, 0, tempo)
  
  for start, duration, note_num in notes:
    track = 0
    channel = 0  
    volume = 100

    midi.addNote(track, channel, note_num, start, duration, volume)

  with open(output_file, 'wb') as midi_file:
    midi.writeFile(midi_file)

In [124]:
def play_midi_file(file_path='output.mid'):
  mf = m21.midi.MidiFile()
  mf.open(file_path)
  mf.read()
  mf.close()

  score = m21.midi.translate.midiFileToStream(mf)
  
  print(f'Duration: {score.highestTime} seconds')

  sp = m21.midi.realtime.StreamPlayer(score)
  sp.play()

In [125]:
x_notes = matrix_to_notes(labels['2242'])
x_times = np.array([[[s, e] for s, e, _ in x_notes[:truncated_len]]], np.float32) / 44100

In [142]:
x_beats = state2.apply_fn({'params': state2.params}, x=x_times)
x_beats.shape

(1, 256, 2)

In [146]:
new_x_notes = []
for i in range(truncated_len):
  new_x_notes.append((float(x_beats[0, i, 0]), float(x_beats[0, i, 1]), x_notes[i][2]))

In [147]:
real_notes = []

df = pd.read_csv('../musicnet/musicnet/train_labels/2242.csv')

for row in df.itertuples():
  real_notes.append((row.start_beat, row.end_beat, row.note))

In [148]:
create_midi_file(new_x_notes)

KeyError: '550'

In [183]:
x1_test.shape, y1_test.shape

((2544, 512, 512), (2544, 512, 84))

In [184]:
x2_test.shape, y2_test.shape

((1022, 256, 2), (1022, 256))