In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
# Useful parameters for reproducibility
# SEED = 1234
# torch.manual_seed(SEED)
# torch.backends.cudnn.deterministic = True
from sklearn.model_selection import train_test_split

import random, os
from tqdm import tqdm

from torchvision import datasets, transforms, utils
from os import walk
from os.path import join, normpath
import pretty_midi

from math import floor
import time
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import pandas as pd

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Here, we define our data loader.

In [None]:
class PianoMusic(data.Dataset):
  def __init__(self, midi_dir=None, transform=None):
    super().__init__()
    if midi_dir is None: # Empty dataset
      self.file_list = []
      self.loaded = []
    else: # Non-empty dataset
      self.file_list = list([f"{root}\{f}" for root,d_names,f_names in os.walk(midi_dir) for f in f_names])
      self.loaded = [None] * len(self.file_list)
    self.transform = transform

  def __len__(self):
      return len(self.file_list)


  def __getitem__(self, index):
      try:
        if self.loaded[index] is None:
          file_name = self.file_list[index]
          music = torch.Tensor(pretty_midi.PrettyMIDI(file_name).get_piano_roll())

          music = torch.transpose(music, 0, 1) # original data has dim 0 = notes, dim 1 = time step, we want the opposite
          
          self.loaded[index] = music

        return self.loaded[index] if self.transform is None else self.transform(self.loaded[index])
      except IndexError:
        raise IndexError("Item does not exist, have you loaded the MIDI files correctly?")

  def _create_from_self(self, file_list):
    new_dataset = PianoMusic()
    new_dataset.file_list = file_list
    new_dataset.loaded = [None] * len(file_list)
    new_dataset.transform = self.transform
    return new_dataset


  def get_path(self, index):
    try:
      file_name = self.file_list[index]
      return file_name
    except IndexError:
      raise IndexError("Item does not exist, have you loaded the MIDI files correctly?")

  def splits(self, test_size=0.3):
    train_files, test_files = train_test_split(self.file_list, test_size=test_size)
    return self._create_from_self(train_files), self._create_from_self(test_files)

Here, we create our sample and we also define a collate_fn. A collate_fn is used a soart of a way to moddify de input datas. We need it because in our case, not every tensor is the same length (some music can be longer than other). We also define the transformer we will use on our datas. Here, we are doing some padding tu make sure every tensor in the sample is the same size (we add zeors to the smaller tensors to make them allign with the size of the biggest

In [None]:
def collate_fn(batch): # In a batch, not all samples are the same size, we need them to
  lengths = [music.shape[0] for music in batch]
  return pad_sequence(batch, padding_value=0), torch.tensor(lengths)

def transform(element): # Turns all velocity values to to 0-1, and extract piano notes
  element = torch.where(element != 0, 1, 0)
  element = element[:, 21:109] # 109-21 = 88 => the piano notes of a piano roll
  element = element.float()
  totals = element.sum(dim=1).reshape((-1,1)) # We want vectors to sum to one
  totals[totals == 0] = 1
  return element / totals

Here is the function we use to plot some pianoroll (if we want to "inspect" the data). This function is able to only plot ONE pianoroll at a time ! It would be cool to modify it to be able to plot several pianoroll in the same display (but not of prior importance)

In [None]:
def plotPianoRoll(pianoroll, title):
    to_plot = torch.transpose(pianoroll, 0, 1)
    fig, ax = plt.subplots(figsize=(160, 60))
    ax.imshow(to_plot, cmap='binary', interpolation='nearest')
    ax.set_title(title)
    ax.invert_yaxis()
    ax.set_xlabel('Nbr Timesteps')
    ax.set_ylabel('Note value')
    plt.show()

def create_midi(filename, pianoroll, instrument):
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.instrument_name_to_program(instrument)
    piano = pretty_midi.Instrument(program=instrument)
    velocity = 90
    transposed = torch.transpose(pianoroll, 0, 1)
    time_step = 0.01
    for i,note in enumerate(transposed):
        uniques, counts = torch.unique_consecutive(note, return_counts=True)
        time = 0
        for j, unique in enumerate(uniques):
            if unique != 0:
                midi_note = pretty_midi.Note(
                    velocity=velocity,
                    pitch=21+i,
                    start=time,
                    end=time + time_step * counts[j].item()
                )
                piano.notes.append(midi_note)
            time += time_step * counts[j].item()
    pm.instruments.append(piano)
    pm.write(filename)

def generate(network, first_chords, length, window_size=None):
    pianoroll = first_chords
    for i in range(length):
        if window_size:
            current_cords = pianoroll[-window_size:]
        else:
            current_cords = pianoroll
        result = network(current_cords)
        pianoroll = torch.cat([pianoroll, result.detach()], dim=0)

    return pianoroll.detach()

In [None]:
class LSTMMusic(nn.Module):
    def __init__(self, note_dim, hidden_dim, output_dim, n_layers, dropout):
        super().__init__()

        self.rnn = nn.LSTM(note_dim, hidden_dim, num_layers=n_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.output_softmax = nn.Softmax(dim=-1)

    def forward(self, musics, lengths=None):
        if lengths is None: # No batch
            output, hn = self.rnn(musics)
            output = output[-1:]
        else: # batch
            packed_musics = nn.utils.rnn.pack_padded_sequence(musics, lengths, enforce_sorted=False)      
            packed_output, (hidden, cell) = self.rnn(packed_musics)
            output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
            # output is padded, longest_seq_dim x batch_size x note_dim
            # we have 1 x batch_size x note_dim, we took last note of each batch, we then just return element 0 => batch_size x note_dim
            output = torch.stack([output[last - 1] for last in output_lengths])[0]
        output = self.fc(output)
        return self.output_softmax(output)

class GRUMusic(nn.Module):
    def __init__(self, note_dim, hidden_dim, output_dim, n_layers, dropout):
        super().__init__()

        self.rnn = nn.GRU(note_dim, hidden_dim, num_layers=n_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.output_softmax = nn.Softmax(dim=-1)

    def forward(self, musics, lengths=None):
        if lengths is None: # No batch
            output, hn = self.rnn(musics)
            output = output[-1:]
        else: # batch
            packed_musics = nn.utils.rnn.pack_padded_sequence(musics, lengths, enforce_sorted=False)      
            packed_output, hidden = self.rnn(packed_musics)
            output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
            # output is padded, longest_seq_dim x batch_size x note_dim
            # we have 1 x batch_size x note_dim, we took last note of each batch, we then just return element 0 => batch_size x note_dim
            output = torch.stack([output[last - 1] for last in output_lengths])[0]
        output = self.fc(output)
        return self.output_softmax(output)

class RNNMusic(nn.Module):
    def __init__(self, note_dim, hidden_dim, output_dim, n_layers, dropout, nonlinearity="tanh"):
        super().__init__()

        self.rnn = nn.RNN(note_dim, hidden_dim, num_layers=n_layers, dropout=dropout, nonlinearity=nonlinearity)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.output_softmax = nn.Softmax(dim=-1)

    def forward(self, musics, lengths=None):
        if lengths is None: # No batch
            output, hn = self.rnn(musics)
            output = output[-1:]
        else: # batch
            packed_musics = nn.utils.rnn.pack_padded_sequence(musics, lengths, enforce_sorted=False)      
            packed_output, hn = self.rnn(packed_musics)
            output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)
            # output is padded, longest_seq_dim x batch_size x note_dim
            # we have 1 x batch_size x note_dim, we took last note of each batch, we then just return element 0 => batch_size x note_dim
            output = torch.stack([output[last - 1] for last in output_lengths])[0]
        output = self.fc(output)
        return self.output_softmax(output)

In [None]:
NOTE_DIM = 88
HIDDEN_DIM = 128
OUTPUT_DIM = 88
N_LAYERS = 3
DROPOUT = 0.0

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def train(model, dataloader, optimizer, criterion, step_size=4, window_size=None):
    epoch_loss = 0
    loop_count = 0

    model.train()

    for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), leave=False):
        musics, musics_lengths = batch
        musics = musics.to(DEVICE)
        updated_lengths = torch.stack([musics_lengths, torch.zeros(musics_lengths.size())])
        for j, batch_next_notes in tqdm(enumerate(musics[1::step_size]), total=int(len(musics) / step_size), position = 1, leave=False):
            optimizer.zero_grad()
            if window_size is None:
                updated_lengths[1] = 1 + j * step_size
                input = musics[0:1 + j * step_size]
            elif window_size > 1 + j * step_size: # there is not enough chords for the window
                continue
            else: # A window size fixes the length of the input, but we do not want to repropagate padding
                updated_lengths[1] = window_size
                input = musics[1 + j * step_size - window_size:1 + j * step_size]

            # The length of a music is the minimum between its unpadded length and the current loop
            predictions = model(input, updated_lengths.min(dim=0)[0])

            if window_size: # We ignore the loss computed when the window size includes padding
                old_reduction = criterion.reduction
                criterion.reduction = 'none'
                unreduced_loss = criterion(predictions, batch_next_notes)
                for k,length in enumerate(musics_lengths):
                    if musics_lengths[k] <= 1 + j * step_size:
                        unreduced_loss[k] = 0
                loss = unreduced_loss.mean()
                criterion.reduction = old_reduction
            else:
                loss = criterion(predictions, batch_next_notes)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            loop_count += 1

    return epoch_loss / loop_count

def evaluate(model, dataloader, criterion, step_size=4, window_size=None):
    epoch_loss = 0
    loop_count = 0

    model.eval()

    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), leave=False):
            musics, musics_lengths = batch
            musics = musics.to(DEVICE)
            updated_lengths = torch.stack([musics_lengths, torch.zeros(musics_lengths.size())])
            for j, batch_next_notes in tqdm(enumerate(musics[1::step_size]), total=int(len(musics) / step_size), leave=False):
                if window_size is None:
                    updated_lengths[1] = 1 + j * step_size
                    input = musics[0:1 + j * step_size]
                elif window_size > 1 + j * step_size: # there is not enough chords for the window
                    continue
                else: # A window size fixes the length of the input, but we do not want to repropagate padding
                    updated_lengths[1] = window_size
                    input = musics[1 + j * step_size - window_size:1 + j * step_size]

                # The length of a music is the minimum between its unpadded length and the current loop
                predictions = model(input, updated_lengths.min(dim=0)[0])

                if window_size: # We ignore the loss computed when the window size includes padding
                    old_reduction = criterion.reduction
                    criterion.reduction = 'none'
                    unreduced_loss = criterion(predictions, batch_next_notes)
                    for k,length in enumerate(musics_lengths):
                        if musics_lengths[k] <= 1 + j * step_size:
                            unreduced_loss[k] = 0
                    loss = unreduced_loss.mean()
                    criterion.reduction = old_reduction
                else:
                    loss = criterion(predictions, batch_next_notes)
                epoch_loss += loss.item()
                loop_count += 1
    return epoch_loss / loop_count

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs
  
def compute_epochs(model, train_dataloader, test_dataloader, criterion, optimizer, n_epochs, step_size=4, window_size=False, best_path=None, last_path=None, verbose=True):
  best_test_loss = float('inf')
  
  criterion = criterion.to(DEVICE)
  
  train_losses = []
  test_losses = []
  epoch_times = []

  for epoch in range(n_epochs):
    start_time = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, step_size=step_size, window_size=window_size)
    test_loss = evaluate(model, test_dataloader, criterion, step_size=step_size, window_size=window_size)

    end_time = time.time()
    

    if best_path and test_loss < best_test_loss:
      best_test_loss = test_loss
      torch.save(model.state_dict(), best_path)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    epoch_times.append(end_time - start_time)

    if verbose:
      epoch_mins, epoch_secs = epoch_time(start_time, end_time)
      print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
      print(f'\tTrain Loss: {train_loss}')
      print(f'\t Test Loss: {test_loss}')
  
  if last_path:
      torch.save(model.state_dict(), last_path)
    
  return model, train_losses, test_losses, epoch_times

class ModelProvider():
  def __init__(self):
    self.LSTMs = {}
    self.GRUs = {}
    self.RNNs = {}
  
  def LSTM(self, index):
    return self._fetch_model_(self.LSTMs, LSTMMusic, index)
  
  def GRU(self, index):
    return self._fetch_model_(self.GRUs, GRUMusic, index)
  
  def RNNtanh(self, index):
    return self._fetch_model_(self.RNNs, RNNMusic, index, nonlinearity='tanh')
  
  def RNNrelu(self, index):
    return self._fetch_model_(self.RNNs, RNNMusic, index, nonlinearity='relu')
  
  def _fetch_model_(self, models, ModelClass, index, **kwargs):
    if index in models:
      return models[index]
    else:
      models[index] = ModelClass(NOTE_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, DROPOUT, **kwargs)
      models[index].to(DEVICE)
      return models[index]

def generate_models(train_dataloader, test_dataloader, n_epochs=8):
  mp = ModelProvider()
  params_list = []
  # model_funcs = {'RNN-tanh': mp.RNNtanh, 'RNN-relu': mp.RNNrelu, 'LSTM': mp.LSTM, 'GRU': mp.GRU}
  model_funcs = {'RNN-relu': mp.RNNrelu}
  # criterions = {'BCE': nn.BCELoss(), 'MSE': nn.MSELoss(), 'CE': nn.CrossEntropyLoss()}
  criterions = {'MSE': nn.MSELoss()}
  optimizer_funcs = {'ADAM': optim.Adam}
  # step_sizes = {'s100': 100, 's200': 200, 's400': 400, 's800': 800}
  step_sizes = {'s100': 100, 's200': 200, 's400': 400, 's800': 800}
  # window_sizes = {'wNone': None, 'w100': 100, 'w200': 200, 'w400': 400, 'w800': 800}
  window_sizes = {'wNone': None, 'w100': 100, 'w200': 200, 'w400': 400, 'w800': 800}
  
  index=0
  for k0, model_func in model_funcs.items():
    for k1, criterion in criterions.items():
      for k2, optimizer_func in optimizer_funcs.items():
        for k3, step_size in step_sizes.items():
          for k4, window_size in window_sizes.items():
            params_list.append({'name': f'{k0}-{k1}-{k2}-{k3}-{k4}', 'model': model_func(index), 'criterion': criterion, 'optimizer': optimizer_func(model_func(index).parameters()), 'step_size': step_size, 'window_size': window_size})
            index += 1
  
  directory = f'models-generation-{n_epochs}epochs-{time.time()}'
  os.makedirs(directory)

  for i, params_row in enumerate(params_list):
    print(f'Processing model {i+1}/{len(params_list)} {params_row["name"]}')
    os.makedirs(f'{directory}/{params_row["name"]}')
    model, train_losses, test_losses, epoch_times = compute_epochs(params_row['model'], train_dataloader, test_dataloader, params_row['criterion'], params_row['optimizer'], n_epochs, 
                                                                   step_size=params_row['step_size'],
                                                                   window_size=params_row['window_size'], 
                                                                   best_path=f'{directory}/{params_row["name"]}/{params_row["name"]}-best.pt',
                                                                   last_path=f'{directory}/{params_row["name"]}/{params_row["name"]}-last.pt', 
                                                                   verbose=True)
    
    # Saving losses and times to CSV
    pd.DataFrame.from_dict({'train_loss':train_losses, 'test_loss': test_losses, 'epoch_time': epoch_times}).to_csv(f'{directory}/{params_row["name"]}/{params_row["name"]}.csv')
    
    # Sampling a music for generation
    dataset = PianoMusic("./js/mini", transform=transform)
    random_sample = dataset[random.randint(0, len(dataset)-1)]
    
    model = model.cpu()
    # Generating a music based on a few notes of the sample
    roll = generate(model, random_sample[0:1000], 1000, params_row['window_size'])
    # Cleaning the result
    pianoroll = torch.where(roll >= 1/16, 1, 0)
    x_np = pianoroll.detach().numpy()
    x_df = pd.DataFrame(x_np)
    # Saving the generation as a CSV
    x_df.to_csv(f'{directory}/{params_row["name"]}/{params_row["name"]}-sample-last.csv')
    # Saving the generation as a MIDI
    create_midi(f'{directory}/{params_row["name"]}/{params_row["name"]}-sample-last.mid', pianoroll, 'Acoustic Grand Piano')
    
    # Reloading the model with its best one (best test loss)
    model = params_row['model']
    model.load_state_dict(torch.load(f'{directory}/{params_row["name"]}/{params_row["name"]}-best.pt'))
    model = model.cpu()
    # Generating a music based on a few notes of the sample
    roll = generate(model, random_sample[0:1000], 1000, params_row['window_size'])
    # Cleaning the result
    pianoroll = torch.where(roll >= 1/16, 1, 0)
    x_np = pianoroll.detach().numpy()
    x_df = pd.DataFrame(x_np)
    # Saving the generation as a CSV
    x_df.to_csv(f'{directory}/{params_row["name"]}/{params_row["name"]}-sample-best.csv')
    # Saving the generation as a MIDI
    create_midi(f'{directory}/{params_row["name"]}/{params_row["name"]}-sample-best.mid', pianoroll, 'Acoustic Grand Piano')

dataset = PianoMusic("./js/all", transform=transform)

train_dataset, test_dataset = dataset.splits()


train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=0)
test_loader = data.DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=0)

generate_models(train_loader, test_loader, n_epochs=25)