In [1]:
from music21 import converter, instrument, note, chord, stream, midi
import glob
import os
import gzip
import tarfile    
from torchvision import datasets                  
import numpy as np
import torch
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def download_data(filepath):
    if not os.path.exists(os.path.join(filepath, 'mozart_sonatas.tar.gz')):
        datasets.utils.download_url('https://github.com/Foundations-of-Applied-Mathematics/Data/raw/master/RNN/mozart_sonatas.tar.gz', filepath, 'mozart_sonatas.tar.gz', None)

    print('Extracting {}'.format('mozart_sonatas.tar.gz'))
    gzip_path = os.path.join(filepath, 'mozart_sonatas.tar.gz')
    with open(gzip_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(gzip_path) as zip_f:
        out_f.write(zip_f.read())

    print('Untarring {}'.format('mozart_sonatas.tar'))
    tar_path = os.path.join(filepath,'mozart_sonatas.tar')
    z = tarfile.TarFile(tar_path)
    z.extractall(tar_path.replace('.tar', ''))
    
def music_to_lists(filepath):
    """ =
    filepath: str, path to the .mid files
    in my case this path is:    "/content/drive/MyDrive/mozart_sonatas/mozart_sonatas"
    RETURNS: 
            list of 119348 pitches.    
    """
    myNotes, myOffsets, myDurations = [], [], []
    #open the file 
    dirs = os.listdir(filepath)
    for sonata in tqdm(dirs): 
        path = filepath+"/"+sonata
        #Read the file
        midi = converter.parse(path)
        notes_to_parse = instrument.partitionByInstrument(midi).parts[0].recurse()
        for element in notes_to_parse:
            if isinstance(element, note.Note):
                mystr = str(element.pitch)
                myNotes.append(mystr)
                myOffsets.append(str(element.offset))
                myDurations.append(str(element.duration.quarterLength))
            
            elif isinstance(element, chord.Chord): 
                mystr = ""
                for thisnote in element.pitches:
                    mystr += str(thisnote)
                    mystr += '.'
                mystr = mystr[:-1] #cut off the last period 
                myNotes.append(mystr)
                myOffsets.append(str(element.offset))
                myDurations.append(str(element.duration.quarterLength))

    return myNotes, myOffsets, myDurations
                
                
def clean_lists(notes, offsets, lengths):
    new_notes, new_offsets, new_lengths = [], [] , []
    offsets = [eval(i) for i in offsets]
    lengths = [eval(i) for i in lengths]


    notes = [notes[i] for i in range(len(notes)) if lengths[i] > 0]
    offsets = [offsets[i] for i in range(len(offsets)) if lengths[i] > 0]
    lengths = [lengths[i] for i in range(len(lengths)) if lengths[i] > 0]

    first = True
    first_note_offset = 0
    
    for i in range(len(notes)):
        
        if lengths[i] > 4:
            lengths[i] = 4
        if first:
            first = False
            temp = first_note_offset
        elif abs(offsets[i] - offsets[i-1]) > 8:
                temp = first_note_offset
        else:
                temp = offsets[i-1]
        if (i > 0) and (notes[i] == "REST") and (notes[i-1] == "REST") and (lengths[i] > 0) and (lengths[i-1] > 0):
            continue
            
        new_offsets += [offsets[i] - temp if offsets[i] - temp < 8 else 8]
        new_lengths += [lengths[i]]
        new_notes += [notes[i]]
    
    return new_notes, [i if i <=4 else 4 for i in new_offsets], new_lengths 

def make_unique_map(somelist):
    somedict = dict()
    if not isinstance(somelist[0],str) :
        somelist = [round(i,4) for i in somelist]
    for i,n in enumerate(np.unique(somelist)):
            somedict[n] = i
    return somedict

def loadData(data_size, batch_size, lists, test_size=0.1):
    scale = len(lists)
    all_data =    [[lists[j][i] for j in range(scale)] for i in range(len(lists[0]))]
    flat_list = [item for sublist in all_data for item in sublist]
    data = torch.LongTensor(flat_list)
    labels = [data[scale*i:data_size+scale*i] for i in range(int(len(data)/scale - data_size/scale))]
    sequences = [data[data_size+scale*i:data_size+scale*(i+1)] for i in range(int(len(data)/scale - data_size/scale))]

    X_train, X_test, y_train, y_test = train_test_split( labels, sequences,test_size = test_size)

    tens_data_train = []
    tens_data_test= []
    for x,y in zip(X_train, y_train):
            tens_data_train.append([x,y])
    for x,y in zip(X_test, y_test):
            tens_data_test.append([x,y])
    TrainLoader = DataLoader(tens_data_train,shuffle = True,drop_last=True,batch_size = batch_size)
    TestLoader = DataLoader(tens_data_test,drop_last = True,batch_size = batch_size)
    return TestLoader, TrainLoader, X_train, X_test, y_train, y_test

def predict(more_notes, sizes, model, criterion, optimizer, test_loader, reverse_mapping, rand_choice=None):
    preds = []
    if rand_choice is None: #
        rand_choice = np.random.randint(0,len(tester))
    rand_index = np.random.randint(0,127)
    temp_data = tester[rand_choice][0][rand_index].reshape(1,-1)
    og_data = temp_data
    H = model.initHidden(1) #get the hidden states
    
    while len(preds) < more_notes*len(sizes):
        
        output, H = model(temp_data, H) #find the output
        p = torch.argmax(output[0][:sizes[0]]) #find the location of the largest
        q = torch.argmax(output[0][sizes[0]:sizes[0] + sizes[1]])
        r = torch.argmax(output[0][sizes[0] + sizes[1]:])
        p, q, r = torch.flatten(p),torch.flatten(q+sizes[0]),torch.flatten(r+sizes[0] +sizes[1])
        preds += list(p.numpy()) + list(q.numpy()) + list(r.numpy()) #get the predictions
        temp_data = torch.cat((temp_data.squeeze(),p,q,r))
        temp_data = temp_data[len(sizes):].unsqueeze(0).reshape(1,-1)

    pred_note = [reverse_mapping["note"][p] for p in og_data[:rand_index].numpy()[0][::3]] + [reverse_mapping["note"][p] for p in preds[::3]]
    pred_off = [reverse_mapping["off"][p-sizes[0]] for p in og_data[:rand_index].numpy()[0][1::3]] + [reverse_mapping["off"][p-sizes[0]] for p in preds[1::3]]
    pred_len = [reverse_mapping["len"][p-sizes[0] - sizes[1]] for p in og_data[:rand_index].numpy()[0][2::3]] + [reverse_mapping["len"][p-sizes[0]-sizes[1]] for p in preds[2::3]]
        
    return pred_note, pred_off, pred_len
