In [1]:
import glob
import pickle
import numpy as np
from music21 import converter, instrument, note, chord, stream
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pygame

path = 'model_weights.pt'

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
def get_notes():
    notes = []
    for file in glob.glob("classical_midi_data/*.mid"):
        midi = converter.parse(file)

        print("Parsing %s" % file)

        notes_to_parse = None

        try:
            s2 = instrument.partitionByInstrument(midi)
            notes_to_parse = s2.parts[0].recurse() 
        except: # file has notes in a flat structure
            notes_to_parse = midi.flat.notes

        for element in notes_to_parse:
            if isinstance(element, note.Note):
                notes.append(str(element.pitch))
            elif isinstance(element, chord.Chord):
                notes.append('.'.join(str(n) for n in element.normalOrder))

    with open('data/notes', 'wb') as filepath:
        pickle.dump(notes, filepath)

    return notes

In [4]:
notes = get_notes()
note_to_int = dict((note, number) for number, note in enumerate(sorted(set(notes))))



Parsing classical_midi_data/haydn_43_1.mid




Parsing classical_midi_data/muss_2.mid




Parsing classical_midi_data/waldstein_1.mid




Parsing classical_midi_data/schumm-1.mid




Parsing classical_midi_data/chpn_op23.mid




Parsing classical_midi_data/chpn-p19.mid
Parsing classical_midi_data/chpn_op7_2.mid
Parsing classical_midi_data/beethoven_opus90_2.mid
Parsing classical_midi_data/chpn-p18.mid
Parsing classical_midi_data/chpn-p24.mid




Parsing classical_midi_data/mendel_op19_1.mid




Parsing classical_midi_data/muss_3.mid
Parsing classical_midi_data/haydn_43_2.mid
Parsing classical_midi_data/muss_1.mid




Parsing classical_midi_data/burg_sylphen.mid
Parsing classical_midi_data/waldstein_2.mid




Parsing classical_midi_data/schumm-2.mid




Parsing classical_midi_data/burg_quelle.mid




Parsing classical_midi_data/mendel_op19_3.mid




Parsing classical_midi_data/schub_d960_4.mid
Parsing classical_midi_data/chpn_op7_1.mid
Parsing classical_midi_data/schum_abegg.mid




Parsing classical_midi_data/gra_esp_4.mid
Parsing classical_midi_data/beethoven_opus90_1.mid
Parsing classical_midi_data/mendel_op19_2.mid
Parsing classical_midi_data/schumm-3.mid




Parsing classical_midi_data/waldstein_3.mid




Parsing classical_midi_data/haydn_8_4.mid
Parsing classical_midi_data/haydn_43_3.mid
Parsing classical_midi_data/muss_4.mid
Parsing classical_midi_data/scn68_10.mid




Parsing classical_midi_data/scn15_10.mid
Parsing classical_midi_data/mendel_op19_6.mid




Parsing classical_midi_data/schub_d960_1.mid
Parsing classical_midi_data/chpn-p23.mid
Parsing classical_midi_data/chpn-p9.mid




Parsing classical_midi_data/burg_perlen.mid




Parsing classical_midi_data/scn16_8.mid
Parsing classical_midi_data/chpn-p8.mid
Parsing classical_midi_data/chpn-p22.mid




Parsing classical_midi_data/liz_rhap09.mid
Parsing classical_midi_data/schumm-6.mid
Parsing classical_midi_data/haydn_8_1.mid




Parsing classical_midi_data/scn15_11.mid




Parsing classical_midi_data/muss_5.mid
Parsing classical_midi_data/muss_7.mid
Parsing classical_midi_data/scn15_13.mid
Parsing classical_midi_data/haydn_8_3.mid
Parsing classical_midi_data/mendel_op19_5.mid
Parsing classical_midi_data/schumm-4.mid
Parsing classical_midi_data/chpn-p20.mid
Parsing classical_midi_data/schub_d960_2.mid
Parsing classical_midi_data/alb_se8.mid
Parsing classical_midi_data/gra_esp_2.mid
Parsing classical_midi_data/liz_et_trans8.mid
Parsing classical_midi_data/gra_esp_3.mid
Parsing classical_midi_data/schub_d960_3.mid
Parsing classical_midi_data/chpn-p21.mid
Parsing classical_midi_data/schumm-5.mid
Parsing classical_midi_data/mendel_op19_4.mid
Parsing classical_midi_data/haydn_8_2.mid
Parsing classical_midi_data/scn15_12.mid




Parsing classical_midi_data/scn68_12.mid
Parsing classical_midi_data/muss_6.mid
Parsing classical_midi_data/ty_november.mid




Parsing classical_midi_data/chp_op18.mid
Parsing classical_midi_data/chpn_op35_4.mid




Parsing classical_midi_data/mz_570_1.mid
Parsing classical_midi_data/chpn_op33_2.mid
Parsing classical_midi_data/chp_op31.mid




Parsing classical_midi_data/deb_prel.mid
Parsing classical_midi_data/ty_juni.mid
Parsing classical_midi_data/mz_570_2.mid
Parsing classical_midi_data/chpn_op25_e4.mid




Parsing classical_midi_data/burg_erwachen.mid
Parsing classical_midi_data/mz_570_3.mid




Parsing classical_midi_data/schuim-4.mid
Parsing classical_midi_data/mendel_op62_3.mid
Parsing classical_midi_data/ty_april.mid
Parsing classical_midi_data/chpn_op35_2.mid
Parsing classical_midi_data/scn15_8.mid
Parsing classical_midi_data/beethoven_opus10_2.mid
Parsing classical_midi_data/haydn_9_3.mid




Parsing classical_midi_data/chpn_op25_e1.mid




Parsing classical_midi_data/liz_donjuan.mid
Parsing classical_midi_data/chpn_op33_4.mid
Parsing classical_midi_data/chpn_op53.mid




Parsing classical_midi_data/hay_40_1.mid
Parsing classical_midi_data/haydn_9_2.mid
Parsing classical_midi_data/beethoven_opus10_3.mid
Parsing classical_midi_data/ty_maerz.mid




Parsing classical_midi_data/elise.mid
Parsing classical_midi_data/scn15_9.mid
Parsing classical_midi_data/schuim-1.mid
Parsing classical_midi_data/chpn_op35_3.mid
Parsing classical_midi_data/mendel_op62_4.mid
Parsing classical_midi_data/chpn_op35_1.mid
Parsing classical_midi_data/schuim-3.mid
Parsing classical_midi_data/beethoven_opus10_1.mid
Parsing classical_midi_data/chpn_op25_e2.mid
Parsing classical_midi_data/hay_40_2.mid
Parsing classical_midi_data/chpn_op25_e3.mid
Parsing classical_midi_data/haydn_9_1.mid
Parsing classical_midi_data/schuim-2.mid
Parsing classical_midi_data/mendel_op62_5.mid




Parsing classical_midi_data/appass_2.mid




Parsing classical_midi_data/scn15_7.mid
Parsing classical_midi_data/mendel_op53_5.mid
Parsing classical_midi_data/mz_545_1.mid
Parsing classical_midi_data/haydn_7_3.mid




Parsing classical_midi_data/mz_332_3.mid
Parsing classical_midi_data/burg_spinnerlied.mid




Parsing classical_midi_data/mz_330_1.mid
Parsing classical_midi_data/mz_332_2.mid
Parsing classical_midi_data/haydn_7_2.mid




Parsing classical_midi_data/chpn_op10_e12.mid
Parsing classical_midi_data/scn15_6.mid
Parsing classical_midi_data/appass_3.mid
Parsing classical_midi_data/appass_1.mid
Parsing classical_midi_data/scn15_4.mid
Parsing classical_midi_data/schubert_D850_4.mid
Parsing classical_midi_data/mendel_op30_5.mid
Parsing classical_midi_data/mz_545_2.mid
Parsing classical_midi_data/mz_330_2.mid
Parsing classical_midi_data/schubert_D935_4.mid
Parsing classical_midi_data/mz_330_3.mid
Parsing classical_midi_data/haydn_7_1.mid
Parsing classical_midi_data/mz_332_1.mid
Parsing classical_midi_data/mz_545_3.mid
Parsing classical_midi_data/mendel_op30_4.mid




Parsing classical_midi_data/chpn_op10_e05.mid
Parsing classical_midi_data/schub_d760_4.mid




Parsing classical_midi_data/scn15_5.mid




Parsing classical_midi_data/scn15_1.mid
Parsing classical_midi_data/beethoven_les_adieux_3.mid
Parsing classical_midi_data/schubert_D850_1.mid
Parsing classical_midi_data/chpn_op10_e01.mid




Parsing classical_midi_data/pathetique_1.mid
Parsing classical_midi_data/schubert_D935_1.mid
Parsing classical_midi_data/chpn_op66.mid
Parsing classical_midi_data/mendel_op30_1.mid




Parsing classical_midi_data/mond_1.mid
Parsing classical_midi_data/schub_d760_1.mid
Parsing classical_midi_data/burg_geschwindigkeit.mid
Parsing classical_midi_data/beethoven_les_adieux_2.mid
Parsing classical_midi_data/ty_september.mid
Parsing classical_midi_data/liz_liebestraum.mid
Parsing classical_midi_data/scn15_2.mid
Parsing classical_midi_data/schubert_D850_2.mid
Parsing classical_midi_data/schub_d760_3.mid




Parsing classical_midi_data/mond_3.mid
Parsing classical_midi_data/mendel_op30_3.mid




Parsing classical_midi_data/pathetique_2.mid
Parsing classical_midi_data/schubert_D935_2.mid
Parsing classical_midi_data/schubert_D935_3.mid




Parsing classical_midi_data/pathetique_3.mid
Parsing classical_midi_data/mendel_op30_2.mid




Parsing classical_midi_data/mond_2.mid
Parsing classical_midi_data/schubert_D850_3.mid
Parsing classical_midi_data/schub_d760_2.mid
Parsing classical_midi_data/beethoven_les_adieux_1.mid
Parsing classical_midi_data/scn15_3.mid
Parsing classical_midi_data/burg_gewitter.mid
Parsing classical_midi_data/schu_143_2.mid
Parsing classical_midi_data/haydn_33_1.mid
Parsing classical_midi_data/mz_331_2.mid
Parsing classical_midi_data/beethoven_opus22_1.mid
Parsing classical_midi_data/alb_esp3.mid
Parsing classical_midi_data/chpn-p10.mid
Parsing classical_midi_data/chpn_op25_e12.mid
Parsing classical_midi_data/alb_se4.mid
Parsing classical_midi_data/chpn-p6.mid
Parsing classical_midi_data/scn16_6.mid
Parsing classical_midi_data/liz_et2.mid
Parsing classical_midi_data/ty_mai.mid
Parsing classical_midi_data/liz_et_trans5.mid
Parsing classical_midi_data/liz_et_trans4.mid
Parsing classical_midi_data/liz_et3.mid
Parsing classical_midi_data/scn16_7.mid
Parsing classical_midi_data/chpn-p7.mid
Parsing cl



Parsing classical_midi_data/muss_8.mid
Parsing classical_midi_data/schu_143_1.mid
Parsing classical_midi_data/haydn_33_2.mid
Parsing classical_midi_data/mz_331_1.mid
Parsing classical_midi_data/mz_333_3.mid
Parsing classical_midi_data/beethoven_opus22_2.mid
Parsing classical_midi_data/liz_rhap10.mid
Parsing classical_midi_data/chpn-p13.mid
Parsing classical_midi_data/debussy_cc_4.mid




Parsing classical_midi_data/chpn_op25_e11.mid
Parsing classical_midi_data/chpn-p5.mid
Parsing classical_midi_data/alb_se7.mid




Parsing classical_midi_data/burg_agitato.mid
Parsing classical_midi_data/scn16_5.mid
Parsing classical_midi_data/liz_et1.mid
Parsing classical_midi_data/beethoven_hammerklavier_4.mid
Parsing classical_midi_data/scn16_4.mid
Parsing classical_midi_data/alb_se6.mid
Parsing classical_midi_data/chpn-p4.mid
Parsing classical_midi_data/deb_menu.mid
Parsing classical_midi_data/chpn-p12.mid
Parsing classical_midi_data/alb_esp1.mid
Parsing classical_midi_data/mz_333_2.mid
Parsing classical_midi_data/beethoven_opus22_3.mid
Parsing classical_midi_data/haydn_33_3.mid
Parsing classical_midi_data/alb_esp5.mid
Parsing classical_midi_data/liz_rhap15.mid
Parsing classical_midi_data/debussy_cc_1.mid
Parsing classical_midi_data/chpn-p16.mid
Parsing classical_midi_data/alb_se2.mid
Parsing classical_midi_data/haydn_35_1.mid
Parsing classical_midi_data/liz_et4.mid
Parsing classical_midi_data/chpn_op27_2.mid
Parsing classical_midi_data/mz_311_1.mid
Parsing classical_midi_data/beethoven_hammerklavier_1.mid
Par

In [125]:
class MusicDataset(Dataset):
    def __init__(self, notes, sequence_length, n_vocab):
        self.sequence_length = sequence_length
        self.n_vocab = n_vocab

        self.note_to_int = dict((note, number) for number, note in enumerate(sorted(set(notes))))
        self.int_to_note = dict((number, note) for number, note in enumerate(sorted(set(notes))))

        network_input = []
        network_output = []

        for i in range(0, len(notes) - sequence_length - 1):
            sequence_in = notes[i:i + sequence_length]
            sequence_out = notes[i + 1 : i + sequence_length + 1]
            network_input.append([self.note_to_int[char] for char in sequence_in])
            network_output.append([self.note_to_int[char] for char in sequence_out])
        
        self.network_input = torch.tensor(network_input, dtype=torch.long)
        self.network_output = torch.tensor(network_output, dtype=torch.long)
        
        
        self.network_input = nn.functional.one_hot(self.network_input, num_classes=n_vocab)
        self.network_input = self.network_input.to(torch.float32)
        self.network_output = nn.functional.one_hot(self.network_output, num_classes=n_vocab) #label
        self.network_output = self.network_output.to(torch.float32)
       
        

    def __len__(self):
        return len(self.network_input)
        
    
    def __getitem__(self, idx):
        return self.network_input[idx], self.network_output[idx]

class MusicLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MusicLSTM, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, dropout=0.3)
        #self.lstm2 = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        #self.lstm3 = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        l = 512
        self.bn1 = nn.BatchNorm1d(l)
        self.bn2 = nn.BatchNorm1d(l)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(input_size, output_size)
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        out, hidden = self.lstm1(x)
        #print("ot,", out.size())
        out = self.bn1(out)
        out = self.dropout(out)
        out = self.fc1(out)
        out = torch.relu(out)
        out = self.bn2(out)
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.softmax(out)
        return out

In [35]:
def train(model, dataloader, loss_fn, optimizer):
    """ train the neural network """
    path = 'model_weights_new.pt'
    num_epochs = 100
    for epoch in range(num_epochs):
        running_loss = 0.0
        min = 600
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model.forward(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss}")
        if(epoch % 10 == 0 and loss < min):
            torch.save(model, path)
            min  = loss



In [4]:
sequence_length = 10
n_vocab = 222

In [38]:
n_vocab = len(set(notes))
print(n_vocab)
dataset = MusicDataset(notes, sequence_length, n_vocab)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
int_to_note_dict = dataset.int_to_note


222


In [None]:
model = MusicLSTM(input_size=n_vocab, hidden_size=512, output_size=n_vocab)
lossfn = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.0011)
optimizer1 = optim.Adam(model.parameters(), lr=0.0001)
train(model, dataloader, lossfn, optimizer)

In [3]:
def generate_start_input(sequence_len):
    np.random.seed()
    rand = np.random.choice(np.arange(0, 157), size=sequence_len, replace=False)
    seq_list = []
    #print(notes[rand[0]])
    #print(note_to_int)
    for i in range(10): #len(rand)
        seq_list.append(rand[i])
    seq_tens = torch.tensor(seq_list)
    seq_tens = torch.eye(n_vocab)[seq_tens]
    seq_tens = dataset.network_input[rand[2],:]
    return seq_tens

In [2]:
def predict_with_model():
    #path = 'model_weights.pt'
    path =  'model_weights_new.pt'
    model = torch.load(path)
    model.eval()
    num_notes = 5
    start_seq = generate_start_input(sequence_length).unsqueeze(0)
    pred_seq = start_seq.clone()
    #out = model.forward(start_seq)
    #print(out.size())
    #print(out)
    #print("numnote:", num_notes)
    #print("s", ou.size())

    for _ in range(num_notes):
        out = model.forward(start_seq)
        #print("iut", start_seq.squeeze(0).size())
        #print("iut", out[-1][-2].unsqueeze(0).size())
        start_seq = torch.cat((start_seq, out), 0)
        # start_seq = torch.cat((start_seq.squeeze(0), out[-1][-2].unsqueeze(0)), 0)
        # start_seq = torch.cat((start_seq.squeeze(0), out[-1][-1].unsqueeze(0)), 0)
        # start_seq=start_seq.unsqueeze(0)
    # for _ in range(num_notes):
    #     out = model.forward(start_seq)
    #     start_seq = torch.cat((start_seq, out), 0)

        #print("i", _)
        #pred_seq = torch.cat((pred_seq, out), 0)
        #print("sthart seq", start_seq.size())
    return start_seq

In [4]:
# def prob_to_note(seq_tens):
#     seq = []
#     rows, _ = seq_tens.size()
#     for i in range(rows):
#         row = seq_tens[i, :] 
#         print("item", int_to_note_dict[torch.argmax(row).item()])
#         seq.append(int_to_note_dict[torch.argmax(row).item()])
#         print(seq)
#     #print(seq)
#    return seq
# def prob_to_note(seq_tens):
#     seq = []
#     rows,_ = seq_tens.size()
#     print("rows", rows)
#     for i in range(rows):
#         row = seq_tens[i, :] 
#         #print("item", int_to_note_dict[torch.argmax(row).item()])
#         # seq.append(int_to_note_dict[torch.argmax(row).item()])
#         if(i > 0):
#             print("seq", seq[i-2], seq[i - 1],int_to_note_dict[torch.argmax(row).item()])
#             print("tens", torch.argmax(row).item())
#             if seq[i - 2] == seq[i - 1] and seq[i - 1] == int_to_note_dict[torch.argmax(row).item()]:
#                 #rand = np.random.choice(np.arange(0, 157), size=1, replace=False)
#                 #print("otherh", int_to_note_dict[rand.item()])
#                 seq.append(int_to_note_dict[torch.topk(row, k=2)[1][1].item()])
#             else:
#                 seq.append(int_to_note_dict[torch.argmax(row).item()])
#         else:
#             seq.append(int_to_note_dict[torch.argmax(row).item()])

#             print("aooendd", int_to_note_dict[torch.argmax(row).item()])
        
#         #print(seq)
#     #print(seq)
#     return seq
def prob_to_note(seq_tens):
    seq = []
    rows = seq_tens.size()[0]
    vals = [i for i in range(n_vocab)]
    for i in range(rows):
        
        row = seq_tens[i, :][0] 
        #print("row", row.size())
        #row  = torch.topk(row, k=5)[1]
        #top_5 = torch.topk(row, k=8)[1]
        #print("top5", top_5)
        #prob_dist = row[top_5].tolist()
        prob_dist = row.tolist()
        #print("e", prob_dist)

        #prob_dist = row.tolist()
        #total = sum(prob_dist)

        #normalization issues :skull:
        #prob_dist= [val / total for val in prob_dist]
        prob_dist = np.round(prob_dist, decimals=10)  
        prob_dist = prob_dist / sum(prob_dist)
        #prob_dist[-1] += 1 - sum(prob_dist)
        #print("tt",total)
        if(sum(prob_dist) != 1.0):
            #print("probdistsad", sum(prob_dist),prob_dist)
            rand = np.random.choice(vals)
            rand = int_to_note_dict[rand]
            seq.append(rand)
            # if seq[i - 2] == seq[i - 1] and seq[i - 1] == int_to_note_dict[torch.argmax(row).item()]:
            #     seq.append(int_to_note_dict[torch.topk(row, k=2)[1][1].item()])
            # else:
            #     seq.append(int_to_note_dict[torch.argmax(row).item()])
       
            continue
        else:    
        
            #vals = [int_to_note_dict[top_5[idx].item()] for idx in range(8)]
            #print("sisiei", prob_dist)
            rand = np.random.choice(vals, p = prob_dist)
            rand = int_to_note_dict[rand]
            seq.append(rand)
            #print(int_to_note_dict[rand])
            #seq.append(rand)
    return seq


In [5]:
def make_midi(prediction):
    output = []
    offset = 0
    for n in prediction:
        if('.' in n or n.isdigit()):
            chord_list = n.split('.')
            notes = []
            for curr in chord_list:
                newnote = note.Note(int(curr))
                newnote.storedInstrument = instrument.Piano()
                notes.append(newnote)
            newchord = chord.Chord(notes)
            newchord.offset =offset
            output.append(newchord)
        elif (n != ''):
            newnote = note.Note(n)
            newnote.offset = offset
            newnote.storedInstrument = instrument.Piano()
            output.append(newnote)
        offset += 0.5
    midistream = stream.Stream(output)
    return midistream

In [6]:
## make the prediciton
seq = predict_with_model()

AttributeError: Can't get attribute 'MusicLSTM' on <module '__main__'>

In [7]:
#print(prob_to_note(seq))
prediction  = prob_to_note(seq)
midistream = make_midi(prediction)
midistream.write('midi', fp='pred_test4.mid')

NameError: name 'seq' is not defined

In [8]:
def play_midi(midi_file):
    pygame.mixer.init()
    pygame.mixer.music.load(midi_file)
    pygame.mixer.music.play()
    while pygame.mixer.music.get_busy():
        continue

In [9]:
##end up converging to 1 note when using argmax:
play_midi('pred.mid')
#play_midi('pred1.mid')
#play_midi('pred2.mid')

In [7]:
#probability distribution

#play_midi('pred3.mid')
play_midi('pred3.5.mid')

In [10]:
##current
#play_midi('pred_test.mid')
#play_midi('pred_test1.mid')
#play_midi('pred_test2.mid')
play_midi('pred_test3.mid') # <3 this one sounds better than some of the others :)
#play_midi('pred_test4.mid')
#play_midi('pred_test5.mid')