In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
import librosa
import librosa.display
import IPython.display as ipd
import sounddevice as sd
import json

import jax
from jax import numpy as jnp, Array
from jax.typing import ArrayLike
import flax
from flax import linen as nn
from flax.core import FrozenDict

In [None]:
metadata = pd.read_csv('musicnet/musicnet_metadata.csv')
labels = metadata['ensemble'].values
labels_to_nums = {label: i for i, label in enumerate(sorted(set(labels)))}
nums_to_labels = {i: label for label, i in labels_to_nums.items()}

In [None]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x: ArrayLike, training: bool) -> Array:
    x = (nn.Conv(features=8, kernel_size=(3, 3), use_bias=False))(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=8, kernel_size=(3, 3), use_bias=False)(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = nn.Conv(features=8, kernel_size=(3, 3), use_bias=False)(x)
    x = nn.BatchNorm(use_running_average=not training)(x)
    x = nn.relu(x)
    x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    
    x = nn.Dense(features=128)(x)
    x = nn.relu(x)

    x = nn.Dense(features=64)(x)
    x = nn.relu(x)

    x = nn.Dense(features=21)(x)
    return x

In [None]:
params_load_path = 'checkpoints/cnn-params.json'
with open(params_load_path, 'r') as f:
  loaded_params_dict = json.load(f)

params = FrozenDict({
  k1: FrozenDict({
    k2: jnp.array(v2) for k2, v2 in v1.items()
  }) for k1, v1 in loaded_params_dict.items()
})

batch_stats_load_path = 'checkpoints/cnn-batch_stats.json'
with open(batch_stats_load_path, 'r') as f:
  loaded_batch_stats_dict = json.load(f)

batch_stats = FrozenDict({
  k1: FrozenDict({
    k2: jnp.array(v2) for k2, v2 in v1.items()
  }) for k1, v1 in loaded_batch_stats_dict.items()
})

model = CNN()

In [None]:
n_mels = 512

def wav_to_mel_spec(path: str) -> np.ndarray:
  y, sr = librosa.load(path)
  spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
  return np.abs(librosa.amplitude_to_db(spec, ref=np.max))

In [None]:
data_files = glob('musicnet/musicnet/*/*.wav')
data = jnp.array([wav_to_mel_spec(path)[:, :512].reshape(512, 512, 1) for path in data_files[:4]])

In [None]:
logits = model.apply({
  'params': params,
  'batch_stats': batch_stats,
}, x=data, training=False)

[nums_to_labels[int(i)] for i in jnp.argmax(logits, axis=1)]

In [None]:
ipd.Audio(data_files[3])

In [None]:
duration = 60
sr = 22050
y = sd.rec(int(duration * sr), samplerate=sr, channels=1, dtype=np.float32)
sd.wait()
y = y.reshape(len(y))
spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
x = np.abs(librosa.amplitude_to_db(spec, ref=np.max))

In [None]:
x = x[:, 512:1024].reshape(1, 512, 512, 1)

In [None]:
logit = model.apply({
  'params': params,
  'batch_stats': batch_stats,
}, x=data, training=False)

nums_to_labels[int(jnp.argmax(logit, axis=1)[0])]

In [None]:
sd.play(y, sr)
sd.wait()