In [20]:
import numpy as np
import pandas as pd
import tensorflow as tf
from glob import glob
import librosa
import librosa.display
from sklearn.model_selection import train_test_split
import json

import jax
from jax import numpy as jnp, jit, value_and_grad
import flax
from flax import linen as nn
from flax.training import train_state
import optax

### Data Preproscessing

In [2]:
metadata = pd.read_csv('musicnet/musicnet_metadata.csv')
data_files = glob('musicnet/musicnet/*/*.wav')

In [3]:
n_mels = 512

def wav_to_mel_spec(path):
  y, sr = librosa.load(path)
  spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
  return np.abs(librosa.amplitude_to_db(spec, ref=np.max))

data = [wav_to_mel_spec(path) for path in data_files]

In [4]:
data_ids = [int(path[-8:-4]) for path in data_files]
labels = [metadata[metadata['id'] == i]['ensemble'].values[0] for i in data_ids]

In [5]:
labels_to_nums = {label: i for i, label in enumerate(sorted(set(labels)))}
nums_to_labels = {i: label for label, i in labels_to_nums.items()}

labels = [labels_to_nums[label] for label in labels]

In [6]:
new_data = []
new_labels = []

sample_len = 512

for x, y in zip(data, labels):
  for i in range(0, x.shape[1] - sample_len + 1, sample_len):
    new_data.append(np.expand_dims(x[:, i:i + sample_len], axis=2))
    new_labels.append(y)

In [7]:
x_full = np.array(new_data, np.float32)
y_full = np.array(new_labels, np.int32)

In [8]:
x_train, x_test, y_train, y_test = train_test_split(x_full, y_full, test_size=0.25, random_state=0)

### Model

In [9]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))

batch_size = 16
train_ds = train_ds.shuffle(buffer_size=len(x_train)).batch(batch_size)
test_ds = test_ds.batch(batch_size)

In [10]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x, train):
    x = (nn.Conv(features=8, kernel_size=(3, 3), use_bias=False))(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=8, kernel_size=(3, 3), use_bias=False)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=8, kernel_size=(3, 3), use_bias=False)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    
    x = nn.Dense(features=128)(x)
    x = nn.relu(x)
    x = nn.Dense(features=21)(x)
    return x

In [11]:
class TrainState(train_state.TrainState):
  batch_stats: any

def create_train_state(model, rng, learning_rate):
  variables = model.init(rng, x=jnp.ones((1, *x_train.shape[1:])), train=False)
  return TrainState.create(
    apply_fn=model.apply, 
    params=variables['params'],
    batch_stats=variables['batch_stats'], 
    tx=optax.adamw(learning_rate, weight_decay=1e-3),
  )

In [12]:
@jit
def train_step(state, batch):
  def loss_fn(params):
    logits, updates = state.apply_fn({
      'params': params, 
      'batch_stats': state.batch_stats,
    }, x=batch[0], train=True, mutable=['batch_stats'])

    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, updates
  
  grad_fn = value_and_grad(loss_fn, has_aux=True)
  (_, updates), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  return state

In [13]:
@jit
def accuracy(state, batch):
  logits = state.apply_fn({
    'params': state.params, 
    'batch_stats': state.batch_stats,
  }, x=batch[0], train=False)
  
  preds = jnp.argmax(logits, axis=1)
  acc = jnp.mean(preds == batch[1])
  return acc

In [14]:
flax_model = CNN()
state = create_train_state(flax_model, jax.random.PRNGKey(0), learning_rate=1e-4)



Metal device set to: Apple M3 Pro

systemMemory: 18.00 GB
maxCacheSize: 6.00 GB



In [15]:
num_epochs = 10

for epoch in range(num_epochs):
  for batch in train_ds.as_numpy_iterator():
    state = train_step(state, batch)
  
  train_acc_list = []
  test_acc_list = []

  for batch in train_ds.as_numpy_iterator():
    acc = accuracy(state, batch)
    train_acc_list.append(acc)

  for batch in test_ds.as_numpy_iterator():
    acc = accuracy(state, batch)
    test_acc_list.append(acc)
  
  train_acc = sum(train_acc_list) / len(train_acc_list)
  test_acc = sum(test_acc_list) / len(test_acc_list)

  print(f'[epoch {epoch + 1}] train acc: {train_acc}, test acc: {test_acc}')

[epoch 1] train acc: 0.6185528635978699, test acc: 0.571093738079071
[epoch 2] train acc: 0.7406662106513977, test acc: 0.677734375
[epoch 3] train acc: 0.7697538137435913, test acc: 0.680468738079071
[epoch 4] train acc: 0.9041579365730286, test acc: 0.770703136920929
[epoch 5] train acc: 0.8265610337257385, test acc: 0.71875
[epoch 6] train acc: 0.9856875538825989, test acc: 0.8402343988418579
[epoch 7] train acc: 0.9886244535446167, test acc: 0.826953113079071
[epoch 8] train acc: 0.9963389039039612, test acc: 0.8648437261581421
[epoch 9] train acc: 0.9972541928291321, test acc: 0.85546875
[epoch 10] train acc: 0.9988232254981995, test acc: 0.8675781488418579


In [35]:
params_save_path = 'checkpoints/params.json'
params_dict = {k1: {k2: v2.tolist() for k2, v2 in v1.items()} for k1, v1 in state.params.items()}
with open(params_save_path, 'w') as f:
  json.dump(params_dict, f)

batch_stats_save_path = 'checkpoints/batch_stats.json'
batch_stats_dict = {k1: {k2: v2.tolist() for k2, v2 in v1.items()} for k1, v1 in state.batch_stats.items()}
with open(batch_stats_save_path, 'w') as f:
  json.dump(batch_stats_dict, f)