# Symbolic Music Generation

The goal of this homework is understanding a simple RNN-based musical langauge model to generate performance MIDI.

## Download the MIDI file

We are going to use a single MIDI file from Saarland Music Data (SMD)

In [None]:
# # Download the audio files
# !gdown 1ORJ5ZYLYkL4NYLJXtcElKQRstxKL13NU

In [None]:
# %%capture
# !unzip gct634-SMD-MIDI.zip

## Check the MIDI files

Install python packages to handle MIDI

In [None]:
# %%capture
# !sudo apt install -y fluidsynth
# !pip install --upgrade pyfluidsynth
# !pip install pretty_midi

Open an example of the MIDI files and inspect the notes events.

In [None]:
import pandas as pd
import pretty_midi
import collections
import numpy as np

def midi_to_notes(pm) -> pd.DataFrame:
  instrument = pm.instruments[0]
  notes = collections.defaultdict(list)

  # Sort the notes by start time
  sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
  prev_start = sorted_notes[0].start

  for note in sorted_notes:
    start = note.start
    end = note.end
    notes['pitch'].append(note.pitch)
    notes['start'].append(start)
    notes['end'].append(end)
    notes['step'].append(start - prev_start)
    notes['duration'].append(end - start)
    notes['velocity'].append(note.velocity)
    prev_start = start

  return pd.DataFrame({name: np.array(value) for name, value in notes.items()})

midi_file = 'gct634-SMD-MIDI/Bach_BWV888-01_008_20110315-SMD.mid'
pm = pretty_midi.PrettyMIDI(midi_file)
raw_notes = midi_to_notes(pm)
raw_notes.head(10)

Let's listen to the MIDI file.

In [None]:
import fluidsynth
import pretty_midi
import IPython.display as ipd

sr = 44100
pm = pretty_midi.PrettyMIDI(midi_file)
synth_audio = pm.fluidsynth(fs=sr)

# play the 10-sec segment from the beginning
synth_audio_seg = synth_audio[:10*sr]
ipd.Audio(synth_audio_seg, rate=sr)

## MIDI-Like Tokenizer

"EventSeq" is the core class that implements the MIDI-Like Tokenizer. It takes MIDI note messages and tokenize them into four classes of tokens (NOTE_ON, NOTE_OFF, VELOCITY, and TIME_SHIFT). The code is quite long. Let's try to under stand a little by a little.  

In [None]:
from pretty_midi import PrettyMIDI, Note, Instrument

USE_VELOCITY = 1
DEFAULT_VELOCITY = 64
DEFAULT_NOTE_LENGTH = 1
MIN_NOTE_LENGTH = 0.05

class Event:
    def __init__(self, type, time, value):
        self.type = type
        self.time = time
        self.value = value

class EventSeq:
    pitch_range = range(21, 109)
    velocity_range = range(21, 109)
    velocity_steps = 32
    time_shift_bins = 1.15 ** np.arange(32) / 65

    @staticmethod
    def from_note_seq(notes):
        note_events = []
        if USE_VELOCITY:
            velocity_bins = EventSeq.get_velocity_bins()

        for note in notes:
            if note.pitch in EventSeq.pitch_range:
                if USE_VELOCITY:
                    velocity = note.velocity
                    velocity = max(velocity, EventSeq.velocity_range.start)
                    velocity = min(velocity, EventSeq.velocity_range.stop - 1)
                    velocity_index = np.searchsorted(velocity_bins, velocity)
                    note_events.append(Event('velocity', note.start, velocity_index))

                pitch_index = note.pitch - EventSeq.pitch_range.start
                note_events.append(Event('note_on', note.start, pitch_index))
                note_events.append(Event('note_off', note.end, pitch_index))

        note_events.sort(key=lambda event: event.time)  # stable
        events = []

        for i, event in enumerate(note_events):
            events.append(event)

            if event is note_events[-1]:
                break

            interval = note_events[i + 1].time - event.time
            shift = 0

            while interval - shift >= EventSeq.time_shift_bins[0]:
                index = np.searchsorted(EventSeq.time_shift_bins,
                                        interval - shift, side='right') - 1
                events.append(Event('time_shift', event.time + shift, index))
                shift += EventSeq.time_shift_bins[index]

        return EventSeq(events)

    @staticmethod
    def from_array(event_indeces):
        time = 0
        events = []
        for event_index in event_indeces:
            for event_type, feat_range in EventSeq.feat_ranges().items():
                if feat_range.start <= event_index < feat_range.stop:
                    event_value = event_index - feat_range.start
                    events.append(Event(event_type, time, event_value))
                    if event_type == 'time_shift':
                        time += EventSeq.time_shift_bins[event_value]
                    break

        return EventSeq(events)

    @staticmethod
    def dim():
        return sum(EventSeq.feat_dims().values())

    @staticmethod
    def feat_dims():
        feat_dims = collections.OrderedDict()
        feat_dims['note_on'] = len(EventSeq.pitch_range)
        feat_dims['note_off'] = len(EventSeq.pitch_range)
        if USE_VELOCITY:
          feat_dims['velocity'] = EventSeq.velocity_steps
        feat_dims['time_shift'] = len(EventSeq.time_shift_bins)
        return feat_dims

    @staticmethod
    def feat_ranges():
        offset = 0
        feat_ranges = collections.OrderedDict()
        for feat_name, feat_dim in EventSeq.feat_dims().items():
            feat_ranges[feat_name] = range(offset, offset + feat_dim)
            offset += feat_dim
        return feat_ranges

    @staticmethod
    def get_velocity_bins():
        n = EventSeq.velocity_range.stop - EventSeq.velocity_range.start
        return np.arange(
            EventSeq.velocity_range.start,
            EventSeq.velocity_range.stop,
            n / (EventSeq.velocity_steps - 1))

    def __init__(self, events=[]):
        for event in events:
            assert isinstance(event, Event)

        self.events = copy.deepcopy(events)

        # compute event times again
        time = 0
        for event in self.events:
            event.time = time
            if event.type == 'time_shift':
                time += EventSeq.time_shift_bins[event.value]

    def to_note_seq(self):
        time = 0
        notes = []

        velocity = DEFAULT_VELOCITY
        velocity_bins = EventSeq.get_velocity_bins()

        last_notes = {}

        for event in self.events:
            if event.type == 'note_on':
                pitch = event.value + EventSeq.pitch_range.start
                note = Note(velocity, pitch, time, None)
                notes.append(note)
                last_notes[pitch] = note

            elif event.type == 'note_off':
                pitch = event.value + EventSeq.pitch_range.start

                if pitch in last_notes:
                    note = last_notes[pitch]
                    note.end = max(time, note.start + MIN_NOTE_LENGTH)
                    del last_notes[pitch]

            elif event.type == 'velocity':
                index = min(event.value, velocity_bins.size - 1)
                velocity = velocity_bins[index]

            elif event.type == 'time_shift':
                time += EventSeq.time_shift_bins[event.value]

        for note in notes:
            if note.end is None:
                note.end = note.start + DEFAULT_NOTE_LENGTH

            note.velocity = int(note.velocity)

        return notes

    def to_array(self):
        feat_idxs = EventSeq.feat_ranges()
        idxs = [feat_idxs[event.type][event.value] for event in self.events]
        dtype = np.uint8 if EventSeq.dim() <= 256 else np.uint16
        return np.array(idxs, dtype=dtype)


Let's load the MIDI file and tokenize it.

In [None]:
import pretty_midi
import copy, itertools, collections

midi_file = 'gct634-SMD-MIDI/Bach_BWV888-01_008_20110315-SMD.mid'
pm = pretty_midi.PrettyMIDI(midi_file)
instrument = pm.instruments[0]
sorted_notes = sorted(instrument.notes, key=lambda note: note.start)

# shift the first note to 0 sec
time_shift = sorted_notes[0].start
for note in sorted_notes:
  note.start -= time_shift
  note.end -= time_shift

# print out note sequences for checking
for i in range(10):
  print(sorted_notes[i])

# translates the note sequences into the four classes (NOTE_ON, NOTE_OFF, VELOCITY, and TIME_SHIFT) of events.
event_seq = EventSeq.from_note_seq(sorted_notes)

# encodes the four classes of events into token indices.
event_data = event_seq.to_array()

print("--- tokens ---")
print(event_data[:20])

Let's convert the token indices back to MIDI note events.   

In [None]:
event_seq = EventSeq.from_array(event_data)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

for i in range(10):
  print(output_notes[i])

The MIDI note events have changed a little. Why? Let's listen to them. It's quite subtle.  

In [None]:
DEFAULT_RESOLUTION = 220
DEFAULT_TEMPO = 120
DEFAULT_SAVING_PROGRAM = 1

# make a prettyMIDI object
midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
midi.instruments.append(inst)

synth_audio = midi.fluidsynth(fs=sr)
synth_audio_seg = synth_audio[:10*sr]
ipd.Audio(synth_audio_seg, rate=sr)

## Preprocessing MIDI files
Once you understand the tokenization, let's store the tokenized MIDI files as separate token files(".data")

In [None]:
import os
import random
import torch

def find_files_by_extensions(root, exts=[]):
    def _has_ext(name):
        if not exts:
            return True
        name = name.lower()
        for ext in exts:
            if name.endswith(ext):
                return True
        return False
    for path, _, files in os.walk(root):
        for name in files:
            if _has_ext(name):
                yield os.path.join(path, name)

# Load MIDI files and tokenize it to store as a ".data"
midi_path = 'gct634-SMD-MIDI/'
midi_files = list(find_files_by_extensions(midi_path, ['.mid']))

save_path = 'gct634-SMD-data/'
os.makedirs(save_path, exist_ok=True)
out_fmt = '{}.data'

for midi_file in midi_files:
  pm = pretty_midi.PrettyMIDI(midi_file)
  instrument = pm.instruments[0]
  sorted_notes = sorted(instrument.notes, key=lambda note: note.start)

  # shift the first note to 0 sec
  time_shift = sorted_notes[0].start
  for note in sorted_notes:
    note.start -= time_shift
    note.end -= time_shift

  # translates the note sequences into the four classes (NOTE_ON, NOTE_OFF, VELOCITY, and TIME_SHIFT) of events.
  event_seq = EventSeq.from_note_seq(sorted_notes)

  # encodes the four classes of events into token indices.
  event_data = event_seq.to_array()

  # save the tokenized file
  name = os.path.basename(midi_file).split('.')[0]
  save_name = os.path.join(save_path,out_fmt.format(name))
  torch.save(event_data, save_name)


## Dataloader

In [None]:
from torch.utils.data import Dataset

class EventDataset(Dataset):
    def __init__(self, window_size):
        super(EventDataset, self).__init__()
        self.window_size = window_size
        self.data_path = 'gct634-SMD-data/'
        self.data_list = list(find_files_by_extensions(self.data_path, ['.data']))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        file_name = os.path.join(self.data_list[idx])
        data = torch.load(file_name, weights_only=False)
        random_int = random.randint(0, len(data) - self.window_size)
        data = data[random_int:random_int+self.window_size]
        return torch.LongTensor(data)

## Build the Model


In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.parallel as P


class LSTM(nn.Module):
    def __init__(self, n_layers, n_hidden, n_dict, n_enc_dim):
        super(LSTM, self).__init__()

        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.n_dict = n_dict

        # define encoder
        self.encoder = nn.Embedding(self.n_dict, n_enc_dim)

        # define rnn cell
        self.rnn = nn.LSTM(input_size= n_enc_dim,
                           hidden_size=self.n_hidden,
                           num_layers=self.n_layers,
                           batch_first=True)

        # define decoder
        self.decoder = nn.Linear(in_features=self.n_hidden,
                                 out_features=self.n_dict)

        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, hidden, tau):
        # TO DO: incoporate tau into the following code
        x_encoder = self.encoder(x)
        x_encoder, x_hidden = self.rnn(x_encoder, hidden)
        x_decoder = self.decoder(x_encoder)
        x_decoder = x_decoder / tau # temperature scaling
        x_pred = self.log_softmax(x_decoder)

        return x_pred, x_hidden


class Model(nn.Module):
    def __init__(self, n_layers, n_hidden, n_dict, n_enc_dim):
        super(Model, self).__init__()

        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.model = LSTM(n_layers, n_hidden, n_dict, n_enc_dim)

    def forward(self, x, hidden, tau):
         return self.model(x, hidden, tau)

    def init_hidden(self, batch_size, random_init=True):
        if random_init:
            return torch.randn(self.n_layers, batch_size, self.n_hidden), \
                   torch.randn(self.n_layers, batch_size, self.n_hidden)
        else:
            return torch.zeros(self.n_layers, batch_size, self.n_hidden), \
                   torch.zeros(self.n_layers, batch_size, self.n_hidden)


## Train  the Model

This cell is the main part to train the model using data loaders and an optimizer.

In [None]:
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

class Runner(object):
    def __init__(self, model, lr, weight_decay):
      self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
      self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.98)
      self.learning_rate = lr
      self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
      self.model = model.to(self.device)
      self.criterion = nn.NLLLoss().to(self.device)

    def train(self, dataloader):

      train_batch_num = len(dataloader)
      self.model.train()

      for batch_idx, batch_data in enumerate(dataloader):
        batch_data = batch_data.to(self.device)
        batch_c0, batch_h0 = self.model.init_hidden(batch_data.shape[0])
        init_hidden = (batch_c0.to(self.device), batch_h0.to(self.device))
        batch_hidden = init_hidden

        # forward
        pred, _ = self.model(x=batch_data[:, :-1], hidden=batch_hidden, tau = 1.0)
        pred, target = pred.reshape(-1, pred.shape[-1]), batch_data[:, 1:].reshape(-1)
        loss = self.criterion(pred, target)

        # backward
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

      return batch_idx, train_batch_num, loss.item()

    def test(self, sequence, tau):
      with torch.no_grad():
        self.model.eval()

        batch_c0, batch_h0 = self.model.init_hidden(batch_size=1, random_init=False)
        init_hidden = (batch_c0.to(self.device), batch_h0.to(self.device))
        init_pred = torch.zeros((1,), dtype=torch.long).to(self.device)

        hidden = init_hidden
        pred = init_pred
        preds = []

        for step in range(sequence - 1):
          pred, hidden = self.model(x=pred.unsqueeze(0), hidden=hidden, tau=tau)
          pred_dist = pred.data.view(-1).exp()
          pred = torch.multinomial(pred_dist, 1)
          preds.append(pred.cpu().numpy()[0])

        return preds
      
    def test_primer(self, primer, sequence, tau):
    # TO DO: implement the primed generation
      with torch.no_grad():
        self.model.eval()

        # regenerate the hidden state using the primer
        batch_c0, batch_h0 = self.model.init_hidden(batch_size=1, random_init=False)
        init_hidden = (batch_c0.to(self.device), batch_h0.to(self.device))
        _, hidden = self.model(x=primer.unsqueeze(0).to(self.device), hidden=init_hidden, tau=tau)

        # start generation from last token of primer
        pred = primer[-1].unsqueeze(0).to(self.device)
        preds = primer.tolist()
        
        for step in range(sequence - len(primer)):
            pred, hidden = self.model(x=pred.unsqueeze(0), hidden=hidden, tau=tau)
            pred_dist = pred.data.view(-1).exp()
            pred = torch.multinomial(pred_dist, 1)
            preds.append(pred.cpu().numpy()[0])

        return preds

Let's train the model!

In [None]:
from torch.utils.data import DataLoader

batch_size = 48
window_size = 200
weight_decay = 0
learning_rate = 1e-3
NUM_EPOCHS = 10000

# LSTM spec
LSTM_n_enc_dim = 240
LSTM_n_layers = 3
LSTM_n_hidden = 256
LSTM_n_dict = 240

# dataloader
dataset = EventDataset(window_size)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

# model
model = Model(LSTM_n_layers, LSTM_n_hidden, LSTM_n_dict, LSTM_n_enc_dim)

# dataloader
runner = Runner(model=model, lr = learning_rate, weight_decay = weight_decay)

for epoch in range(NUM_EPOCHS):
  batch_idx, train_batch_num, train_loss = runner.train(train_loader)
  if (epoch % 100) == 0:
    print('Epoch: {:03d}/{:03d}, Loss: {:5f}'.format(epoch + 1, NUM_EPOCHS, train_loss))



## Inference: Unconditional Generation

In [None]:
sequence = 4000
tau = 1.0
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

synth_audio = gen_midi.fluidsynth(fs=sr)
ipd.Audio(synth_audio, rate=sr)

## Evaluating the Result Using Pitch, Step, and Duration Statistics

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):
  plt.figure(figsize=[15, 5])
  plt.subplot(1, 3, 1)
  sns.histplot(notes, x="pitch", bins=20)

  plt.subplot(1, 3, 2)
#  max_step = np.percentile(notes['step'], 100 - drop_percentile)
  max_step = 0.5
  sns.histplot(notes, x="step", bins=np.linspace(0, max_step, 21))

  plt.subplot(1, 3, 3)
#  max_duration = np.percentile(notes['duration'], 100 - drop_percentile)
  max_duration = 2.0
  sns.histplot(notes, x="duration", bins=np.linspace(0, max_duration, 21))


This is the statistics of the training data.

In [None]:
midi_file = 'gct634-SMD-MIDI/Bach_BWV888-01_008_20110315-SMD.mid'
pm = pretty_midi.PrettyMIDI(midi_file)
raw_notes = midi_to_notes(pm)
plot_distributions(raw_notes)

Plot the statistics of the generated music.

In [None]:
gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)

In [None]:
import numpy as np
from scipy.stats import entropy

eps = 1e-9

def normalize_hist(hist):
    hist = hist.astype(np.float64)
    return hist / (np.sum(hist) + eps)

def histogram_distance(p_hist, q_hist):
    p = normalize_hist(p_hist)
    q = normalize_hist(q_hist)

    distances = {
        "L1": np.sum(np.abs(p - q)),
        "L2": np.sqrt(np.sum((p - q)**2)),
        "KL(p||q)": entropy(p + eps, q + eps),
        "KL(q||p)": entropy(q + eps, p + eps),
    }
    
    distances["average"] = np.mean(list(distances.values()))

    return distances

def histogram_distance_notes(raw_notes, gen_notes):
    result = {}
    for key in raw_notes.keys():
        p = np.histogram(raw_notes[key])[0]
        q = np.histogram(gen_notes[key])[0]
        result[key] = histogram_distance(p, q)
    return result
  
def print_histogram_distance(raw_notes, gen_notes):
    dist = histogram_distance_notes(raw_notes, gen_notes)
    average = np.mean([v["average"] for v in dist.values()])
    for key, value in dist.items():
        print(key)
        for k, v in value.items():
            print(f"  {k}: {v:.4f}")
    print(f"Average distance: {average:.4f}")
            
print_histogram_distance(raw_notes, gen_notes)

## Conditional Generation with a Primer

In [None]:
sequence = 4000
tau = 1.0

date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
primer_data = torch.load(date_file, weights_only=False)

primer = primer_data[:100]
primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
primer_output_tokens = list(primer) + primer_output_tokens

event_seq = EventSeq.from_array(primer_output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
primer_gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(primer_gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)
synth_audio = primer_gen_midi.fluidsynth(fs=sr)
ipd.Audio(synth_audio, rate=sr)


# Credit

This homework is based on a PerformanceRNN implementation on this Github reposity (https://github.com/hmi88/prnn/tree/master).   

The orignial code was modified by Juhan Nam, Hounsu Kim, Jaeran Choi from the KAIST Music and Audio Computing Lab.

# Tau optimization

### Tau = 0.6

In [None]:
# Unconditional generation
sequence = 4000
tau = 0.8
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

In [None]:
# # Conditional generation
# date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
# primer_data = torch.load(date_file, weights_only=False)

# primer = primer_data[:100]
# primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
# primer_output_tokens = list(primer) + primer_output_tokens

# event_seq = EventSeq.from_array(primer_output_tokens)
# output_notes = event_seq.to_note_seq()
# velocity_scale=0.8

# # restore the velocity scale
# for note in output_notes:
#   note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# # make a prettyMIDI object
# primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
# inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
# inst.notes = copy.deepcopy(output_notes)
# primer_gen_midi.instruments.append(inst)

# gen_notes = midi_to_notes(primer_gen_midi)
# plot_distributions(gen_notes)
# print_histogram_distance(raw_notes, gen_notes)

# synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

### Tau = 0.8

In [None]:
# Unconditional generation
sequence = 4000
tau = 0.8
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

In [None]:
# # Conditional generation
# date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
# primer_data = torch.load(date_file, weights_only=False)

# primer = primer_data[:100]
# primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
# primer_output_tokens = list(primer) + primer_output_tokens

# event_seq = EventSeq.from_array(primer_output_tokens)
# output_notes = event_seq.to_note_seq()
# velocity_scale=0.8

# # restore the velocity scale
# for note in output_notes:
#   note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# # make a prettyMIDI object
# primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
# inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
# inst.notes = copy.deepcopy(output_notes)
# primer_gen_midi.instruments.append(inst)

# gen_notes = midi_to_notes(primer_gen_midi)
# plot_distributions(gen_notes)
# print_histogram_distance(raw_notes, gen_notes)

# synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

### Tau = 1.2

In [None]:
# Unconditional generation
sequence = 4000
tau = 0.8
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

In [None]:
# # Conditional generation
# date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
# primer_data = torch.load(date_file, weights_only=False)

# primer = primer_data[:100]
# primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
# primer_output_tokens = list(primer) + primer_output_tokens

# event_seq = EventSeq.from_array(primer_output_tokens)
# output_notes = event_seq.to_note_seq()
# velocity_scale=0.8

# # restore the velocity scale
# for note in output_notes:
#   note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# # make a prettyMIDI object
# primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
# inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
# inst.notes = copy.deepcopy(output_notes)
# primer_gen_midi.instruments.append(inst)

# gen_notes = midi_to_notes(primer_gen_midi)
# plot_distributions(gen_notes)
# print_histogram_distance(raw_notes, gen_notes)

# synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# # ipd.Audio(synth_audio, rate=sr)

### Tau = 1.4

In [None]:
# Unconditional generation
sequence = 4000
tau = 0.8
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

In [None]:
# # Conditional generation
# date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
# primer_data = torch.load(date_file, weights_only=False)

# primer = primer_data[:100]
# primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
# primer_output_tokens = list(primer) + primer_output_tokens

# event_seq = EventSeq.from_array(primer_output_tokens)
# output_notes = event_seq.to_note_seq()
# velocity_scale=0.8

# # restore the velocity scale
# for note in output_notes:
#   note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# # make a prettyMIDI object
# primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
# inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
# inst.notes = copy.deepcopy(output_notes)
# primer_gen_midi.instruments.append(inst)

# gen_notes = midi_to_notes(primer_gen_midi)
# plot_distributions(gen_notes)
# print_histogram_distance(raw_notes, gen_notes)

# synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)

# Tempo Conditioning + augmentation

### Preprocess

In [None]:
import os
import random
import torch

def find_files_by_extensions(root, exts=[]):
    def _has_ext(name):
        if not exts:
            return True
        name = name.lower()
        for ext in exts:
            if name.endswith(ext):
                return True
        return False
    for path, _, files in os.walk(root):
        for name in files:
            if _has_ext(name):
                yield os.path.join(path, name)
                
# Load MIDI files and tokenize it to store as a ".data"
midi_path = 'gct634-SMD-MIDI/'
midi_files = list(find_files_by_extensions(midi_path, ['.mid']))

save_path = 'gct634-SMD-data/'
os.makedirs(save_path, exist_ok=True)
out_fmt = '{}{}.data'

# Tempo augmentation 
tempo_factors = [0.8, 0.9, 1.0, 1.1, 1.2]

for midi_file in midi_files:
    pm = pretty_midi.PrettyMIDI(midi_file)
    instrument = pm.instruments[0]
    sorted_notes = sorted(instrument.notes, key=lambda note: note.start)

    # shift the first note to 0 sec
    time_shift = sorted_notes[0].start
    for note in sorted_notes:
        note.start -= time_shift
        note.end -= time_shift

    base_name = os.path.basename(midi_file).split('.')[0]
    
    for tempo_factor in tempo_factors:
        aug_notes = []
        for note in sorted_notes:
            aug_note = pretty_midi.Note(
                velocity=note.velocity,
                pitch=note.pitch,
                start=note.start * tempo_factor,
                end=note.end * tempo_factor
            )
            aug_notes.append(aug_note)

        # translates the note sequences into the four classes (NOTE_ON, NOTE_OFF, VELOCITY, and TIME_SHIFT) of events.
        event_seq = EventSeq.from_note_seq(aug_notes)

        # encodes the four classes of events into token indices.
        event_data = event_seq.to_array()

        # save the tokenized file
        suffix = '' if tempo_factor == 1.0 else f'_tempo{tempo_factor}'
        save_name = os.path.join(save_path, out_fmt.format(base_name, suffix))
        torch.save(event_data, save_name)

### Dataloader

In [None]:

from torch.utils.data import Dataset
import random
import torch

class TempoAugDataset(Dataset):
    def __init__(self, window_size, augment_tempo=True):
        super(TempoAugDataset, self).__init__()
        self.window_size = window_size
        self.data_path = 'gct634-SMD-data/'
        self.data_list = list(find_files_by_extensions(self.data_path, ['.data']))
        if not augment_tempo:
            self.data_list = [f for f in self.data_list if '_tempo' not in f]

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        file_name = os.path.join(self.data_list[idx])
        data = torch.load(file_name, weights_only=False)

        random_int = random.randint(0, len(data) - self.window_size)
        data = data[random_int:random_int+self.window_size]
        
        # Extract tempo factor from filename
        base_name = os.path.basename(file_name)
        tempo_factor = float(base_name.split('_tempo')[1].split('.data')[0]) if '_tempo' in base_name else 1.0
        
        return torch.LongTensor(data), torch.FloatTensor([tempo_factor])

### Model

In [None]:

import os

import torch
import torch.nn as nn
import torch.nn.parallel as P

class TempoLSTM(nn.Module):
    def __init__(self, n_layers, n_hidden, n_dict, n_enc_dim, tempo_cond=True):
        super(TempoLSTM, self).__init__()

        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.n_dict = n_dict
        self.n_enc_dim = n_enc_dim
        self.tempo_cond = tempo_cond

        # define encoder
        self.encoder = nn.Embedding(self.n_dict, n_enc_dim)

        # Tempo embedding
        if self.tempo_cond:
            self.tempo_proj = nn.Linear(1, n_enc_dim)

        # define rnn cell
        self.rnn = nn.LSTM(input_size= n_enc_dim,
                           hidden_size=self.n_hidden,
                           num_layers=self.n_layers,
                           batch_first=True)

        # define decoder
        self.decoder = nn.Linear(in_features=self.n_hidden,
                                 out_features=self.n_dict)
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x, hidden, tau, tempo=None):
      
        x_encoder = self.encoder(x)

        # Tempo embedding
        if self.tempo_cond:
            if isinstance(tempo, float):
                tempo = torch.FloatTensor([tempo]).unsqueeze(0).to(x_encoder.device)
            elif isinstance(tempo, torch.Tensor) and len(tempo.shape) == 1:
                tempo = tempo.unsqueeze(1)
            tempo_emb = self.tempo_proj(tempo)
            tempo_emb = tempo_emb.unsqueeze(1).expand_as(x_encoder)
            x_encoder = x_encoder + tempo_emb

        x_encoder, x_hidden = self.rnn(x_encoder, hidden)
        x_decoder = self.decoder(x_encoder) / tau  # temperature scaling
        x_pred = self.log_softmax(x_decoder)

        return x_pred, x_hidden


class TempoModel(nn.Module):
    def __init__(self, n_layers, n_hidden, n_dict, n_enc_dim, tempo_cond=True):
        super(TempoModel, self).__init__()
        
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.model = TempoLSTM(n_layers, n_hidden, n_dict, n_enc_dim, tempo_cond)

    def forward(self, x, hidden, tau, tempo=None):
        return self.model(x, hidden, tau, tempo)

    def init_hidden(self, batch_size, random_init=True):
        if random_init:
            return torch.randn(self.n_layers, batch_size, self.n_hidden), \
                   torch.randn(self.n_layers, batch_size, self.n_hidden)
        else:
            return torch.zeros(self.n_layers, batch_size, self.n_hidden), \
                   torch.zeros(self.n_layers, batch_size, self.n_hidden)

### Runner

In [None]:
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

class TempoRunner(object):
    def __init__(self, model, lr, weight_decay):
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        self.scheduler = StepLR(self.optimizer, step_size=1000, gamma=0.98)
        self.learning_rate = lr
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model = model.to(self.device)
        self.criterion = nn.NLLLoss().to(self.device)

    def train(self, dataloader):
      
        train_batch_num = len(dataloader)
        self.model.train()

        for batch_idx, (batch_data, batch_tempo) in enumerate(dataloader):
            batch_tempo = batch_tempo.to(self.device)
            batch_data = batch_data.to(self.device)
            batch_c0, batch_h0 = self.model.init_hidden(batch_data.shape[0])
            init_hidden = (batch_c0.to(self.device), batch_h0.to(self.device))
            batch_hidden = init_hidden

            # forward
            pred, _ = self.model(x=batch_data[:, :-1], hidden=batch_hidden, tau=1.0, tempo=batch_tempo)
            pred, target = pred.reshape(-1, pred.shape[-1]), batch_data[:, 1:].reshape(-1)
            loss = self.criterion(pred, target)

            # backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

        return batch_idx, train_batch_num, loss.item()

    def test(self, sequence, tau, tempo=1.0):
        with torch.no_grad():
            self.model.eval()

            batch_c0, batch_h0 = self.model.init_hidden(batch_size=1, random_init=False)
            init_hidden = (batch_c0.to(self.device), batch_h0.to(self.device))
            init_pred = torch.zeros((1,), dtype=torch.long).to(self.device)

            hidden = init_hidden
            pred = init_pred
            preds = []

            for step in range(sequence - 1):
                pred, hidden = self.model(x=pred.unsqueeze(0), hidden=hidden, tau=tau, tempo=tempo)
                pred_dist = pred.data.view(-1).exp()
                pred = torch.multinomial(pred_dist, 1)
                preds.append(pred.cpu().numpy()[0])

            return preds

      
    def test_primer(self, primer, sequence, tau, tempo=1.0):
        with torch.no_grad():
            self.model.eval()

            # regenerate the hidden state using the primer
            batch_c0, batch_h0 = self.model.init_hidden(batch_size=1, random_init=False)
            init_hidden = (batch_c0.to(self.device), batch_h0.to(self.device))
            _, hidden = self.model(x=primer.unsqueeze(0).to(self.device), hidden=init_hidden, tau=tau, tempo=tempo)

            # start generation from last token of primer
            pred = primer[-1].unsqueeze(0).to(self.device)
            preds = primer.tolist()
            
            for step in range(sequence - len(primer)):
                pred, hidden = self.model(x=pred.unsqueeze(0), hidden=hidden, tau=tau, tempo=tempo)
                pred_dist = pred.data.view(-1).exp()
                pred = torch.multinomial(pred_dist, 1)
                preds.append(pred.cpu().numpy()[0])

            return preds

### Train

In [None]:
from torch.utils.data import DataLoader

batch_size = 48
window_size = 200
weight_decay = 0
learning_rate = 1e-3
NUM_EPOCHS = 10000

# LSTM spec
LSTM_n_enc_dim = 240
LSTM_n_layers = 3
LSTM_n_hidden = 256
LSTM_n_dict = 240

# dataloader
dataset = TempoAugDataset(window_size)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

# model
model = TempoModel(LSTM_n_layers, LSTM_n_hidden, LSTM_n_dict, LSTM_n_enc_dim)

# dataloader
runner = TempoRunner(model=model, lr = learning_rate, weight_decay = weight_decay)

for epoch in range(NUM_EPOCHS):
  batch_idx, train_batch_num, train_loss = runner.train(train_loader)
  if (epoch % 100) == 0:
    print('Epoch: {:03d}/{:03d}, Loss: {:5f}'.format(epoch + 1, NUM_EPOCHS, train_loss))

### Infer: Unconditional Generation

In [None]:
sequence = 4000
tau = 1.0
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
ipd.Audio(synth_audio, rate=sr)

### Infer: Conditional Generation

In [None]:
sequence = 4000
tau = 1.0

date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
primer_data = torch.load(date_file, weights_only=False)

primer = primer_data[:100]
primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
primer_output_tokens = list(primer) + primer_output_tokens

event_seq = EventSeq.from_array(primer_output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
primer_gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(primer_gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)


# Ablation

### Only tempo conditioning

In [None]:
# dataloader
dataset = TempoAugDataset(window_size, augment_tempo=False)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

# model
model = TempoModel(LSTM_n_layers, LSTM_n_hidden, LSTM_n_dict, LSTM_n_enc_dim, tempo_cond=True)

# dataloader
runner = TempoRunner(model=model, lr = learning_rate, weight_decay = weight_decay)

for epoch in range(NUM_EPOCHS):
  batch_idx, train_batch_num, train_loss = runner.train(train_loader)
  if (epoch % 100) == 0:
    print('Epoch: {:03d}/{:03d}, Loss: {:5f}'.format(epoch + 1, NUM_EPOCHS, train_loss))

In [None]:
sequence = 4000
tau = 1.0
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
ipd.Audio(synth_audio, rate=sr)

In [None]:
sequence = 4000
tau = 1.0

date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
primer_data = torch.load(date_file, weights_only=False)

primer = primer_data[:100]
primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
primer_output_tokens = list(primer) + primer_output_tokens

event_seq = EventSeq.from_array(primer_output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
primer_gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(primer_gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)


### Only tempo augmentation

In [None]:
# dataloader
dataset = TempoAugDataset(window_size, augment_tempo=True)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

# model
model = TempoModel(LSTM_n_layers, LSTM_n_hidden, LSTM_n_dict, LSTM_n_enc_dim, tempo_cond=False)

# dataloader
runner = TempoRunner(model=model, lr = learning_rate, weight_decay = weight_decay)

for epoch in range(NUM_EPOCHS):
  batch_idx, train_batch_num, train_loss = runner.train(train_loader)
  if (epoch % 100) == 0:
    print('Epoch: {:03d}/{:03d}, Loss: {:5f}'.format(epoch + 1, NUM_EPOCHS, train_loss))

In [None]:
# Unconditional generation
sequence = 4000
tau = 1.0
output_tokens = runner.test(sequence,tau)

event_seq = EventSeq.from_array(output_tokens)
output_notes = event_seq.to_note_seq()
velocity_scale=0.8

# restore the velocity scale
for note in output_notes:
  note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# make a prettyMIDI object
gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
inst.notes = copy.deepcopy(output_notes)
gen_midi.instruments.append(inst)

gen_notes = midi_to_notes(gen_midi)
plot_distributions(gen_notes)
print_histogram_distance(raw_notes, gen_notes)

synth_audio = gen_midi.fluidsynth(fs=sr)
ipd.Audio(synth_audio, rate=sr)

In [None]:
# # Conditional generation
# date_file = 'gct634-SMD-data/Bach_BWV888-01_008_20110315-SMD.data'
# primer_data = torch.load(date_file, weights_only=False)

# primer = primer_data[:100]
# primer_output_tokens = runner.test_primer(torch.LongTensor(primer), sequence, tau)
# primer_output_tokens = list(primer) + primer_output_tokens

# event_seq = EventSeq.from_array(primer_output_tokens)
# output_notes = event_seq.to_note_seq()
# velocity_scale=0.8

# # restore the velocity scale
# for note in output_notes:
#   note.velocity = int((note.velocity - 64) * velocity_scale + 64)

# # make a prettyMIDI object
# primer_gen_midi = PrettyMIDI(resolution=DEFAULT_RESOLUTION, initial_tempo=DEFAULT_TEMPO)
# inst = Instrument(DEFAULT_SAVING_PROGRAM, False, 'NoteSeq')
# inst.notes = copy.deepcopy(output_notes)
# primer_gen_midi.instruments.append(inst)

# gen_notes = midi_to_notes(primer_gen_midi)
# plot_distributions(gen_notes)
# print_histogram_distance(raw_notes, gen_notes)

# synth_audio = primer_gen_midi.fluidsynth(fs=sr)
# ipd.Audio(synth_audio, rate=sr)