In [9]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_augmented_mag"
folder = "dynamic_mag_models/2enc_2nd"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [10]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/main").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/piano").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/piano_predict").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/intro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/outro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/predict").mkdir(parents=True, exist_ok=True)

In [11]:
event2word, word2event = pickle.load(open('dictionary_augmented.pkl', 'rb'))

In [12]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [13]:
# Fields

main_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('main', main_field), ('piano', piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 1
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.main),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.main),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.main),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

main_field.build_vocab(train, min_freq=1)
piano_field.build_vocab(train, min_freq=1)

In [6]:
big = []
for ((main, main_len), (piano, piano_len)), _ in (train_iter):
    #print(intro.transpose(1,0).size(0))
    print(piano_len.cpu().item())

1823
2171
964
1720
1643
906
1580
709
1511
1662
1920
1187
1004
631
1520
1442
967
606
1095
1828
810
819
1180
1284
1187
1999
1420
1644
1508
1026
2006
1199
1252
1319
1447
986
2587
828
1960
1798
1434
1020
1026
1321
1389
3668
1846
1284
439
1690
2627
1968
1236
1723
1580
1351
355
929
939
1502
1948
1304
1763
638
439
3023
541
1375
2504
1656
1509
1303
1350
1524
1014
1546
940
1319
2120
739
1429
1438
638
1447
1113
1250
1309
1070
1558
541
577
784
1876
1080
858
1750
819
2006
739
2639
1684
1111
914
2619
1462
1216
2146
2094
672
986
1509
1776
1071
868
556
3528
1004
1928
979
1856
1855
1598
1502
1166
1264
1520
1993
969
1014
1211
858
1720
1442
1789
1743
1960
1111
1082
1322
1439
1558
2627
1059
638
3163
1422
891
1631
2023
1374
509
726
854
1020
1194
868
2293
931
1765
1935
1163
800
546
1350
2032
1026
1715
1250
2619
1366
1534
1171
1024
1337
467
1443
1379
1014
1303
1765
541
898
1026
1765
1802
1731
930
657
1935
1343
895
1963
2067
509
825
637
439
1978
1180
1654
1690
1216
546
2278
831
1547
1647
1116
2006
946
1569
7

1026
1654
828
825
1442
1176
2892
1580
843
979
1098
1420
1776
1371
1086
1559
1180
828
1724
589
1690
636
1463
1098
1890
530
1799
1509
970
970
1798
1643
1799
1022
1746
1429
967
838
2000
1558
1468
1956
705
1248
1508
939
2017
923
1022
931
2171
1731
878
550
2031
2792
1751
1383
883
1774
1690
1235
1335
1941
869
1314
839
1304
1105
931
1543
1321
1791
1313
1113
759
1122
633
1070
1968
633
949
1603
798
1015
1503
804
1422
1846
1211
1098
2429
804
837
3827
1569
1791
1355
1511
955
2619
1746
1319
1993
1113
1765
928
1520
1932
787
1830
2088
1720
837
955
1211
1508
2927
1086
2094
804
2151
2120
2379
2429
923
1082
2699
1999
1187
1329
606
1765
1080
729
714
1383
1439
2429
1434
2146
3044
1030
2398
1538
2067
1376
974
1502
589
1463
1539
3783
1690
822
1998
1770
1059
1647
757
836
2792
814
1835
986
1684
1577
396
530
1122
845
970
1383
822
1376
1776
1439
1236
1594
804
986
1174
834
1509
1720
530
1750
1422
1978
1231
1770
1166
1799
906
2151
637
869
1572
1141
1515
2067
1374
2112
1429
1178
2041
1647
886
1610
1967
1116
746
9

In [14]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [15]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [16]:
src_vocab_size = len(main_field.vocab)
trg_vocab_size = len(piano_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 3000
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)


In [17]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=1e-5) #non augmented 3e-4


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 12,969,244 trainable parameters


In [18]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [19]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [20]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((main, main_len), (piano, piano_len)), _ in (iterator):
        if piano_len.cpu().item()>=3000:
            continue
        src, trg = main.transpose(1,0), piano.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)
        loss = criterion(output, trg)
#         print(torch.isfinite(trg).all().cpu().item())
#         print(torch.isfinite(output).all().cpu().item())
#         print(torch.isfinite(loss).all().cpu().item())
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((main, main_len), (piano, piano_len)), _ in (iterator):
            if piano_len.cpu().item()>=3000:
                continue
            src, trg = main.transpose(1,0), piano.transpose(1,0)
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [21]:
def translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [14]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_main = df_intro['main'].values
val_piano = df_intro['piano'].values
val_data=[]
for i in range(len(val_main)):
    temp_dict = {}
    temp_dict['main'] = val_main[i]
    temp_dict['piano'] = val_piano[i]
    val_data.append(temp_dict)
print(len(val_piano))

112


In [15]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(3):
        main = val_main[i]
        piano = val_piano[i]
        #print(intro)
        list_main = [int(x) for x in main.split(' ')]
        list_piano = [int(x) for x in piano.split(' ')]
        translated_sentence = translate_sentence(model, main, main_field, piano_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [19]:
N_EPOCHS = 2000
S_EPOCH = 0
CLIP = 1

train_loss_log = []
valid_loss_log = []
best_valid_loss = float('inf')
#torch.autograd.set_detect_anomaly(True)
#model = nn.DataParallel(model, device_ids=[0,1]).to(device)
for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    
    train_loss_log.append(train_loss)
    valid_loss_log.append(valid_loss)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+1) % 25 ==0:
        if check_mode_collapse(model) > 1:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 1982 | Time: 5m 40s
	Train Loss: 0.091 | Train PPL:   1.096
	 Val. Loss: 8.081 |  Val. PPL: 3231.108
=> Saving checkpoint
Epoch: 1983 | Time: 5m 41s
	Train Loss: 0.091 | Train PPL:   1.096
	 Val. Loss: 8.037 |  Val. PPL: 3092.142
=> Saving checkpoint
Epoch: 1984 | Time: 5m 41s
	Train Loss: 0.092 | Train PPL:   1.096
	 Val. Loss: 8.047 |  Val. PPL: 3123.171
Epoch: 1985 | Time: 5m 40s
	Train Loss: 0.091 | Train PPL:   1.096
	 Val. Loss: 8.065 |  Val. PPL: 3181.441
Epoch: 1986 | Time: 5m 41s
	Train Loss: 0.091 | Train PPL:   1.095
	 Val. Loss: 8.093 |  Val. PPL: 3270.876
Epoch: 1987 | Time: 5m 40s
	Train Loss: 0.091 | Train PPL:   1.095
	 Val. Loss: 8.084 |  Val. PPL: 3240.840
Epoch: 1988 | Time: 5m 41s
	Train Loss: 0.091 | Train PPL:   1.096
	 Val. Loss: 8.035 |  Val. PPL: 3087.370
=> Saving checkpoint
Epoch: 1989 | Time: 5m 41s
	Train Loss: 0.091 | Train PPL:   1.095
	 Val. Loss: 8.058 |  Val. PPL: 3159.673
Epoch: 1990 | Time: 5m 41s
	Train Loss: 0.090 | Train PPL:   1.095
	 Val.

[0, 1, 2, 162, 1, 54, 30, 7, 4, 54, 26, 7, 4, 64, 6, 31, 8, 53, 6, 31, 10, 35, 33, 31, 10, 64, 6, 7, 10, 64, 9, 7, 91, 63, 33, 31, 13, 53, 50, 31, 13, 32, 6, 7, 16, 54, 12, 7, 17, 53, 6, 7, 27, 54, 50, 7, 27, 51, 57, 31, 0, 1, 32, 30, 15, 1, 53, 50, 31, 23, 35, 6, 52, 8, 51, 57, 31, 8, 63, 30, 41, 10, 35, 33, 31, 10, 49, 50, 7, 13, 53, 6, 7, 13, 35, 33, 7, 16, 54, 6, 7, 78, 123, 48, 52, 17, 45, 6, 31, 90, 49, 33, 15, 27, 35, 6, 22, 27, 35, 57, 7, 0, 1, 53, 33, 15, 1, 53, 50, 7, 4, 71, 50, 31, 23, 54, 9, 15, 8, 51, 30, 31, 8, 63, 50, 15, 10, 49, 57, 7, 10, 54, 50, 7, 10, 54, 57, 7, 10, 53, 30, 15, 13, 35, 50, 15, 13, 63, 9, 7, 16, 53, 98, 7, 16, 53, 30, 7, 78, 71, 12, 7, 17, 65, 12, 7, 90, 51, 111, 15, 74, 35, 33, 15, 0, 1, 51, 21, 7, 1, 53, 33, 15, 1, 53, 6, 15, 4, 35, 57, 15, 23, 53, 33, 15, 23, 49, 6, 31, 8, 53, 57, 7, 10, 53, 33, 52, 10, 51, 6, 15, 91, 54, 21, 7, 13, 53, 50, 22, 16, 35, 6, 15, 16, 53, 21, 7, 78, 35, 111, 60, 17, 53, 33, 52, 90, 35, 6, 19, 90, 54, 33, 41, 27, 45, 6, 

In [18]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
output = open(folder + "train_loss_log.pkl", 'wb')
pickle.dump(train_loss_log, output)
output.close()

output = open(folder + "valid_loss_log.pkl", 'wb')
pickle.dump(valid_loss_log, output)
output.close()

In [19]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [None]:
state = torch.load(destination_folder + '/1000_checkpoint.pt', map_location=device)
load_checkpoint(state, model, optimizer)

In [None]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

In [None]:
generated_outputs = folder +  "/generated_samples_1000epochs"
Path(generated_outputs+"/main").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/piano").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/piano_predict").mkdir(parents=True, exist_ok=True)

In [None]:

df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_main = df_intro['main'].values
test_piano = df_intro['piano'].values
test_data=[]
for i in range(len(val_main)):
    temp_dict = {}
    temp_dict['main'] = test_main[i]
    temp_dict['piano'] = test_piano[i]
    test_data.append(temp_dict)
print(len(test_piano))

In [None]:
for i in range(0,len(test_piano)):
#     if len(test_intro) > 1200:
#         continue
    main = test_main[i]
    piano = test_piano[i]
    list_main = [int(x) for x in main.split(' ')]
    list_piano = [int(x) for x in piano.split(' ')]
    
    translated_sentence = translate_sentence(model, main, main_field, piano_field, device, max_length=3000)
        #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_main, word2event, generated_outputs + "/main/" + "/main" + str(i)  + ".mid")
    utils.write_midi(list_piano, word2event, generated_outputs  + "/piano/" + "/piano" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/piano_predict/" + "/piano_predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


In [None]:
import mido
for i in range(11):
    piano = mido.MidiFile(generated_outputs + "/piano/" + '/piano' + str(i) + '.mid')
    main = mido.MidiFile(generated_outputs + "/main/" +'/main' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/piano_predict/" +'/piano_predict' + str(i) + '.mid')

    piano.tracks[1].name = "piano"
    main.tracks[1].name = "main"
    predict.tracks[1].name = "piano_predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = main.ticks_per_beat
    merged_mid.tracks = piano.tracks + main.tracks 
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = main.ticks_per_beat
    merged_mid.tracks = predict.tracks + main.tracks 
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

In [26]:
# dissimilar_interpolation
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    #solo = test_solo[i]
    if i + 3 < (len(test_intro)):
        outro = test_outro[i+3]
    else:
        outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    #list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, dissimilar_interpolation + "/intro/" + "/intro" + str(i)  + ".mid")
    #utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, dissimilar_interpolation + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, dissimilar_interpolation + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


NameError: name 'test_intro' is not defined

In [None]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(dissimilar_interpolation + "/intro/" + '/intro' + str(i) + '.mid')
    outro = mido.MidiFile(dissimilar_interpolation + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(dissimilar_interpolation + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(dissimilar_interpolation + '/merged_predict' + str(i) + '.mid')

In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns