In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from glob import glob
import librosa
import librosa.display
import IPython.display as ipd
from sklearn.model_selection import train_test_split

import jax
from jax import numpy as jnp, jit, grad
import flax
from flax import linen as nn
from flax.training.train_state import TrainState
import optax

### Data Preproscessing

In [2]:
metadata = pd.read_csv('musicnet_metadata.csv')
train_data_files = glob('musicnet/musicnet/train_data/*.wav')
test_data_files = glob('musicnet/musicnet/test_data/*.wav')

In [3]:
def wav_to_mel_spec(path):
  y, sr = librosa.load(path)
  spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=512)
  return np.abs(librosa.amplitude_to_db(spec, ref=np.max))

In [4]:
train_data = [wav_to_mel_spec(path) for path in train_data_files]
test_data = [wav_to_mel_spec(path) for path in test_data_files]

In [8]:
train_data_ids = [int(path[-8:-4]) for path in train_data_files]
test_data_ids = [int(path[-8:-4]) for path in test_data_files]

train_labels = [metadata[metadata['id'] == i]['ensemble'].values[0] for i in train_data_ids]
test_labels = [metadata[metadata['id'] == i]['ensemble'].values[0] for i in test_data_ids]

In [9]:
labels_to_nums = {label: i for i, label in enumerate(set(train_labels))}
nums_to_labels = {i: label for label, i in labels_to_nums.items()}

In [10]:
x_train = np.array([x[:, :1024].reshape(512, 1024, 1) for x in train_data], np.float32)
x_test = np.array([x[:, :1024].reshape(512, 1024, 1) for x in test_data], np.float32)
y_train = np.array([labels_to_nums[label] for label in train_labels], np.int32)
y_test = np.array([labels_to_nums[label] for label in test_labels], np.int32)

### Model

In [11]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))

batch_size = 10
train_ds = train_ds.shuffle(buffer_size=len(x_train)).batch(batch_size)
test_ds = test_ds.batch(batch_size)

In [77]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=16, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=8, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    
    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    x = nn.Dense(features=21)(x)
    return x

In [78]:
def create_train_state(model, rng, learning_rate):
  params = model.init(rng, jnp.ones((1, *x_train.shape[1:])))['params']
  tx = optax.adam(learning_rate)
  return TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [79]:
@jit
def train_step(state, batch):
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, batch[0])
    return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
  
  grad_fn = grad(loss_fn)
  return state.apply_gradients(grads=grad_fn(state.params))

In [83]:
@jit
def compute_metrics(state, batch):
  logits = state.apply_fn({'params': state.params}, batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
  preds = jnp.argmax(logits, axis=1)
  acc = jnp.mean(preds == batch[1])
  return acc

In [86]:
flax_model = CNN()
state = create_train_state(flax_model, jax.random.PRNGKey(0), learning_rate=0.0001)

In [87]:
num_epochs = 40

for epoch in range(num_epochs):
  for batch in train_ds.as_numpy_iterator():
    state = train_step(state, batch)
  
  train_acc_list = []
  test_acc_list = []

  for batch in train_ds.as_numpy_iterator():
    acc = compute_metrics(state, batch)
    train_acc_list.append(acc)

  for batch in test_ds.as_numpy_iterator():
    acc = compute_metrics(state, batch)
    test_acc_list.append(acc)
  
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(f'[epoch {epoch + 1}] train acc: {train_acc}, test acc: {test_acc}')

[epoch 1] train acc: 0.03437500074505806, test acc: 0.10000000149011612
[epoch 2] train acc: 0.4781250059604645, test acc: 0.30000001192092896
[epoch 3] train acc: 0.4781250059604645, test acc: 0.30000001192092896
[epoch 4] train acc: 0.4937499761581421, test acc: 0.30000001192092896
[epoch 5] train acc: 0.49062496423721313, test acc: 0.30000001192092896
[epoch 6] train acc: 0.5312500596046448, test acc: 0.30000001192092896
[epoch 7] train acc: 0.5750000476837158, test acc: 0.30000001192092896
[epoch 8] train acc: 0.546875, test acc: 0.4000000059604645
[epoch 9] train acc: 0.5375000238418579, test acc: 0.4000000059604645
[epoch 10] train acc: 0.6124999523162842, test acc: 0.30000001192092896
[epoch 11] train acc: 0.640625, test acc: 0.30000001192092896
[epoch 12] train acc: 0.6312500238418579, test acc: 0.30000001192092896
[epoch 13] train acc: 0.637499988079071, test acc: 0.30000001192092896
[epoch 14] train acc: 0.6968749761581421, test acc: 0.4000000059604645
[epoch 15] train acc: 0