In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from glob import glob
import librosa
import librosa.display
from sklearn.model_selection import train_test_split
from IPython.display import Image, Audio
import music21 as m21
import json
import functools

import jax
from jax import numpy as jnp, jit, grad, Array
from jax.typing import ArrayLike
import flax
from flax import linen as nn
from flax.training import train_state
from flax.core import FrozenDict
import optax

### Data Preproscessing

In [None]:
data_files = glob('musicnet/musicnet_midis/musicnet_midis/Bach/*.mid')
midi_data = []
for path in data_files:
  try:
    midi_data.append(m21.converter.parse(path))
  except Exception:
    continue

In [None]:
def extract_notes(files):
  notes = []
  pick = None
  for f in files:
    piece = m21.instrument.partitionByInstrument(f)
    for part in piece.parts:
      pick = part.recurse()
      for element in pick:
        if isinstance(element, m21.note.Note):
          notes.append(str(element.pitch))
        elif isinstance(element, m21.chord.Chord):
          notes.append(','.join(str(n) for n in element.normalOrder))
  return notes

In [None]:
notes = extract_notes(midi_data)

num_unique_notes = len(set(notes))

notes_to_nums = {note: i for i, note in enumerate(sorted(set(notes)))}
nums_to_notes = {i: note for note, i in notes_to_nums.items()}

In [None]:
sample_len = 64

features = []
targets = []

for i in range(len(notes) - sample_len):
  features.append([notes_to_nums[n] for n in notes[i:i + sample_len]])
  targets.append(notes_to_nums[notes[i + sample_len]])

In [None]:
x_full = np.array(features, np.float32).reshape(len(features), sample_len, 1) / num_unique_notes
y_full = np.array(targets, np.int32)

x_train, x_seed, y_train, y_seed = train_test_split(x_full, y_full, test_size=0.2, random_state=0)

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_seed, y_seed))

batch_size = 16
train_ds = train_ds.shuffle(buffer_size=len(x_seed)).batch(batch_size)
test_ds = test_ds.batch(batch_size)

### Model

In [None]:
class LSTM(nn.Module):
  def setup(self) -> None:
    lstm_layer = nn.scan(
      nn.OptimizedLSTMCell,
      variable_broadcast='params',
      in_axes=1,
      out_axes=1,
      split_rngs={'params': False})
    
    self.lstm1 = lstm_layer(128)
    self.lstm2 = lstm_layer(64)
    self.dense1 = nn.Dense(256)
    self.dense2 = nn.Dense(num_unique_notes)
  
  @nn.remat
  def __call__(self, x: ArrayLike) -> Array:
    carry, hidden = self.lstm1.initialize_carry(jax.random.PRNGKey(0), x[:, 0].shape)
    (carry, hidden), x = self.lstm1((carry, hidden), x)

    carry, hidden = self.lstm2.initialize_carry(jax.random.PRNGKey(1), x[:, 0].shape)
    (carry, hidden), x = self.lstm2((carry, hidden), x)

    x = x[:, -1]

    x = self.dense1(x)
    x = self.dense2(x)
    return x

In [None]:
TrainState = train_state.TrainState

def create_train_state(model: LSTM, rng_key: Array, learning_rate: float) -> TrainState:
  params = model.init(rng_key, x=x_train[:1])['params']
  tx = optax.adam(learning_rate=learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [None]:
def load_train_state(model: LSTM, learning_rate: float, params_path: str) -> TrainState:
  with open(params_path, 'r') as f:
    params_dict = json.load(f)
  
  def freeze_dict(unfrozen_dict: dict[any]) -> FrozenDict[any]:
    frozen_dict = {}
    for k, v in unfrozen_dict.items():
      if isinstance(v, dict):
        frozen_dict[k] = freeze_dict(v)
      else:
        frozen_dict[k] = jnp.array(v)
    return FrozenDict(frozen_dict) 

  params = freeze_dict(params_dict)
  tx = optax.adam(learning_rate=learning_rate)
  return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [None]:
@jit
def train_step(state: TrainState, batch: tuple[ArrayLike, ArrayLike]) -> TrainState:

  def loss_fn(params: FrozenDict) -> Array:
    logits = LSTM().apply({'params': params}, x=batch[0])
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss
  
  grad_fn = grad(loss_fn)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

In [None]:
@jit
def accuracy(state: TrainState, batch: ArrayLike):
  logits = state.apply_fn({'params': state.params}, x=batch[0])
  preds = jnp.argmax(logits, axis=1)
  acc = jnp.mean(preds == batch[1])
  return acc

In [None]:
model = LSTM()
state = create_train_state(model, jax.random.PRNGKey(0), learning_rate=1e-3)

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
  for batch in train_ds.as_numpy_iterator():
    state = train_step(state, batch)
  
  train_acc_list = []
  test_acc_list = []

  for batch in train_ds.as_numpy_iterator():
    acc = accuracy(state, batch)
    train_acc_list.append(acc)

  for batch in test_ds.as_numpy_iterator():
    acc = accuracy(state, batch)
    test_acc_list.append(acc)
  
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(f'[epoch {epoch + 1}] train acc: {train_acc}, test acc: {test_acc}')

In [None]:
def unfreeze_dict(frozen_dict: FrozenDict[any]) -> dict[any]:
  unfrozen_dict = {}
  for k, v in frozen_dict.items():
    if isinstance(v, FrozenDict) or isinstance(v, dict):
      unfrozen_dict[k] = unfreeze_dict(v)
    else:
      unfrozen_dict[k] = v.tolist()
  return unfrozen_dict

params_save_path = 'checkpoints/lstm-params.json'
params_dict = unfreeze_dict(state.params)
with open(params_save_path, 'w') as f:
  json.dump(params_dict, f)

In [None]:
def generate_melody(note_len: int) -> m21.stream.Stream:
  rand_idx = np.random.randint(0, len(x_seed))
  seed = x_seed[rand_idx:rand_idx + 1]
  music = []
  for i in range(note_len):
    logits = state.apply_fn({'params': state.params}, x=seed)
    pred = int(jnp.argmax(logits, axis=1)[0])
    music.append(str(nums_to_notes[pred]))
    seed = np.hstack((seed[:, 1:], [[[pred]]]))
  
  melody = []
  offset = 0 
  for x in music:
    if ',' in x or x.isdigit():
      chord_notes = [] 
      for y in x.split(','):
        chord_notes.append(m21.note.Note(int(y)))
        chord_snip = m21.chord.Chord(chord_notes)
        chord_snip.offset = offset
        melody.append(chord_snip)
    else: 
      note_snip = m21.note.Note(x)
      note_snip.offset = offset
      melody.append(note_snip)
    offset += 1

  melody_midi = m21.stream.Stream(melody)   
  return melody_midi

In [None]:
melody = generate_melody(128)
melody.write('midi', 'generated_music.mid')

In [None]:
nums_to_notes