# Attentive Music

I plan to use a Transformer architecture to generate musical MIDI sequences.

In [19]:
from music21 import *
import os, sys
import numpy as np
from tqdm import tqdm_notebook as tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torchsample.modules import ModuleTrainer
import pickle
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda:0")

## Data

I've found a [dataset](https://github.com/jukedeck/nottingham-dataset) of MIDI files.

In [2]:
PATH="../nottingham-dataset/MIDI"
files = [f for f in os.listdir(PATH) if os.path.isfile(PATH+'/'+f)]
files[:10]

['waltzes7.mid',
 'reelsa-c79.mid',
 'reelsr-t57.mid',
 'jigs211.mid',
 'morris29.mid',
 'reelsu-z8.mid',
 'jigs156.mid',
 'ashover5.mid',
 'reelsa-c32.mid',
 'morris10.mid']

From [this](https://www.hackerearth.com/blog/machine-learning/jazz-music-using-deep-learning/) tutorial for parsing MIDI.

In [3]:
def get_notes(file_list, PATH):  
    notes = []  
    for file in tqdm(file_list):  
    # converting .mid file to stream object
        midi = converter.parse(PATH + '/' + file)  
        notes_to_parse = [] 
        try:  
            # Given a single stream, partition into a part for each unique instrument  
            parts = instrument.partitionByInstrument(midi)  
        except:  
            pass  
        if parts: # if parts has instrument parts   
            notes_to_parse = parts.parts[0].recurse()  
        else:  
            notes_to_parse = midi.flat.notes  
        for element in notes_to_parse:   
            if isinstance(element, note.Note):  
                # if element is a note, extract pitch   
                notes.append(str(element.pitch))  
            elif(isinstance(element, chord.Chord)):  
                # if element is a chord, append the normal form of the   
                # chord (a list of integers) to the list of notes.   
                notes.append('.'.join(str(n) for n in element.normalOrder)) 
    
    with open('data/notes', 'wb') as filepath:  
        pickle.dump(notes, filepath)  
    return notes

In [4]:
# Create notes again
# notes = get_notes(files, PATH)

# Load from previously saved version
if os.path.getsize('data/notes') > 0:
    with open('data/notes', 'rb') as f:
        unpickler = pickle.Unpickler(f)
        notes = unpickler.load()

In [5]:
pitchnames = sorted(set(item for item in notes))
note_to_int = dict((note, number) for number, note in enumerate(pitchnames))

In [6]:
int_notes = [note_to_int[x] for x in notes]; int_notes[:10]

[88, 111, 34, 108, 103, 88, 34, 110, 88, 94]

In [7]:
bs = 8

In [8]:
xs = np.array([np.array(int_notes[i*bs:(i+1)*bs]) for i in range(len(int_notes)//bs)])
ys = np.array([int_notes[(i+1)*bs] for i in range(len(int_notes)//bs)])

In [9]:
xs[:10]

array([[ 88, 111,  34, 108, 103,  88,  34, 110],
       [ 88,  94,  67, 118,  94,  88,  34, 110],
       [ 88, 111,  34, 108, 103,  88,  34, 110],
       [ 88,  94,  44, 108,  97,  83, 103,  34],
       [ 88, 111,  34, 108, 103,  88,  34, 110],
       [ 88,  94,  67, 118,  94,  88,  34, 110],
       [ 88, 111,  34, 108, 103,  88,  34, 110],
       [ 88,  94,  44, 108,  97,  83, 103,  34],
       [ 88, 111,  34, 108, 103,  88,  34, 110],
       [ 88,  94,  67, 118,  94,  88,  34, 110]])

These are the next notes in the sequence for each sequence in `xs`.

In [10]:
ys[:10]

array([88, 88, 88, 88, 88, 88, 88, 88, 88, 88])

But our y data will need to be one-hot encoded for our training to work.

In [11]:
def one_hot(batch,vocab_size):
    ones = torch.eye(vocab_size)
    return ones.index_select(0,batch)

In [78]:
xs.shape

(30727, 8)

In [81]:
x_tr, x_val, y_tr, y_val = train_test_split(xs[:30720], ys[:30720], test_size=0.25)

In [82]:
def tensor(from_int):
    return torch.from_numpy(np.array(from_int)).long()

We need to create a class for our dataset.

In [83]:
class MusicData(Dataset):

    def __init__(self, x_data, y_data):
        self.len = len(x_data)
        self.x_data = tensor(x_data)
        self.y_data = tensor(y_data)
            
    def __getitem__(self, index):
        return self.x_data[index], one_hot(self.y_data[index],vocab_size=120).squeeze(0).long()
    
    def __len__(self):
        return self.len

In [84]:
tr_data = MusicData(x_tr, y_tr)
val_data = MusicData(x_val, y_val)

tr_loader = DataLoader(dataset=tr_data,
                       batch_size=32,
                       shuffle=True,
                       num_workers=1,
                       pin_memory=True)
val_loader = DataLoader(dataset=val_data,
                        batch_size=32,
                        shuffle=True,
                        num_workers=1,
                        pin_memory=True)

## LSTM

Let's first try an LSTM as a simple example.

In [106]:
class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, num_layers, batch_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.vocab_size = vocab_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_size, dropout=0.5)
        self.hidden2tag = nn.Linear(hidden_dim, vocab_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros(self.batch_size, self.num_layers, self.hidden_dim).cuda(),
                torch.zeros(self.batch_size, self.num_layers, self.hidden_dim).cuda())

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(
            embeds.view(sentence.size(0),sentence.size(1), -1), self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(sentence.size(0),sentence.size(1), -1))
#         tag_scores = F.log_softmax(tag_space, dim=-1).view(self.batch_size, -1, self.vocab_size)
        tag_scores = F.log_softmax(tag_space, dim=-1)
        return tag_scores

In [107]:
model = LSTMTagger(embedding_dim=50,hidden_dim=128,vocab_size=len(note_to_int), num_layers=8, batch_size=32).cuda()
loss_function = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=1)

for epoch in tqdm(range(4), desc='Epoch'):  # again, normally you would NOT do 4 epochs, it is toy data
    for i, (inputs, labels) in enumerate(tqdm(tr_loader, desc='Batch')):
        
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()
        optimizer.zero_grad()

        # Also, we need to clear out the hidden state of the LSTM,
        # detaching it from its history on the last instance.
        model.hidden = model.init_hidden()

        # Step 2. Get our inputs ready for the network, that is, turn them into
        # Tensors of word indices.
        inputs, labels = Variable(inputs), Variable(labels)

        # Step 3. Run our forward pass.
        tag_scores = model(inputs.cuda())

        # Step 4. Compute the loss, gradients, and update the parameters by
        #  calling optimizer.step()
        loss = loss_function(tag_scores.cuda(), labels.cuda())
        loss.backward()
        sys.stdout.write('\r'+str(loss))
        optimizer.step()

HBox(children=(IntProgress(value=0, description='Epoch', max=4, style=ProgressStyle(description_width='initial…

HBox(children=(IntProgress(value=0, description='Batch', max=720, style=ProgressStyle(description_width='initi…

tensor(2.0794, device='cuda:0', grad_fn=<NllLoss2DBackward>)

Process Process-40:
Traceback (most recent call last):
  File "/home/jhbremner/.conda/envs/attentive-music/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/jhbremner/.conda/envs/attentive-music/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jhbremner/.conda/envs/attentive-music/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/home/jhbremner/.conda/envs/attentive-music/lib/python3.7/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/home/jhbremner/.conda/envs/attentive-music/lib/python3.7/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/jhbremner/.conda/envs/attentive-music/lib/python3.7/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/home/jh

KeyboardInterrupt: 

In [73]:
torch.save(model.state_dict(), 'lstm_model')

In [None]:
model = LSTMTagger(embedding_dim=50,hidden_dim=128,vocab_size=len(note_to_int), num_layers=8, batch_size=32)
model.load_state_dict(torch.load('lstm_model'))

### Test

In [61]:
next(iter(tr_loader))[0][:10]

tensor([[ 94, 100,  81,  88,  88,  57, 103, 100],
        [110,  88,  94,  67,  88, 103,  34, 108],
        [ 94,  88,  81,  89, 119,  44, 108, 108],
        [ 94, 110, 102, 107, 110, 115,  47,  94],
        [102,  57, 118, 118, 110, 118,  67,  88],
        [107,  47, 110, 107, 102,  96,  83,  93],
        [ 94,  97,  83, 103, 108,  97, 103,  34],
        [ 89, 111,  12, 103, 103,  34, 108, 103],
        [100,  34, 108, 103, 100,  94,  67, 118],
        [119, 111, 108, 111, 119, 103,  94,  67]])

In [67]:
tags = next(iter(tr_loader))[1]

In [90]:
np.argmax(tags.cpu().detach().numpy(), axis=1)

array([ 88, 111, 107, 111, 108, 103, 111, 102, 103,  97, 103, 102,  89,
        88, 118, 103, 107, 102, 108, 103,  34, 103,  89, 118,  97, 119,
        88,  91,  97, 100,  94, 108])

In [99]:
model(next(iter(tr_loader))[0].cuda()).size()

tensor([[[-4.7930, -4.7547, -4.7560,  ..., -4.7498, -4.7386, -4.8113],
         [-4.7958, -4.7464, -4.7515,  ..., -4.7518, -4.7369, -4.8163],
         [-4.7985, -4.7524, -4.7524,  ..., -4.7527, -4.7391, -4.8094],
         ...,
         [-4.7953, -4.7445, -4.7464,  ..., -4.7536, -4.7385, -4.8071],
         [-4.7961, -4.7466, -4.7508,  ..., -4.7602, -4.7323, -4.8092],
         [-4.7944, -4.7486, -4.7491,  ..., -4.7545, -4.7389, -4.8050]],

        [[-4.7887, -4.7514, -4.7542,  ..., -4.7551, -4.7361, -4.8152],
         [-4.7948, -4.7476, -4.7497,  ..., -4.7550, -4.7355, -4.8122],
         [-4.8025, -4.7467, -4.7532,  ..., -4.7530, -4.7386, -4.8017],
         ...,
         [-4.7903, -4.7481, -4.7435,  ..., -4.7522, -4.7374, -4.8046],
         [-4.7904, -4.7485, -4.7486,  ..., -4.7542, -4.7427, -4.8121],
         [-4.7942, -4.7501, -4.7520,  ..., -4.7568, -4.7390, -4.8053]],

        [[-4.7927, -4.7517, -4.7482,  ..., -4.7545, -4.7369, -4.8143],
         [-4.7891, -4.7529, -4.7540,  ..., -4

In [95]:
preds = model(next(iter(tr_loader))[0].cuda()).cpu().detach().numpy()

In [96]:
preds.shape

(32, 8, 120)

In [77]:
np.argmax(preds, axis=2)+1

array([[88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 78, 88, 88, 88],
       [88, 88, 88, 88, 78, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 88, 88, 88, 88, 88, 88],
       [88, 88, 42, 88, 88, 88, 88, 88],
       [88, 88, 

In [75]:
tags.size()

torch.Size([32, 120])