In [None]:
!pip install music21 lightning midi2audio
!apt-get update
!apt-get install -y fluidsynth musescore3

# Imports

In [None]:
import os
import glob

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,random_split
from torch.utils.tensorboard import SummaryWriter
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.utilities.model_summary import ModelSummary

from music21 import corpus,converter,instrument,note,stream,chord,duration
from midi2audio import FluidSynth

from collections import Counter
from IPython.display import Image,Audio
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time
from datetime import timedelta
import pickle

# Constants and data preperation

In [None]:
files = glob.glob("/kaggle/input/classical-music-midi/chopin/*.mid")
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#I did the except block in a different notebook and it is loaded in the try block 
try:
    with open("/kaggle/input/midi-lists/chopin_list.pkl","rb") as f:
        all_mids = pickle.load(f)
except:
    all_mids=[]
    for f in tqdm(files):
        all_mids.append(converter.parse(f).chordify())
    len(all_mids)

In [None]:
def extract_notes(files):
    notes = []
    durations = []
    for file in tqdm(files):
        score = file.transpose(0)
        notes.extend(["START"] * 32)
        durations.extend([0]*32)

        for element in score.flat:
            if isinstance(element,note.Note):
                if element.isRest:
                    notes.append(str(element.name))
                    durations.append(element.duration.quarterLength)
                else:
                    notes.append(str(element.nameWithOctave))
                    durations.append(element.duration.quarterLength)
            if isinstance(element,chord.Chord):
                notes.append('.'.join(n.nameWithOctave for n in element.pitches))
                durations.append(element.duration.quarterLength)
    return notes,durations

In [None]:
notes,durations = extract_notes(all_mids)
print("total notes in corpus: ",len(notes))

In [None]:
notes[30:40],durations[30:40]

In [None]:
get_distinct = lambda e : (sorted(set(e)),len(set(e)))
notes_names, n_notes = get_distinct(notes)
duration_names , n_durations = get_distinct(durations)

In [None]:
n_notes,n_durations

In [None]:
def create_lookup(data):
    d2i = dict((e,i) for i,e in enumerate(data))
    i2d = dict((i,e) for i,e in enumerate(data))
    return d2i,i2d

note_to_int , int_to_note = create_lookup(notes_names)
duration_to_int,int_to_duration = create_lookup(duration_names)

In [None]:
save = {"n2i":note_to_int,"i2n":int_to_note,"d2i":duration_to_int,"i2d":int_to_duration}
with open("lookups.pkl",'wb') as f:
    pickle.dump(save,f)

In [None]:
print('\nnote_to_int')
for i, item in enumerate(note_to_int.items()):
    if i < 10:
        print(item)

In [None]:
print('\nduration_to_int')
print(duration_to_int)

In [None]:
class MusicDs(Dataset):
    def __init__(self,notes,durations,n2i,d2i,seq_len = 32):
        self.notes = notes
        self.durations = durations
        self.n2i = n2i
        self.d2i = d2i
        self.seq_len = seq_len
        self.L_n = len(n2i)
        self.L_d = len(d2i)


    def __len__(self):
        return len(self.notes) - (self.seq_len)

    def __getitem__(self,i):
        notes_in = self.notes[i:i+self.seq_len]
        notes_out = self.n2i[self.notes[i+self.seq_len]]

        duration_in = self.durations[i:i+self.seq_len]
        duration_out = self.d2i[self.durations[i+self.seq_len]]

        xn = torch.tensor([self.n2i[k] for k in notes_in]).int()
        xd = torch.tensor([self.d2i[k] for k in duration_in]).int()

        yn = torch.tensor([notes_out]).long()
        yd = torch.tensor([duration_out]).long()

        yn = F.one_hot(yn,num_classes = self.L_n).squeeze()
        yd = F.one_hot(yd,num_classes = self.L_d).squeeze()

        return xn,xd,yn,yd
#         return xn.to(device),xd.to(device),yn.to(device),yd.to(device)

In [None]:
ds = MusicDs(notes,durations,note_to_int,duration_to_int)

In [None]:
for i in ds[5]:
    print(i.shape)

In [None]:
train_ds,val_ds = random_split(ds,[0.8,0.2])

In [None]:
train_dl = DataLoader(train_ds,batch_size = 64,pin_memory = True)
val_dl = DataLoader(val_ds,batch_size=64,pin_memory = True)

In [None]:
for i,(xn,xd,yn,yd) in enumerate(train_dl):
    print(xn.shape,xd.shape,yn.shape,yd.shape)
    break

# Model Defination

In [None]:
class MusicModel(L.LightningModule):
    def __init__(self, n_notes,n_durations,embed_size = 128,rnn_units = 512):
        super(MusicModel, self).__init__()

        self.rnn_units = rnn_units

        self.ne = nn.Embedding(n_notes,embed_size)
        self.de = nn.Embedding(n_durations,embed_size)

        self.concat = lambda x1,x2: torch.cat([x1,x2],axis =-1)

        self.lstm1 = nn.GRU(2*embed_size,self.rnn_units,batch_first = True)
        self.drop1 = nn.Dropout(0.3)
        self.lstm2 = nn.GRU(self.rnn_units,self.rnn_units,num_layers = 2,batch_first = True)
        self.drop2 = nn.Dropout(0.3)
        self.seq = nn.Sequential(
            nn.Linear(self.rnn_units,self.rnn_units*8),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.rnn_units*8,self.rnn_units*4),
            nn.ReLU(),
            nn.Linear(self.rnn_units*4,self.rnn_units),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.out1 = nn.Linear(self.rnn_units,n_notes)
        self.out2 = nn.Linear(self.rnn_units,n_durations)

        self.example_input_array = [torch.randint(0,10,(2,32))]*2
        self.save_hyperparameters()


    def forward(self,xn,xd):

        xn = self.ne(xn)
        xd = self.de(xd)

        x = self.concat(xn,xd)

        x,_ = self.lstm1(x)
        x = self.drop1(x)
        x,_ = self.lstm2(x)
        x = self.drop2(x)

        x = self.seq(x[:,-1])

        return self.out1(x),self.out2(x)

    def training_step(self,batch,batch_idx):
        xn,xd,yn,yd = batch
        yn_hat,yd_hat = self.forward(xn,xd)
        lossn = F.cross_entropy(yn_hat,yn.float())
        lossd = F.cross_entropy(yd_hat,yd.float())
        loss = lossn + lossd
        self.log("train_loss",loss)
        return loss


    def validation_step(self,batch,batch_idx):
        xn,xd,yn,yd = batch
        yn_hat,yd_hat = self.forward(xn,xd)
        lossn = F.cross_entropy(yn_hat,yn.float())
        lossd = F.cross_entropy(yd_hat,yd.float())
        loss = lossn + lossd
        self.log("val_loss",loss )
        return loss

    def on_train_epoch_end(self):
        if self.current_epoch ==0:
            array = [v.to(self.device) for v in self.example_input_array]
            self.logger.experiment.add_graph(self,array)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(),1e-3)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,verbose = True)
        return {
            "optimizer":opt,
            "lr_scheduler":{
                "scheduler":sch,
                "monitor":"val_loss"
            }
        }

# Training

In [None]:
model = MusicModel(n_notes,n_durations)

In [None]:
ModelSummary(model,max_depth=-1)

In [None]:
es = EarlyStopping(monitor = "train_loss",min_delta = 0.001,patience = 10,verbose = True)

In [None]:
trainer = L.Trainer(
    max_epochs = 300,
    max_time = timedelta(hours = 9),
    default_root_dir = "./MusicModel",
    callbacks = [es]
    )

In [None]:
trainer.fit(model,train_dl,val_dl)

# Predictions

In [None]:
model = model.to(device)
model.eval()

In [None]:
@torch.no_grad
def generator(note_count = 100):

    note_seed = np.array([note_to_int['START'] for _ in range(32)])
    dur_seed = np.array([duration_to_int[0] for _ in range(32)])

    notes_pred = []
    dur_preds = []
    for i in tqdm(range(note_count)):
        note_seed = note_seed.reshape(1,32)
        dur_seed = dur_seed.reshape(1,32)
        pred = model(torch.tensor(note_seed,device = device),torch.tensor(dur_seed,device = device))

        if i<note_count//5:
            #probabilistic choice
            notes_probs = F.softmax(pred[0],dim = 1)
            dur_probs = F.softmax(pred[1],dim = 1)

            note = torch.multinomial(notes_probs,num_samples = 1).item()
            dur = torch.multinomial(dur_probs,num_samples = 1).item()
        else:
            #deterministic choice:
            note = pred[0].argmax().item()
            dur = pred[1].argmax().item()

        notes_pred.append(int_to_note[note])
        dur_preds.append(int_to_duration[dur])

        note_seed = np.insert(note_seed[0],len(note_seed[0]),note)[1:]
        dur_seed = np.insert(dur_seed[0],len(dur_seed[0]),dur)[1:]

    return notes_pred,dur_preds

In [None]:
def convert(out_notes,out_durs):
    s = stream.Stream()
    # Iterate over the notes and durations
    for note_str, dur in zip(out_notes, out_durs):
        if dur == 0 or note_str == 'START':  # If duration is zero, it's a rest
            m21_note = note.Rest()
        elif '.' in note_str:  # It's a chord
            chord_notes = [note.Note(n) for n in note_str.split('.')]
            m21_note = chord.Chord(chord_notes)
        else:  # It's a single note
            m21_note = note.Note(note_str)

        m21_note.duration = duration.Duration(dur)
        s.append(m21_note)

    return s

In [None]:
fs = FluidSynth()
for i in range(10):
    mel = convert(*generator(100))
    try:
        mel.show()
    except:
        print("Melody contains rest notes. so cannt call mel.show()")
    mel.write("midi",f"test-{i}.mid")
    fs.midi_to_audio(f"test-{i}.mid",f"test-{i}.wav")
    display(Audio(f"test-{i}.wav"))