In [33]:
import os
import pandas as pd 
import torch.nn as nn
import torch
import torch.nn.functional as F
from music21 import *
import random

# Data Pre-Processing

### Transfroming the data into more readable input to the model

In [423]:
def get_pitch_class(note):
    return note % 12

def find_matching_octave_note(df):
    bass_line = df['note3'].values
    last_bass_note = bass_line[-1]
    return last_bass_note

def explore_for_lowest_tonic(df, pitch_class):
    bass_notes = df['note3'].values  
    matching_notes = [note for note in bass_notes if get_pitch_class(note) == pitch_class]
    if matching_notes:
        return min(matching_notes) 
    else:
        return bass_notes[0]

def detect_tonic(df):
    candidate_note = find_matching_octave_note(df)
    pitch_class = get_pitch_class(candidate_note)
    true_tonic_note = explore_for_lowest_tonic(df, pitch_class)
    return true_tonic_note

def is_major(df):
    chord_notes = df.iloc[0, :4].values
    unique_pitch_classes = set(note % 12 for note in chord_notes)
    
    for root in unique_pitch_classes:
        intervals = sorted((note - root) % 12 for note in unique_pitch_classes if note != root)
        if 4 in intervals and 7 in intervals:
            return True
        
        if 3 in intervals and 7 in intervals:
            return False
    
    return False

def key_transposition(df):
    tonic_note = detect_tonic(df)
    transpose_val = 48 if is_major(df) else 45
    transpose_val -= tonic_note 
    df = (df + transpose_val).clip(lower=0, upper=127)
    return df

def encode_song(song):
    result = []
    prev = {'note0': -1, 'note1': -1, 'note2': -1, 'note3': -1}
    result.append('START')
    
    for index, row in song.iterrows():
        for voice in ['note0', 'note1', 'note2', 'note3']:
            pitch = row[voice]
            previous_pitch = prev[voice]
            
            tied = 1 if pitch == previous_pitch else 0
            result.append((pitch,tied))
            prev[voice] = pitch
        result.append('|||')
    result.append('END')
    return result

In [427]:
folder_path = 'Data/'
test_set = []
train_set = []
validation_set = []
for dirname in os.listdir(folder_path):
    if dirname != '.DS_Store':
        for filename in os.listdir(folder_path + dirname):
            if filename != '.ipynb_checkpoints':
                df = pd.read_csv(folder_path + dirname + '/' + filename)
                transpose = key_transposition(df)
                song = encode_song(transpose)
                if dirname == 'test':
                    test_set.append(song)
                if dirname == 'train':
                    train_set.append(song)
                if dirname == 'valid':
                    validation_set.append(song)

In [403]:
class Model(nn.Module):
    def __init__(self, embedding_dim=128, hidden_dim=256, vocab_size=259, num_layers=2, droput_rate=.3):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(in_features=hidden_dim, out_features=vocab_size)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.dropout = nn.Dropout(droput_rate)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.batch_size = 1

    def init_hidden_state(self):
        return (torch.zeros(self.num_layers, self.batch_size, self.hidden_dim),
                torch.zeros(self.num_layers, self.batch_size, self.hidden_dim))
    
    def forward(self, song, mask, hidden=None, teacher_forcing_ratio=None):
        if hidden is None:
            hidden = self.init_hidden_state()
            
        output = []
        x = song[0]
        
        for i in range(len(song)):
            x = self.embedding(torch.tensor(x, dtype=torch.long))
            x = x.unsqueeze(0).unsqueeze(0)
            curr, hidden = self.lstm(x, hidden)
            curr = self.dropout(curr)
            curr = self.fc(curr)
            
            if mask[i]:
                clamped_output = torch.zeros_like(curr)
                clamped_output[0, 0, song[i]] = 1.0
                output.append(clamped_output)
                curr = song[i]
            else: 
                if teacher_forcing_ratio and random.random() < teacher_forcing_ratio:
                    output.append(curr)
                    curr = song[i]
                else:
                    output.append(curr)
                    curr = curr.argmax(dim=2).item()
            x = curr
        output = torch.cat(output, dim=1).view(-1, self.fc.out_features)
        return output, hidden

In [435]:
def compute_mask_training(song):
    result = []
    result.append(True) # start
    for i in range(1, len(song) - 1):
        if i % 5 == 1 or i % 5 == 0:
            result.append(True)
        else:
            result.append(False)
    result.append(True) # end
    return result

def compute_mask_testing(song):
    result = []
    for i in range(len(song)):
        if i % 5 == 1 and i != len(song) - 1:
            result.append(True)
        else:
            result.append(False)
    return result

def embedding_dictionary():
    token_to_index = {("START"): 256, ("END"): 257, ("|||"): 258}
    for note in range(128):
        token_to_index[(note, 0)] = note
        token_to_index[(note, 1)] = note + 128
    return token_to_index

def harmonies_to_zero(song):
    result = []
    for i in range(len(song)):
        if i % 5 == 2 or i % 5 == 3 or i % 5 == 4:
            result.append(-1)
        else:
            result.append(song[i])
    return result
        
token_dictionary = embedding_dictionary()

def train_model(model, optimizer, criterion, num_epochs):
    for index, song in enumerate(train_set):
        model.train()
        melody_mask = compute_mask_training(song)
        embeded_song = [token_dictionary[token] for token in song]
        hidden = None
        for epoch in range(num_epochs):
            teacher_forcing_rate = max(0.5 * (1 - epoch / num_epochs), 0)
            optimizer.zero_grad()
            output, hidden = model(embeded_song, melody_mask, hidden, teacher_forcing_rate)
            loss = criterion(output, torch.tensor(embeded_song))
            loss.backward()
            hidden = tuple(h.detach() for h in hidden)
            optimizer.step()
            if (epoch + 1) % 10 == 0:
                print(f"Song {index + 1}, Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
                
        with torch.no_grad():
            model.eval()
            total_val_loss = 0
            for val_song in validation_set:
                val_embeded_song = [token_dictionary[token] for token in val_song]
                val_melody_mask = compute_mask_testing(val_song)
                val_input_song = harmonies_to_zero(val_embeded_song)
                val_output, _ = model(val_input_song, val_melody_mask)
                val_loss = criterion(val_output, torch.tensor(val_embeded_song))
                total_val_loss += val_loss.item()
            print(f'Validation loss: {total_val_loss / len(validation_set)}')

In [436]:
criterion = nn.CrossEntropyLoss()
num_epochs = 100

model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
train_model(model, optimizer, criterion, num_epochs)

Song 1, Epoch 10/100, Loss: 5.130405426025391
Song 1, Epoch 20/100, Loss: 5.092310428619385
Song 1, Epoch 30/100, Loss: 4.952627182006836
Song 1, Epoch 40/100, Loss: 4.466960430145264
Song 1, Epoch 50/100, Loss: 4.135542869567871
Song 1, Epoch 60/100, Loss: 4.002730369567871
Song 1, Epoch 70/100, Loss: 3.9401936531066895
Song 1, Epoch 80/100, Loss: 3.913753032684326
Song 1, Epoch 90/100, Loss: 3.8935158252716064
Song 1, Epoch 100/100, Loss: 3.896516799926758
Validation loss: 5.257776847252479
Song 2, Epoch 10/100, Loss: 4.452261447906494
Song 2, Epoch 20/100, Loss: 4.230161666870117
Song 2, Epoch 30/100, Loss: 4.100407123565674
Song 2, Epoch 40/100, Loss: 4.030731678009033
Song 2, Epoch 50/100, Loss: 3.964204788208008
Song 2, Epoch 60/100, Loss: 3.9149467945098877
Song 2, Epoch 70/100, Loss: 3.907349109649658
Song 2, Epoch 80/100, Loss: 3.8698604106903076
Song 2, Epoch 90/100, Loss: 3.8429019451141357
Song 2, Epoch 100/100, Loss: 3.825822353363037
Validation loss: 5.2677854758042555
So

Song 17, Epoch 40/100, Loss: 3.4278409481048584
Song 17, Epoch 50/100, Loss: 3.329632520675659
Song 17, Epoch 60/100, Loss: 3.303605318069458
Song 17, Epoch 70/100, Loss: 3.2180237770080566
Song 17, Epoch 80/100, Loss: 3.167788028717041
Song 17, Epoch 90/100, Loss: 3.10965633392334
Song 17, Epoch 100/100, Loss: 3.0891387462615967
Validation loss: 5.316176011012151
Song 18, Epoch 10/100, Loss: 3.4314444065093994
Song 18, Epoch 20/100, Loss: 3.2802908420562744
Song 18, Epoch 30/100, Loss: 3.217468500137329
Song 18, Epoch 40/100, Loss: 3.152625560760498
Song 18, Epoch 50/100, Loss: 3.126762628555298
Song 18, Epoch 60/100, Loss: 3.1050939559936523
Song 18, Epoch 70/100, Loss: 3.125760316848755
Song 18, Epoch 80/100, Loss: 3.092864513397217
Song 18, Epoch 90/100, Loss: 3.106015682220459
Song 18, Epoch 100/100, Loss: 3.0159785747528076
Validation loss: 5.313482443491618
Song 19, Epoch 10/100, Loss: 3.05900502204895
Song 19, Epoch 20/100, Loss: 2.9462759494781494
Song 19, Epoch 30/100, Loss: 

Song 33, Epoch 50/100, Loss: 2.5827622413635254
Song 33, Epoch 60/100, Loss: 2.5509445667266846
Song 33, Epoch 70/100, Loss: 2.5050785541534424
Song 33, Epoch 80/100, Loss: 2.4920217990875244
Song 33, Epoch 90/100, Loss: 2.476952314376831
Song 33, Epoch 100/100, Loss: 2.413698673248291
Validation loss: 5.204314134059808
Song 34, Epoch 10/100, Loss: 3.333207845687866
Song 34, Epoch 20/100, Loss: 3.1057770252227783
Song 34, Epoch 30/100, Loss: 2.9956772327423096
Song 34, Epoch 40/100, Loss: 2.943277359008789
Song 34, Epoch 50/100, Loss: 2.8764724731445312
Song 34, Epoch 60/100, Loss: 2.7910408973693848
Song 34, Epoch 70/100, Loss: 2.8050358295440674
Song 34, Epoch 80/100, Loss: 2.7748606204986572
Song 34, Epoch 90/100, Loss: 2.67505145072937
Song 34, Epoch 100/100, Loss: 2.6602509021759033
Validation loss: 5.113712151845296
Song 35, Epoch 10/100, Loss: 3.1089794635772705
Song 35, Epoch 20/100, Loss: 2.934392213821411
Song 35, Epoch 30/100, Loss: 2.8424549102783203
Song 35, Epoch 40/100, 

Song 49, Epoch 60/100, Loss: 2.564070463180542
Song 49, Epoch 70/100, Loss: 2.477813959121704
Song 49, Epoch 80/100, Loss: 2.415719747543335
Song 49, Epoch 90/100, Loss: 2.774338960647583
Song 49, Epoch 100/100, Loss: 2.5609681606292725
Validation loss: 5.451343780908829
Song 50, Epoch 10/100, Loss: 2.905886650085449
Song 50, Epoch 20/100, Loss: 2.746979236602783
Song 50, Epoch 30/100, Loss: 2.585930585861206
Song 50, Epoch 40/100, Loss: 2.5513999462127686
Song 50, Epoch 50/100, Loss: 2.570844888687134
Song 50, Epoch 60/100, Loss: 2.454329013824463
Song 50, Epoch 70/100, Loss: 2.415555953979492
Song 50, Epoch 80/100, Loss: 2.4026219844818115
Song 50, Epoch 90/100, Loss: 2.3212010860443115
Song 50, Epoch 100/100, Loss: 2.3447906970977783
Validation loss: 5.692408732878856
Song 51, Epoch 10/100, Loss: 2.9502484798431396
Song 51, Epoch 20/100, Loss: 2.7529594898223877
Song 51, Epoch 30/100, Loss: 2.641278028488159
Song 51, Epoch 40/100, Loss: 2.5414366722106934
Song 51, Epoch 50/100, Loss

Song 65, Epoch 70/100, Loss: 2.1273093223571777
Song 65, Epoch 80/100, Loss: 2.0783438682556152
Song 65, Epoch 90/100, Loss: 2.060067653656006
Song 65, Epoch 100/100, Loss: 2.0754189491271973
Validation loss: 5.3472137451171875
Song 66, Epoch 10/100, Loss: 3.1140754222869873
Song 66, Epoch 20/100, Loss: 2.948441505432129
Song 66, Epoch 30/100, Loss: 2.804424285888672
Song 66, Epoch 40/100, Loss: 2.812666416168213
Song 66, Epoch 50/100, Loss: 2.7902321815490723
Song 66, Epoch 60/100, Loss: 2.6508398056030273
Song 66, Epoch 70/100, Loss: 2.679304838180542
Song 66, Epoch 80/100, Loss: 2.5959298610687256
Song 66, Epoch 90/100, Loss: 2.677704334259033
Song 66, Epoch 100/100, Loss: 2.6936075687408447
Validation loss: 5.190690346253224
Song 67, Epoch 10/100, Loss: 2.7453958988189697
Song 67, Epoch 20/100, Loss: 2.6109609603881836
Song 67, Epoch 30/100, Loss: 2.4645724296569824
Song 67, Epoch 40/100, Loss: 2.39530873298645
Song 67, Epoch 50/100, Loss: 2.3206684589385986
Song 67, Epoch 60/100, 

Song 81, Epoch 80/100, Loss: 2.783277988433838
Song 81, Epoch 90/100, Loss: 2.650186538696289
Song 81, Epoch 100/100, Loss: 2.6408042907714844
Validation loss: 6.278602795723157
Song 82, Epoch 10/100, Loss: 2.865678548812866
Song 82, Epoch 20/100, Loss: 2.700515031814575
Song 82, Epoch 30/100, Loss: 2.5390501022338867
Song 82, Epoch 40/100, Loss: 2.529897689819336
Song 82, Epoch 50/100, Loss: 2.4613358974456787
Song 82, Epoch 60/100, Loss: 2.3003201484680176
Song 82, Epoch 70/100, Loss: 2.2904419898986816
Song 82, Epoch 80/100, Loss: 2.3382039070129395
Song 82, Epoch 90/100, Loss: 2.3599467277526855
Song 82, Epoch 100/100, Loss: 2.2247185707092285
Validation loss: 5.4277281027573805
Song 83, Epoch 10/100, Loss: 3.4736812114715576
Song 83, Epoch 20/100, Loss: 3.25213623046875
Song 83, Epoch 30/100, Loss: 3.0975759029388428
Song 83, Epoch 40/100, Loss: 3.0515809059143066
Song 83, Epoch 50/100, Loss: 2.9519479274749756
Song 83, Epoch 60/100, Loss: 2.94646954536438
Song 83, Epoch 70/100, L

Song 97, Epoch 90/100, Loss: 2.2260072231292725
Song 97, Epoch 100/100, Loss: 2.077014207839966
Validation loss: 5.202111941117507
Song 98, Epoch 10/100, Loss: 2.9174857139587402
Song 98, Epoch 20/100, Loss: 2.7222065925598145
Song 98, Epoch 30/100, Loss: 2.6156959533691406
Song 98, Epoch 40/100, Loss: 2.4559733867645264
Song 98, Epoch 50/100, Loss: 2.404783010482788
Song 98, Epoch 60/100, Loss: 2.2859296798706055
Song 98, Epoch 70/100, Loss: 2.2675702571868896
Song 98, Epoch 80/100, Loss: 2.166391611099243
Song 98, Epoch 90/100, Loss: 2.171147584915161
Song 98, Epoch 100/100, Loss: 2.119218587875366
Validation loss: 5.184339462182461
Song 99, Epoch 10/100, Loss: 2.6039583683013916
Song 99, Epoch 20/100, Loss: 2.3762056827545166
Song 99, Epoch 30/100, Loss: 2.3320000171661377
Song 99, Epoch 40/100, Loss: 2.2072043418884277
Song 99, Epoch 50/100, Loss: 2.1669564247131348
Song 99, Epoch 60/100, Loss: 2.141710042953491
Song 99, Epoch 70/100, Loss: 2.0739381313323975
Song 99, Epoch 80/100,

Song 113, Epoch 70/100, Loss: 2.281986713409424
Song 113, Epoch 80/100, Loss: 2.2576723098754883
Song 113, Epoch 90/100, Loss: 2.2269372940063477
Song 113, Epoch 100/100, Loss: 2.276031017303467
Validation loss: 5.555446453583547
Song 114, Epoch 10/100, Loss: 2.7253403663635254
Song 114, Epoch 20/100, Loss: 2.4726216793060303
Song 114, Epoch 30/100, Loss: 2.4512007236480713
Song 114, Epoch 40/100, Loss: 2.265734910964966
Song 114, Epoch 50/100, Loss: 2.230027437210083
Song 114, Epoch 60/100, Loss: 2.1512367725372314
Song 114, Epoch 70/100, Loss: 2.113414764404297
Song 114, Epoch 80/100, Loss: 2.1385786533355713
Song 114, Epoch 90/100, Loss: 2.1758341789245605
Song 114, Epoch 100/100, Loss: 2.0423777103424072
Validation loss: 5.425928702721229
Song 115, Epoch 10/100, Loss: 3.167015790939331
Song 115, Epoch 20/100, Loss: 2.9049220085144043
Song 115, Epoch 30/100, Loss: 2.801238775253296
Song 115, Epoch 40/100, Loss: 2.640796661376953
Song 115, Epoch 50/100, Loss: 2.6218273639678955
Song 

Song 129, Epoch 40/100, Loss: 2.1888844966888428
Song 129, Epoch 50/100, Loss: 2.119743824005127
Song 129, Epoch 60/100, Loss: 2.08927321434021
Song 129, Epoch 70/100, Loss: 2.0436596870422363
Song 129, Epoch 80/100, Loss: 2.0119948387145996
Song 129, Epoch 90/100, Loss: 2.0042057037353516
Song 129, Epoch 100/100, Loss: 1.9751107692718506
Validation loss: 5.584557997874724
Song 130, Epoch 10/100, Loss: 2.6023776531219482
Song 130, Epoch 20/100, Loss: 2.40226149559021
Song 130, Epoch 30/100, Loss: 2.2361466884613037
Song 130, Epoch 40/100, Loss: 2.169405698776245
Song 130, Epoch 50/100, Loss: 2.0933260917663574
Song 130, Epoch 60/100, Loss: 2.086728096008301
Song 130, Epoch 70/100, Loss: 2.0598251819610596
Song 130, Epoch 80/100, Loss: 2.016554117202759
Song 130, Epoch 90/100, Loss: 1.985365390777588
Song 130, Epoch 100/100, Loss: 1.991497278213501
Validation loss: 5.587650604737111
Song 131, Epoch 10/100, Loss: 2.7070236206054688
Song 131, Epoch 20/100, Loss: 2.5185306072235107
Song 13

Song 145, Epoch 20/100, Loss: 2.3763246536254883
Song 145, Epoch 30/100, Loss: 2.23172664642334
Song 145, Epoch 40/100, Loss: 2.1428284645080566
Song 145, Epoch 50/100, Loss: 2.073235273361206
Song 145, Epoch 60/100, Loss: 2.0821597576141357
Song 145, Epoch 70/100, Loss: 2.053102970123291
Song 145, Epoch 80/100, Loss: 2.0065369606018066
Song 145, Epoch 90/100, Loss: 2.0133254528045654
Song 145, Epoch 100/100, Loss: 2.048579216003418
Validation loss: 5.4538451463748245
Song 146, Epoch 10/100, Loss: 2.7531323432922363
Song 146, Epoch 20/100, Loss: 2.5405914783477783
Song 146, Epoch 30/100, Loss: 2.424776554107666
Song 146, Epoch 40/100, Loss: 2.396646499633789
Song 146, Epoch 50/100, Loss: 2.3872439861297607
Song 146, Epoch 60/100, Loss: 2.350383758544922
Song 146, Epoch 70/100, Loss: 2.375727891921997
Song 146, Epoch 80/100, Loss: 2.512242555618286
Song 146, Epoch 90/100, Loss: 2.707534074783325
Song 146, Epoch 100/100, Loss: 2.744354724884033
Validation loss: 5.298325734260755
Song 147

Song 160, Epoch 100/100, Loss: 1.9400337934494019
Validation loss: 5.391647791251158
Song 161, Epoch 10/100, Loss: 2.542086124420166
Song 161, Epoch 20/100, Loss: 2.3902854919433594
Song 161, Epoch 30/100, Loss: 2.2297585010528564
Song 161, Epoch 40/100, Loss: 2.141348361968994
Song 161, Epoch 50/100, Loss: 2.1003146171569824
Song 161, Epoch 60/100, Loss: 2.0909922122955322
Song 161, Epoch 70/100, Loss: 2.0662879943847656
Song 161, Epoch 80/100, Loss: 1.97566819190979
Song 161, Epoch 90/100, Loss: 1.9639302492141724
Song 161, Epoch 100/100, Loss: 1.9563082456588745
Validation loss: 5.7170843711266155
Song 162, Epoch 10/100, Loss: 2.9033761024475098
Song 162, Epoch 20/100, Loss: 2.6658692359924316
Song 162, Epoch 30/100, Loss: 2.6504812240600586
Song 162, Epoch 40/100, Loss: 2.383074998855591
Song 162, Epoch 50/100, Loss: 2.2969932556152344
Song 162, Epoch 60/100, Loss: 2.3392090797424316
Song 162, Epoch 70/100, Loss: 2.180025100708008
Song 162, Epoch 80/100, Loss: 2.2571539878845215
So

Song 176, Epoch 80/100, Loss: 2.0888662338256836
Song 176, Epoch 90/100, Loss: 2.1289010047912598
Song 176, Epoch 100/100, Loss: 2.039914131164551
Validation loss: 6.193681778051914
Song 177, Epoch 10/100, Loss: 2.7983832359313965
Song 177, Epoch 20/100, Loss: 2.608968734741211
Song 177, Epoch 30/100, Loss: 2.359121799468994
Song 177, Epoch 40/100, Loss: 2.2636873722076416
Song 177, Epoch 50/100, Loss: 2.2191624641418457
Song 177, Epoch 60/100, Loss: 2.146604061126709
Song 177, Epoch 70/100, Loss: 2.0900962352752686
Song 177, Epoch 80/100, Loss: 2.10478138923645
Song 177, Epoch 90/100, Loss: 2.2173988819122314
Song 177, Epoch 100/100, Loss: 2.136683225631714
Validation loss: 5.731310428717197
Song 178, Epoch 10/100, Loss: 2.7857909202575684
Song 178, Epoch 20/100, Loss: 2.4569222927093506
Song 178, Epoch 30/100, Loss: 2.4109537601470947
Song 178, Epoch 40/100, Loss: 2.2389206886291504
Song 178, Epoch 50/100, Loss: 2.218069314956665
Song 178, Epoch 60/100, Loss: 2.173999786376953
Song 1

Song 192, Epoch 60/100, Loss: 2.3523566722869873
Song 192, Epoch 70/100, Loss: 2.467066526412964
Song 192, Epoch 80/100, Loss: 2.2674787044525146
Song 192, Epoch 90/100, Loss: 2.555037498474121
Song 192, Epoch 100/100, Loss: 2.7735559940338135
Validation loss: 5.4408362951034155
Song 193, Epoch 10/100, Loss: 2.653865337371826
Song 193, Epoch 20/100, Loss: 2.4617395401000977
Song 193, Epoch 30/100, Loss: 2.4334638118743896
Song 193, Epoch 40/100, Loss: 2.3100688457489014
Song 193, Epoch 50/100, Loss: 2.212916612625122
Song 193, Epoch 60/100, Loss: 2.4801957607269287
Song 193, Epoch 70/100, Loss: 2.171039581298828
Song 193, Epoch 80/100, Loss: 2.197010040283203
Song 193, Epoch 90/100, Loss: 2.1300268173217773
Song 193, Epoch 100/100, Loss: 2.1116268634796143
Validation loss: 5.2234206444177875
Song 194, Epoch 10/100, Loss: 2.6457085609436035
Song 194, Epoch 20/100, Loss: 2.4680166244506836
Song 194, Epoch 30/100, Loss: 2.3112282752990723
Song 194, Epoch 40/100, Loss: 2.18766188621521
Son

Song 208, Epoch 30/100, Loss: 2.6462137699127197
Song 208, Epoch 40/100, Loss: 2.5025784969329834
Song 208, Epoch 50/100, Loss: 2.49867844581604
Song 208, Epoch 60/100, Loss: 2.394202947616577
Song 208, Epoch 70/100, Loss: 2.3592095375061035
Song 208, Epoch 80/100, Loss: 2.4920902252197266
Song 208, Epoch 90/100, Loss: 2.4346730709075928
Song 208, Epoch 100/100, Loss: 2.3932206630706787
Validation loss: 5.008588118430896
Song 209, Epoch 10/100, Loss: 2.913646936416626
Song 209, Epoch 20/100, Loss: 2.564267635345459
Song 209, Epoch 30/100, Loss: 2.4719350337982178
Song 209, Epoch 40/100, Loss: 2.3677451610565186
Song 209, Epoch 50/100, Loss: 2.375013589859009
Song 209, Epoch 60/100, Loss: 2.2880513668060303
Song 209, Epoch 70/100, Loss: 2.3416664600372314
Song 209, Epoch 80/100, Loss: 2.348865270614624
Song 209, Epoch 90/100, Loss: 2.3352348804473877
Song 209, Epoch 100/100, Loss: 2.32293438911438
Validation loss: 5.757574228140024
Song 210, Epoch 10/100, Loss: 2.6911661624908447
Song 2

Validation loss: 5.587624244200877
Song 224, Epoch 10/100, Loss: 3.184915065765381
Song 224, Epoch 20/100, Loss: 2.9878897666931152
Song 224, Epoch 30/100, Loss: 2.7788467407226562
Song 224, Epoch 40/100, Loss: 2.6576452255249023
Song 224, Epoch 50/100, Loss: 2.567413568496704
Song 224, Epoch 60/100, Loss: 2.465985059738159
Song 224, Epoch 70/100, Loss: 2.4660146236419678
Song 224, Epoch 80/100, Loss: 2.497713327407837
Song 224, Epoch 90/100, Loss: 2.3795852661132812
Song 224, Epoch 100/100, Loss: 2.412055730819702
Validation loss: 5.341046614524646
Song 225, Epoch 10/100, Loss: 2.839893341064453
Song 225, Epoch 20/100, Loss: 2.5959208011627197
Song 225, Epoch 30/100, Loss: 2.4522604942321777
Song 225, Epoch 40/100, Loss: 2.3382577896118164
Song 225, Epoch 50/100, Loss: 2.2320010662078857
Song 225, Epoch 60/100, Loss: 2.2378129959106445
Song 225, Epoch 70/100, Loss: 2.1780097484588623
Song 225, Epoch 80/100, Loss: 2.153515100479126
Song 225, Epoch 90/100, Loss: 2.144472122192383
Song 2

Song 239, Epoch 80/100, Loss: 1.9105100631713867
Song 239, Epoch 90/100, Loss: 1.9011367559432983
Song 239, Epoch 100/100, Loss: 1.873628854751587
Validation loss: 5.714274235260793
Song 240, Epoch 10/100, Loss: 2.9128429889678955
Song 240, Epoch 20/100, Loss: 2.640010356903076
Song 240, Epoch 30/100, Loss: 2.586324691772461
Song 240, Epoch 40/100, Loss: 2.495699882507324
Song 240, Epoch 50/100, Loss: 2.4285316467285156
Song 240, Epoch 60/100, Loss: 2.4126713275909424
Song 240, Epoch 70/100, Loss: 2.3372933864593506
Song 240, Epoch 80/100, Loss: 2.2759053707122803
Song 240, Epoch 90/100, Loss: 2.4483449459075928
Song 240, Epoch 100/100, Loss: 2.5394673347473145
Validation loss: 5.6626388720977
Song 241, Epoch 10/100, Loss: 2.628692150115967
Song 241, Epoch 20/100, Loss: 2.4824812412261963
Song 241, Epoch 30/100, Loss: 2.3601675033569336
Song 241, Epoch 40/100, Loss: 2.2238316535949707
Song 241, Epoch 50/100, Loss: 2.3166418075561523
Song 241, Epoch 60/100, Loss: 2.1246907711029053
Song

Song 255, Epoch 50/100, Loss: 1.8783860206604004
Song 255, Epoch 60/100, Loss: 1.8673394918441772
Song 255, Epoch 70/100, Loss: 1.8712520599365234
Song 255, Epoch 80/100, Loss: 1.8624600172042847
Song 255, Epoch 90/100, Loss: 1.858152985572815
Song 255, Epoch 100/100, Loss: 1.8570806980133057
Validation loss: 5.890730087573711
Song 256, Epoch 10/100, Loss: 3.1680221557617188
Song 256, Epoch 20/100, Loss: 2.758331298828125
Song 256, Epoch 30/100, Loss: 2.588280439376831
Song 256, Epoch 40/100, Loss: 2.482109546661377
Song 256, Epoch 50/100, Loss: 2.289813995361328
Song 256, Epoch 60/100, Loss: 2.235259771347046
Song 256, Epoch 70/100, Loss: 2.2005209922790527
Song 256, Epoch 80/100, Loss: 2.1649351119995117
Song 256, Epoch 90/100, Loss: 2.1305315494537354
Song 256, Epoch 100/100, Loss: 2.2444565296173096
Validation loss: 5.998667985965044
Song 257, Epoch 10/100, Loss: 2.580923318862915
Song 257, Epoch 20/100, Loss: 2.466132879257202
Song 257, Epoch 30/100, Loss: 2.3015735149383545
Song 

Song 271, Epoch 20/100, Loss: 2.175088882446289
Song 271, Epoch 30/100, Loss: 2.066060781478882
Song 271, Epoch 40/100, Loss: 2.0383925437927246
Song 271, Epoch 50/100, Loss: 2.0303826332092285
Song 271, Epoch 60/100, Loss: 2.075995445251465
Song 271, Epoch 70/100, Loss: 2.0377919673919678
Song 271, Epoch 80/100, Loss: 2.045344352722168
Song 271, Epoch 90/100, Loss: 1.9167253971099854
Song 271, Epoch 100/100, Loss: 1.8996590375900269
Validation loss: 5.662303912333953
Song 272, Epoch 10/100, Loss: 2.509274482727051
Song 272, Epoch 20/100, Loss: 2.186380386352539
Song 272, Epoch 30/100, Loss: 2.1256942749023438
Song 272, Epoch 40/100, Loss: 2.0039162635803223
Song 272, Epoch 50/100, Loss: 1.9449819326400757
Song 272, Epoch 60/100, Loss: 1.9256575107574463
Song 272, Epoch 70/100, Loss: 1.9106988906860352
Song 272, Epoch 80/100, Loss: 1.9008421897888184
Song 272, Epoch 90/100, Loss: 1.8875426054000854
Song 272, Epoch 100/100, Loss: 1.884566307067871
Validation loss: 6.411707743620261
Song

Song 286, Epoch 100/100, Loss: 2.092130422592163
Validation loss: 5.451711801382212
Song 287, Epoch 10/100, Loss: 2.769756317138672
Song 287, Epoch 20/100, Loss: 2.5983572006225586
Song 287, Epoch 30/100, Loss: 2.4788341522216797
Song 287, Epoch 40/100, Loss: 2.36449933052063
Song 287, Epoch 50/100, Loss: 2.325881004333496
Song 287, Epoch 60/100, Loss: 2.307889461517334
Song 287, Epoch 70/100, Loss: 2.2760701179504395
Song 287, Epoch 80/100, Loss: 2.259772539138794
Song 287, Epoch 90/100, Loss: 2.523972511291504
Song 287, Epoch 100/100, Loss: 2.671234607696533
Validation loss: 5.2324245403974485
Song 288, Epoch 10/100, Loss: 2.4846863746643066
Song 288, Epoch 20/100, Loss: 2.3378403186798096
Song 288, Epoch 30/100, Loss: 2.216465950012207
Song 288, Epoch 40/100, Loss: 2.067338466644287
Song 288, Epoch 50/100, Loss: 2.036956548690796
Song 288, Epoch 60/100, Loss: 2.003566026687622
Song 288, Epoch 70/100, Loss: 1.9592313766479492
Song 288, Epoch 80/100, Loss: 1.9414007663726807
Song 288,

Song 302, Epoch 70/100, Loss: 2.408453941345215
Song 302, Epoch 80/100, Loss: 2.257791757583618
Song 302, Epoch 90/100, Loss: 2.2383949756622314
Song 302, Epoch 100/100, Loss: 2.556452751159668
Validation loss: 5.389361295944605
Song 303, Epoch 10/100, Loss: 3.0978944301605225
Song 303, Epoch 20/100, Loss: 2.8127458095550537
Song 303, Epoch 30/100, Loss: 2.655290126800537
Song 303, Epoch 40/100, Loss: 2.49747633934021
Song 303, Epoch 50/100, Loss: 2.457646131515503
Song 303, Epoch 60/100, Loss: 2.303738832473755
Song 303, Epoch 70/100, Loss: 2.3819849491119385
Song 303, Epoch 80/100, Loss: 2.19207501411438
Song 303, Epoch 90/100, Loss: 2.1922624111175537
Song 303, Epoch 100/100, Loss: 2.1628386974334717
Validation loss: 5.182658672332764
Song 304, Epoch 10/100, Loss: 2.7588255405426025
Song 304, Epoch 20/100, Loss: 2.4888157844543457
Song 304, Epoch 30/100, Loss: 2.391141176223755
Song 304, Epoch 40/100, Loss: 2.3749725818634033
Song 304, Epoch 50/100, Loss: 2.213416337966919
Song 304,

In [430]:
import numpy as np
import torch
from music21 import stream, note

def midi_to_note(part):
    result = stream.Part()
    count = 1
    prev = round(part[0])
    for i in range(1, len(part)):
        curr = round(part[i])
        if curr == prev:
            count += 1
        else:
            result.append(note.Note(prev, quarterLength=count / 4))
            count = 1
        prev = curr
    result.append(note.Note(prev, quarterLength=count / 4))
    return result

def process_sequence(sequence, delimiter_token="|||"):
    index_to_token = {v: k for k, v in token_dictionary.items()}
    original_sequence = [index_to_token[embedded_value] for embedded_value in sequence]
    original_sequence = original_sequence[1:-1]
    melody, alto, tenor, bass = [], [], [], []
    for i in range(0, len(original_sequence), 5):
        melody.append(original_sequence[i][0])
        alto.append(original_sequence[i+1][0])
        tenor.append(original_sequence[i+2][0])
        bass.append(original_sequence[i+3][0])
    
    return melody, alto, tenor, bass

def output_to_sheet_music(result):
    result = torch.argmax(result, dim=-1)
    result = result.squeeze(0)
    melody_notes, alto_notes, tenor_notes, bass_notes = process_sequence(result.numpy())
    print(alto_notes)

    melody_part = midi_to_note(melody_notes)
    alto_part = midi_to_note(alto_notes)
    tenor_part = midi_to_note(tenor_notes)
    bass_part = midi_to_note(bass_notes)

    score = stream.Score()
    score.append(melody_part)
    score.append(alto_part)
    score.append(tenor_part)
    score.append(bass_part)

    score.show('midi')
    score.write('musicxml', 'output.xml')

In [438]:
test_song = test_set[0]
embedded_test = [token_dictionary[token] for token in test_song]
test_mask = compute_mask_testing(test_song)
input_embedded_test = harmonies_to_zero(embedded_test)

output, _ = model(input_embedded_test, test_mask)
output_to_sheet_music(output)

[64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 45, 45, 45, 45, 69, 69, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 69, 69, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 72, 72, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 67, 67, 67, 67, 67, 67, 64, 64, 65, 65, 65, 65, 65, 65, 60, 60, 60, 59, 59, 59, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 60, 60, 60, 60, 60, 60, 62, 62, 62, 62, 62, 62, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 59, 59, 59, 59, 59, 59, 59, 59, 64, 64, 64, 64, 62, 67, 67, 67, 67, 67, 67, 67, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]


In [None]:
import torch
import torch.nn as nn

# Define the model output size as 127 (for each note class)
# Assuming model output shape: (batch_size, seq_len, 3, 127)
output = torch.randn(5, 10, 3, 127)  # Random output for demonstration
print(output)

# Assume target is structured as note indices for alto, tenor, and bass
# Shape of target: (batch_size, seq_len, 3) - with each entry as an index in [0, 126]
target = torch.randint(0, 127, (5, 10, 3))  # Random target for demonstration
print(target)

# CrossEntropyLoss expects shape (N, C) for input and (N) for target, so reshape
criterion = nn.CrossEntropyLoss()

# Reshape output to (batch_size * seq_len * 3, 127) and target to (batch_size * seq_len * 3)
loss = criterion(output.view(-1, 127), target.view(-1))
print(loss)