<a href="https://colab.research.google.com/github/cloughurd/deep-piano/blob/master/Wav2Mid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
## Pulling ideas from https://github.com/jsleep/wav2mid
## Data from http://www.tsi.telecom-paristech.fr/aao/en/2010/07/08/maps-database-a-piano-database-for-multipitch-estimation-and-automatic-transcription-of-music/

In [15]:
from google.colab import drive
drive.mount('/content/gdrive')
model_dir = '/content/gdrive/My Drive/Winter 2020/DL/models/'
data_dir = '/content/gdrive/My Drive/Winter 2020/DL/data/'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
## Downloads full maps dataset
!wget https://amubox.univ-amu.fr/s/iNG0xc5Td1Nv4rR/download

## Downloads maps subset
# !wget http://students.cs.byu.edu/~bclough/maps.zip

In [0]:
## Unzips the full maps dataset
!unzip -q download
!rm download
!mkdir data

import os
from zipfile import ZipFile

for filename in os.listdir('MAPS/'):
  if 'zip' in filename:
    with ZipFile('MAPS/' + filename, 'r') as z:
      z.extractall('data/' + filename.split('.')[0])

!rm -r MAPS/

In [0]:
## Unzips maps subset from mounted drive
!unzip -q /content/gdrive/My\ Drive/Winter\ 2020/DL/data/maps.zip -d /content/

In [1]:
!pip3 install torch 
!pip3 install torchvision
!pip3 install tqdm
!pip3 install pysoundfile



In [0]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

import pandas as pd
import soundfile as sf
import pretty_midi
import numpy as np
import librosa
import glob
import random
from tqdm import tqdm

assert torch.cuda.is_available()

In [0]:
lowest_key = 21
highest_key = 108
octave_size = 12
desired_sr = 22050
window_size = 7
pretty_midi.pretty_midi.MAX_TICK = 1e10

def wav_to_input(fn, bin_multiple=3):
  bins_per_octave = bin_multiple * octave_size
  num_bins = (highest_key+1 - lowest_key) * bin_multiple
  
  audio, _ = librosa.load(fn, desired_sr)
  cqt = librosa.cqt(audio, desired_sr, fmin=librosa.midi_to_hz(lowest_key), bins_per_octave=bins_per_octave, n_bins=num_bins)
  del audio
  cqt = cqt.T # Puts time dim first
  cqt = np.abs(cqt)
  min_fq = np.min(cqt)
  cqt = np.pad(cqt, ((window_size//2, window_size//2),(0,0)), 'constant', constant_values=min_fq)

  # This sets up a matrix where at each time step we have a 7 (window_size) frame snippet from which to pull piano pitches
  windows = []
  for i in range(len(cqt) - window_size + 1):
    windows.append(cqt[i:i+window_size, :])
  cqt = np.array(windows)
  return cqt

def midi_to_output(midi, x):
  times = librosa.frames_to_time(np.arange(len(x)), desired_sr)
  roll = midi.get_piano_roll(desired_sr, times)
  roll = roll[lowest_key: highest_key+1]
  roll = roll.T # Puts time dim first
  roll[roll > 0] = 1
  return roll

In [0]:
class MapsDataset(Dataset):
  def __init__(self, root, chunk_size=600, subset=True):
    if subset:
      self.wav_files = glob.glob(root + '*.wav')
    else:
      self.wav_files = glob.glob(root + '*/*/MUS/MAPS_MUS*.wav')
    self.chunk_size = chunk_size
  def __getitem__(self, i):
    # x, sr = sf.read(self.wav_files[i])
    x = wav_to_input(self.wav_files[i])
    midi_filename = self.wav_files[i].split('.')[0] + '.mid'
    y = pretty_midi.PrettyMIDI(midi_filename)
    y = midi_to_output(y, x)
    if len(y) <= self.chunk_size:
      return x, y
    start = random.randint(0, len(y)-self.chunk_size)
    x = x[start:start+self.chunk_size, :]
    y = y[start:start+self.chunk_size, :]
    return x, y
  def __len__(self):
    return len(self.wav_files)

In [0]:
dataset = MapsDataset('MAPS_MUS/')
loader = DataLoader(dataset, shuffle=True)

In [0]:
class ConvBlock(nn.Module):
  def __init__(self, in_c, out_c, kernel_size=3, padding=1):
    super(ConvBlock, self).__init__()
    self.net = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_c),
        nn.ReLU(),
        nn.Conv2d(out_c, out_c, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_c),
        nn.ReLU(),
        nn.Conv2d(out_c, out_c, kernel_size=kernel_size, padding=padding),
        nn.BatchNorm2d(out_c),
        nn.Dropout2d()
    )
    if in_c != out_c:
      self.skip = nn.Conv2d(in_c, out_c, kernel_size=1)
    else:
      self.skip = nn.Identity()
    self.final = nn.ReLU()
  def forward(self, x):
    res = self.net(x)
    y = self.skip(x) + res
    return self.final(y)

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.net = nn.Sequential(
        ConvBlock(1, 3),
        ConvBlock(3, 8),
        ConvBlock(8, 8),
        ConvBlock(8, 16),
        nn.MaxPool2d((1, 2)),
        ConvBlock(16, 16),
        ConvBlock(16, 32),
        ConvBlock(32, 32),
        ConvBlock(32, 64),
        nn.MaxPool2d((1, 2)),
        ConvBlock(64, 64),
        ConvBlock(64, 64),
        ConvBlock(64, 64),
        nn.MaxPool2d((1, 2)),
        ConvBlock(64, 128),
        ConvBlock(128, 128),
        nn.AvgPool2d((7, 33))
    )
    self.final = nn.Sequential(
        nn.Linear(128, 88),
        nn.Sigmoid()
    )
  def forward(self, x):
    y = self.net(x)
    y = y.squeeze(2).squeeze(2)
    return self.final(y)

In [0]:
net = Net().cuda()
# net = torchvision.models.resnet50(pretrained=True)
# num_f = net.fc.in_features
# net.fc = nn.Sequential(
#     nn.Linear(num_f, 88),
#     torch.Sigmoid()
# )
# net = net.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=3e-3)
objective = nn.BCELoss(reduction='sum')
losses = []

In [0]:
def train(num_epochs=25, save_freq=1):
  for i in range(num_epochs):
    loop = tqdm(total=len(loader), position=0, leave=False)
    for x, y in loader:
      x = x.squeeze(0).unsqueeze(1).float().cuda()
      y = y.squeeze(0).cuda()

      optimizer.zero_grad()
      y_hat = net(x)
      loss = objective(y_hat, y.float())
      loss.backward()
      optimizer.step()

      losses.append(loss.item())      
      
      loop.set_description('epoch:{}, loss:{:.4f}'.format(i, loss.item()))
      loop.update(1)
      
    if i % save_freq == 0:
      torch.save(net, model_dir + 'transcriber' + str(i) + '.mod')
  return losses

In [0]:
train()

plt.plot(losses)
plt.show()

epoch:0, loss:7455.4170:  92%|█████████▏| 138/150 [58:10<05:58, 29.88s/it]