In [20]:
import os
import glob
import pickle

import torch
import music21 as m21

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook as tqdm

BATCH_SIZE = 32

class MusicDataset(Dataset):
    def __init__(self, force_update=False, seq_len=100):
        self._read_notes(force_update)
        self._prepare_sequences(seq_len)
        
    def _read_notes(self, force_update=False):
        path = 'notes.pk'
        if not force_update and os.path.exists(path):
            with open(path, mode='rb') as f:
                notes = pickle.load(f)
            print(f'Read {len(notes)} already cached notes!')
        else:
            notes = []
            for file in tqdm(glob.glob('../../res/dataset/raw/*.mid'), ncols=800):
                midi = m21.converter.parse(file)

                notes_to_parse = None
                try: # file has instrument parts
                    s2 = m21.instrument.partitionByInstrument(midi)
                    notes_to_parse = s2.parts[0].recurse() 
                except: # file has notes in a flat structure
                    notes_to_parse = midi.flat.notes

                for element in notes_to_parse:
                    if isinstance(element, m21.note.Note):
                        notes.append(str(element.pitch))
                    elif isinstance(element, m21.chord.Chord):
                        notes.append('.'.join(str(n) for n in element.normalOrder))

            with open(path, 'wb') as filepath:
                pickle.dump(notes, filepath)

            print(f'Read {len(otes)} notes!')

        n_vocab = len(set(notes))
        print(f'{n_vocab} unique notes found.')
        
        self.notes = notes
        self.n_vocab = n_vocab
    
    def _prepare_sequences(self, seq_len=100):
        pitchnames = sorted(set(item for item in self.notes))
        note2idx = {note: number for number, note in enumerate(pitchnames)}

        features = []
        labels = []

        for i in range(0, len(self.notes) - seq_len, 1):
            seq_inputs = self.notes[i:i + seq_len]
            seq_output = self.notes[i + seq_len]
            features.append([note2idx[x] for x in seq_inputs])
            labels.append(note2idx[seq_output])

        features = torch.tensor(features, dtype=torch.float).view((-1, seq_len, 1))
        features = features / float(self.n_vocab)

        self.features = features
        self.labels = labels
        
    def __getitem__(self, i):
        return self.features[i], self.labels[i]

    def __len__(self) -> int:
        return len(self.features)
    
use_cuda = torch.cuda.is_available()
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
dataset = MusicDataset()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, **kwargs)

Read 2089 already cached notes!
58 unique notes found.


In [53]:
from torch import nn
from torch.nn import functional as F

class NNP_RNN(nn.Module):
    def __init__(self, hd=250, layers=3):
        super(NNP_RNN, self).__init__()
            
        self.hd = hd
        
        dropout = 0.2 if layers > 1 else 0
        self.lstm = nn.LSTM(input_size=1,
                            hidden_size=hd,
                            num_layers=layers,
                            dropout=dropout,
                            batch_first=True)

        self.dense_1 = nn.Linear(in_features=250,
                                 out_features=100)
        
        
        self.dense_2 = nn.Linear(in_features=100,
                                 out_features=58)
        
    def forward(self, x):
        x, _ = self.lstm(x)
        
        x = self.dense_1(x)
        x = self.dense_2(x)
        
        x = x[:, -1, :]
        return F.softmax(x, dim=1)
    
    def init_hidden(dims):
        return torch.zeros(device=device)
        
    
model = NNP_RNN()

In [55]:
def train(model, data, epochs=100):
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(1, epochs + 1):
        epoch_loss = 0
        for features, labels in tqdm(data, ncols=800):
            optimizer.zero_grad()
            
            batch_size = features.size(0)
            output = model(features)
            
            loss = criterion(output, labels)
            loss.backward()
                
            optimizer.step()
            
            epoch_loss += loss * batch_size
            
        epoch_loss /= len(data.dataset)
        print(f'Loss epoch #{epoch} = {epoch_loss:.10f}')

train(model, dataloader)

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=63), HTML(value='')), layout=Layout(display='…

Loss epoch #1 = 3.9916024208


HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=63), HTML(value='')), layout=Layout(display='…

KeyboardInterrupt: 