In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import autoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_fixed_augmented_autoencoder"
folder = "dynamic_fixed_augmented_models/autoencoder"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
vocab = folder + "/vocab"

In [2]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [3]:
event2word, word2event = pickle.load(open('dictionary_fixed_augmented.pkl', 'rb'))

In [4]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [5]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
# intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
# outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
# outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
# solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
# solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
# fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
#           ('outro', outro_field), ('outro_piano', outro_piano_field), \
#           ('solo', solo_field), ('solo_piano', solo_piano_field)]
fields = [('intro', intro_field)]
# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
# intro_piano_field.build_vocab(train, min_freq=1)
# outro_field.build_vocab(train, min_freq=1)
# outro_piano_field.build_vocab(train, min_freq=1)
# solo_field.build_vocab(train, min_freq=1)
# solo_piano_field.build_vocab(train, min_freq=1)

In [6]:
intro_field.vocab["0"]

6

In [7]:
for ((intro, intro_len)), _ in (test_iter):
    print(intro.transpose(1,0).size())

torch.Size([148, 8])
torch.Size([110, 8])
torch.Size([62, 8])
torch.Size([235, 8])
torch.Size([102, 8])
torch.Size([114, 8])
torch.Size([78, 8])
torch.Size([102, 8])
torch.Size([94, 8])
torch.Size([134, 8])
torch.Size([190, 8])
torch.Size([165, 8])
torch.Size([86, 8])
torch.Size([122, 8])


In [8]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [9]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

In [10]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = kk.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
            device=device
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out
    
    def interpolate(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )
            
        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_src2 = self.dropout(
            (self.src_word_embedding(src2) + self.src_position_embedding(src2_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer.interpolate(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out
   
    def interpolate_pad(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )
            
        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_src2 = self.dropout(
            (self.src_word_embedding(src2) + self.src_position_embedding(src2_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer.interpolate_pad(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out    
    
    def interpolate_cat(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )
            
        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_src2 = self.dropout(
            (self.src_word_embedding(src2) + self.src_position_embedding(src2_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer.interpolate_cat(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [11]:
src_vocab_size = len(intro_field.vocab)
trg_vocab_size = len(intro_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [12]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=4e-5)


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 11,118,357 trainable parameters


In [13]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [14]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [15]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len)), _ in (iterator):
        src, trg = intro.transpose(1,0), intro.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len)), _ in (iterator):
            src, trg = intro.transpose(1,0), intro.transpose(1,0)
            src, trg = src.to(device), trg.to(device)
            #print(trg.size())
            output = model(src, trg[:-1, :]) #turn off teacher forcing

            #print(trg[:-1,:].size())
            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)
            #print(trg.size())
            #print(output.size())
            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [16]:
def translate_sentence(model, sentence, sentence2, sentence_vocab, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, sentence_vocab.init_token)
    tokens.append(sentence_vocab.eos_token)

    tokens2.insert(0, sentence_vocab.init_token)
    tokens2.append(sentence_vocab.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [sentence_vocab.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [sentence_vocab.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [sentence_vocab.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model.interpolate_pad(sentence_tensor, sentence2_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == sentence_vocab.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [sentence_vocab.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [17]:
def translate_sentence_padseq(model, sentence, sentence2, sentence_vocab, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, sentence_vocab.init_token)
    tokens2.insert(0, sentence_vocab.init_token)
    
    if len(tokens) < len(tokens2):
        while len(tokens) < len(tokens2):
            tokens.append(sentence_vocab.pad_token)
    else:
        while len(tokens2) < len(tokens):
            tokens2.append(sentence_vocab.pad_token)
    
    tokens.append(sentence_vocab.eos_token)
    tokens2.append(sentence_vocab.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [sentence_vocab.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [sentence_vocab.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [sentence_vocab.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model.interpolate(sentence_tensor, sentence2_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == sentence_vocab.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [sentence_vocab.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [18]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
#val_solo = df_intro['solo'].values
#val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
#     temp_dict['solo'] = val_solo[i]
#     temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

336


In [19]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(5):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, intro, intro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [None]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
#     if (epoch+2) % 20 ==0:
#         if check_mode_collapse(model) > 2:
#             print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 1m 7s
	Train Loss: 4.590 | Train PPL:  98.507
	 Val. Loss: 3.953 |  Val. PPL:  52.105
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 1m 9s
	Train Loss: 3.345 | Train PPL:  28.352
	 Val. Loss: 3.011 |  Val. PPL:  20.303
=> Saving checkpoint
Epoch: 03 | Time: 1m 9s
	Train Loss: 2.724 | Train PPL:  15.241
	 Val. Loss: 2.540 |  Val. PPL:  12.686
=> Saving checkpoint
Epoch: 04 | Time: 1m 9s
	Train Loss: 2.366 | Train PPL:  10.650
	 Val. Loss: 2.207 |  Val. PPL:   9.088
=> Saving checkpoint
Epoch: 05 | Time: 1m 9s
	Train Loss: 2.130 | Train PPL:   8.413
	 Val. Loss: 2.007 |  Val. PPL:   7.439
=> Saving checkpoint
Epoch: 06 | Time: 1m 9s
	Train Loss: 1.979 | Train PPL:   7.232
	 Val. Loss: 1.878 |  Val. PPL:   6.538
=> Saving checkpoint
Epoch: 07 | Time: 1m 9s
	Train Loss: 1.874 | Train PPL:   6.514
	 Val. Loss: 1.797 |  Val. PPL:   6.033
=> Saving checkpoint
Epoch: 08 | Time: 1m 9s
	Train Loss: 1.798 | Train PPL:   6.035
	 Val. Loss: 1.732 |  Val. PPL:   5.652


Epoch: 67 | Time: 1m 9s
	Train Loss: 0.324 | Train PPL:   1.383
	 Val. Loss: 0.393 |  Val. PPL:   1.481
=> Saving checkpoint
Epoch: 68 | Time: 1m 9s
	Train Loss: 0.287 | Train PPL:   1.332
	 Val. Loss: 0.378 |  Val. PPL:   1.459
=> Saving checkpoint
Epoch: 69 | Time: 1m 9s
	Train Loss: 0.257 | Train PPL:   1.293
	 Val. Loss: 0.329 |  Val. PPL:   1.390
=> Saving checkpoint
Epoch: 70 | Time: 1m 9s
	Train Loss: 0.231 | Train PPL:   1.260
	 Val. Loss: 0.334 |  Val. PPL:   1.396
Epoch: 71 | Time: 1m 9s
	Train Loss: 0.208 | Train PPL:   1.231
	 Val. Loss: 0.288 |  Val. PPL:   1.333
=> Saving checkpoint
Epoch: 72 | Time: 1m 9s
	Train Loss: 0.189 | Train PPL:   1.208
	 Val. Loss: 0.284 |  Val. PPL:   1.329
=> Saving checkpoint
Epoch: 73 | Time: 1m 9s
	Train Loss: 0.169 | Train PPL:   1.185
	 Val. Loss: 0.254 |  Val. PPL:   1.290
=> Saving checkpoint
Epoch: 74 | Time: 1m 9s
	Train Loss: 0.153 | Train PPL:   1.166
	 Val. Loss: 0.248 |  Val. PPL:   1.282
=> Saving checkpoint
Epoch: 75 | Time: 1m 

Epoch: 137 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.033 |  Val. PPL:   1.033
Epoch: 138 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.032 |  Val. PPL:   1.032
=> Saving checkpoint
Epoch: 139 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.036 |  Val. PPL:   1.037
Epoch: 140 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.032 |  Val. PPL:   1.033
=> Saving checkpoint
Epoch: 141 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.033 |  Val. PPL:   1.033
=> Saving checkpoint
Epoch: 142 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.032 |  Val. PPL:   1.032
=> Saving checkpoint
Epoch: 143 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.034 |  Val. PPL:   1.035
Epoch: 144 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss: 0.033 |  Val. PPL:   1.034
Epoch: 145 | Time: 1m 9s
	Train Loss: 0.002 | Train PPL:   1.002
	 Val. Loss

In [None]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [20]:
state = torch.load(destination_folder + '/199_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, model, optimizer)

model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint


In [21]:
test_loss = evaluate(model, valid_iter, criterion)
print(math.exp(test_loss))

1.0276374162362931


In [22]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [23]:
for i in range(0,len(test_intro)):
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence_padseq(model, intro, outro, intro_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
    
        


[0, 1, 2, 122, 48, 73, 13, 14, 8, 2, 122, 8, 73, 34, 7, 15, 81, 41, 37, 17, 2, 122, 23, 2, 122, 0, 1, 2, 122, 4, 9, 39, 14, 48, 12, 41, 14, 8, 2, 122, 8, 81, 30, 7, 15, 73, 41, 14, 52, 81, 30, 7, 17, 2, 122, 36, 81, 47, 7, 23, 2, 122, 23, 81, 30, 10, 33, 81, 30, 14, 53, 73, 30, 14, 0, 1, 2, 122, 1, 73, 30, 10, 54, 73, 20, 14, 4, 81, 39, 14, 48, 12, 30, 14, 8, 81, 100, 14, 23, 81, 55, 10, 32, 81, 47, 16, 53, 73, 47, 14, 0, 1, 73, 47, 10, 54, 70, 20, 14, 48, 73, 20, 14, 8, 73, 20, 14, 15, 73, 13, 16, 52, 81, 39, 85, 0, 1, 81, 13, 10, 54, 81, 46, 14, 4, 9, 13, 14, 48, 81, 62, 45, 50, 81, 62, 14, 23, 12, 62, 14, 32, 75, 47, 10, 33, 73, 46, 14, 53, 70, 46, 10, 53, 73, 46, 14, 0, 1, 2, 122, 1, 81, 46, 16, 4, 12, 47, 10, 48, 81, 46, 10, 48, 81, 39, 14, 11, 12, 13, 14]
0
[0, 1, 2, 103, 48, 26, 41, 10, 8, 2, 103, 11, 12, 71, 16, 52, 12, 46, 10, 17, 2, 103, 36, 12, 41, 10, 18, 12, 41, 7, 23, 2, 103, 32, 73, 39, 16, 33, 81, 71, 14, 0, 1, 2, 103, 8, 2, 103, 11, 12, 43, 10, 15, 12, 71, 10, 52, 81, 

[0, 1, 58, 131, 1, 70, 34, 27, 4, 81, 41, 7, 15, 70, 41, 45, 33, 72, 39, 10, 0, 1, 81, 41, 10, 52, 70, 39, 7, 36, 70, 46, 10, 50, 73, 41, 10, 32, 12, 41, 7, 0, 54, 70, 46, 10, 48, 73, 46, 10, 11, 72, 46, 10, 52, 73, 46, 10, 33, 70, 39, 10, 0, 1, 72, 43, 85, 32, 70, 71, 85, 0, 1, 5, 34, 24, 54, 81, 41, 66, 8, 81, 77, 85, 53, 81, 41, 10]
17
[0, 1, 58, 148, 54, 70, 77, 27, 8, 70, 77, 7, 15, 73, 71, 7, 17, 73, 41, 85, 33, 73, 71, 16, 0, 1, 70, 65, 16, 4, 70, 46, 16, 8, 70, 77, 7, 15, 73, 71, 7, 17, 73, 77, 57, 0, 1, 73, 77, 16, 8, 70, 77, 16, 15, 72, 71, 16, 17, 81, 41, 16, 18, 73, 65, 16, 23, 70, 44, 10, 33, 70, 77, 16, 0, 1, 73, 71, 90, 0, 1, 70, 77, 24, 8, 73, 41, 10, 15, 73, 88, 16, 17, 5, 77, 82, 0, 23, 73, 88, 91]
18
[0, 1, 58, 136, 1, 73, 47, 51, 4, 70, 20, 7, 8, 73, 47, 24, 53, 73, 13, 16, 0, 54, 73, 30, 27, 52, 81, 39, 16, 36, 70, 13, 16, 50, 73, 43, 31, 53, 73, 20, 24, 0, 52, 12, 39, 10, 36, 73, 13, 10, 50, 81, 43, 24, 53, 73, 39, 10, 0, 48, 73, 43, 27, 52, 73, 110, 27, 50, 72, 3

[0, 1, 2, 3, 1, 73, 6, 14, 4, 70, 13, 14, 8, 73, 65, 10, 11, 72, 67, 14, 15, 73, 65, 10, 17, 73, 13, 14, 18, 73, 6, 16, 23, 70, 30, 14, 33, 73, 67, 45, 0, 1, 73, 30, 14, 54, 70, 67, 14, 4, 73, 65, 14, 48, 70, 67, 14, 8, 70, 65, 14, 11, 73, 77, 14, 15, 70, 6, 16, 17, 70, 13, 14, 36, 73, 22, 31, 0, 1, 70, 77, 14, 54, 70, 13, 14, 4, 70, 88, 14, 48, 73, 13, 14, 8, 70, 71, 14, 11, 70, 77, 10, 15, 70, 71, 10, 17, 73, 30, 14, 36, 70, 77, 7, 0, 1, 70, 30, 14, 54, 73, 30, 14, 4, 81, 30, 14, 48, 73, 65, 7, 8, 73, 67, 14, 11, 73, 30, 14, 15, 73, 67, 10, 17, 70, 65, 10, 36, 70, 13, 51, 0, 1, 70, 6, 14, 54, 73, 13, 14, 4, 73, 13, 14, 48, 73, 13, 14, 8, 70, 13, 14, 11, 73, 13, 10, 15, 73, 6, 10, 17, 73, 6, 10, 36, 73, 13, 85, 0, 54, 73, 13, 27, 23, 70, 13, 14, 54, 73, 71, 14]
35
[0, 1, 2, 80, 1, 73, 67, 10, 54, 73, 67, 14, 48, 81, 67, 10, 11, 70, 67, 10, 11, 73, 61, 10, 52, 72, 61, 10, 17, 70, 61, 10, 0, 1, 75, 30, 10, 54, 70, 64, 10, 48, 70, 30, 10, 8, 70, 30, 66, 11, 75, 67, 10, 52, 73, 65, 7, 17,

[0, 1, 58, 204, 1, 73, 41, 27, 4, 81, 62, 7, 8, 81, 64, 16, 15, 12, 62, 10, 17, 73, 64, 10, 53, 73, 62, 10, 0, 54, 73, 64, 10, 11, 73, 62, 10, 36, 81, 64, 10, 50, 70, 62, 16, 32, 81, 64, 16, 53, 73, 43, 16, 0, 54, 73, 62, 16, 48, 73, 64, 66, 52, 73, 62, 16, 17, 73, 64, 66, 53, 73, 49, 7, 33, 73, 62, 16, 0, 1, 81, 64, 51, 4, 73, 62, 16, 8, 73, 64, 7, 15, 73, 62, 16, 17, 73, 64, 66, 0, 1, 70, 62, 27, 4, 81, 49, 90, 0, 1, 70, 47, 27, 54, 73, 62, 24, 17, 73, 62, 16, 48, 73, 100, 7, 8, 73, 65, 7]
52
[0, 1, 2, 147, 15, 70, 64, 16, 17, 72, 64, 16, 18, 70, 61, 7, 23, 70, 64, 31, 53, 72, 61, 10, 0, 1, 72, 46, 27, 8, 72, 61, 7, 15, 70, 64, 16, 17, 70, 64, 7, 18, 72, 46, 7, 50, 75, 61, 16, 23, 72, 46, 7, 33, 70, 64, 82, 0, 4, 72, 62, 24, 17, 72, 61, 16, 18, 69, 99, 10, 17, 72, 61, 31, 50, 72, 46, 16, 23, 72, 46, 16, 32, 75, 99, 10, 33, 72, 62, 16, 0, 1, 72, 99, 16, 4, 72, 64, 10, 48, 72, 64, 31, 8, 72, 99, 14, 0, 1, 72, 46, 16, 4, 72, 99, 7, 15, 70, 99, 16, 15, 70, 64, 31, 36, 70, 99, 14, 50, 75,

[0, 1, 2, 107, 1, 26, 65, 10, 4, 9, 46, 16, 8, 35, 46, 66, 15, 35, 47, 16, 18, 40, 46, 7, 33, 35, 65, 16, 0, 1, 26, 46, 16, 4, 35, 39, 31, 15, 29, 65, 16, 0, 1, 19, 20, 16, 4, 9, 64, 16, 8, 26, 20, 66, 18, 40, 20, 16, 23, 40, 64, 16, 33, 35, 47, 16, 0, 4, 26, 62, 10, 8, 19, 62, 7, 15, 70, 49, 51, 33, 9, 49, 10, 0, 1, 19, 62, 16, 4, 29, 64, 16, 8, 35, 62, 27, 18, 5, 62, 10, 23, 9, 62, 16, 33, 5, 62, 10, 50, 9, 49, 21]
70
[0, 1, 2, 107, 1, 9, 25, 24, 4, 5, 22, 16, 8, 29, 20, 10, 15, 12, 20, 7, 17, 9, 22, 10, 18, 12, 22, 16, 23, 12, 25, 31, 0, 1, 12, 13, 27, 53, 29, 39, 7, 0, 54, 5, 13, 7, 15, 26, 22, 31, 17, 5, 20, 16, 36, 73, 22, 7, 53, 12, 55, 16, 0, 54, 5, 55, 66, 50, 9, 56, 16, 32, 9, 13, 7, 53, 29, 20, 16, 0, 54, 26, 20, 7, 48, 12, 22, 16, 11, 5, 55, 10, 52, 5, 20, 27, 50, 9, 25, 7]
71
[0, 1, 58, 101, 48, 29, 65, 7, 8, 58, 101, 15, 9, 6, 7, 52, 12, 61, 16, 17, 58, 101, 18, 12, 25, 7, 23, 58, 101, 23, 5, 64, 16, 33, 12, 61, 16, 0, 1, 58, 101, 1, 9, 6, 66, 8, 58, 101, 17, 58, 101, 23,

[0, 1, 2, 146, 1, 70, 99, 7, 4, 75, 100, 16, 8, 73, 99, 7, 52, 19, 46, 7, 36, 5, 61, 7, 50, 5, 64, 16, 32, 9, 61, 10, 33, 12, 46, 10, 53, 5, 61, 45, 0, 52, 40, 61, 16, 36, 5, 64, 16, 50, 40, 61, 16, 32, 19, 64, 16, 53, 26, 61, 16, 0, 54, 35, 46, 16, 48, 35, 61, 16, 11, 5, 64, 7, 52, 5, 99, 7, 36, 5, 64, 7, 32, 9, 47, 7, 53, 29, 61, 16, 33, 40, 43, 7, 53, 19, 46, 10, 53, 5, 99, 45, 0, 1, 5, 99, 7, 33, 29, 62, 31, 48, 9, 64, 16, 53, 12, 62, 7, 0, 54, 5, 99, 7, 48, 12, 100, 7, 11, 29, 61, 7, 52, 9, 100, 95, 50, 29, 62, 10, 23, 12, 25, 7, 33, 5, 99, 28]
87
[0, 1, 58, 185, 1, 9, 47, 45, 17, 5, 46, 51, 0, 1, 9, 41, 21, 17, 35, 46, 95, 0, 1, 26, 47, 66, 17, 40, 46, 104, 53, 5, 64, 28]
88
[0, 1, 58, 101, 1, 5, 13, 16, 48, 19, 67, 14, 8, 9, 67, 14, 15, 19, 67, 14, 17, 9, 67, 14, 18, 9, 30, 14, 23, 35, 67, 16, 33, 40, 39, 10, 0, 1, 5, 13, 10, 4, 5, 67, 10, 8, 35, 39, 16, 15, 35, 39, 14, 17, 35, 39, 10, 18, 29, 67, 10, 23, 26, 39, 10, 32, 73, 13, 10, 33, 12, 6, 24, 0, 15, 40, 39, 14, 17, 35, 39, 

[0, 1, 2, 80, 1, 29, 30, 14, 54, 35, 30, 14, 48, 35, 41, 14, 11, 35, 39, 14, 15, 35, 41, 16, 18, 35, 20, 14, 50, 35, 20, 10, 23, 35, 41, 10, 32, 9, 13, 10, 33, 35, 39, 16, 0, 1, 35, 39, 7, 4, 26, 13, 10, 8, 35, 39, 14, 11, 35, 41, 24, 17, 35, 47, 7, 23, 35, 47, 14, 32, 5, 13, 14, 33, 35, 39, 10, 0, 1, 35, 20, 10, 4, 5, 13, 14, 8, 5, 22, 14, 11, 29, 13, 16, 52, 35, 47, 24, 23, 40, 20, 14, 32, 19, 20, 14, 33, 35, 20, 10, 0, 1, 35, 47, 14, 54, 70, 39, 10, 8, 35, 47, 14, 11, 29, 20, 14, 15, 35, 47, 10, 17, 35, 39, 14, 36, 40, 30, 7, 32, 26, 39, 14, 33, 35, 43, 10, 53, 35, 34, 14, 0, 54, 9, 30, 14, 4, 35, 39, 10, 48, 29, 41, 14, 8, 35, 39, 14, 15, 35, 41, 10, 36, 35, 39, 14, 32, 35, 65, 14, 33, 5, 77, 14]
106
[0, 1, 2, 134, 54, 19, 65, 7, 4, 35, 41, 7, 52, 35, 46, 16, 36, 40, 65, 16, 50, 35, 43, 16, 32, 9, 43, 16, 53, 35, 43, 10, 0, 54, 35, 43, 10, 48, 19, 71, 16, 11, 19, 71, 16, 52, 35, 43, 16, 36, 9, 71, 16, 50, 26, 71, 16, 23, 26, 71, 16, 53, 9, 71, 16, 0, 54, 9, 41, 16, 11, 35, 41, 7, 5

In [24]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

25080
25080
20160
20160
29880
29880
21360
21360
23760
23760
20640
20640
35820
35820
19920
19920
22800
22800
18600
18600
20160
20160
21600
21600
25380
25380
19320
19320
25800
25800
21300
21300
25320
25320
19680
19680
40620
40620
18780
18780
25620
25620
19200
19200
27300
27300
16680
16680
23880
23880
18840
18840
20460
20460
18900
18900
23520
23520
19200
19200
39060
39060
18660
18660
22980
22980
15240
15240
24540
24540
19500
19500
25320
25320
22860
22860
45300
45300
19800
19800
31080
31080
18960
18960
21540
21540
19200
19200
26280
26280
17760
17760
24900
24900
19080
19080
24840
24840
21180
21180
25560
25560
19140
19140
28920
28920
18780
18780
26760
26760
19500
19500
27240
27240
19620
19620
17640
17640
17700
17700
24300
24300
21900
21900
24360
24360
19740
19740
17640
17640
21240
21240
25440
25440
19020
19020
42480
42480
17580
17580
17880
17880
21060
21060
24300
24300
19140
19140
40800
40800
22200
22200
24240
24240
18360
18360
26160
26160
21120
21120
22200
22200
16020
16020
22680
22680
2004

In [32]:
dissimilar_interpolation = folder +  "/interpolation_sum"
Path(dissimilar_interpolation+"/intro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/outro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/solo").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/predict").mkdir(parents=True, exist_ok=True)

In [33]:
# dissimilar_interpolation
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    #solo = test_solo[i]
    if i + 3 < (len(test_intro)):
        outro = test_outro[i+3]
    else:
        outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    #list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, device, max_length=1200)
        #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, dissimilar_interpolation + "/intro/" + "/intro" + str(i)  + ".mid")
    #utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, dissimilar_interpolation + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, dissimilar_interpolation + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


[0, 1, 2, 161, 1, 64, 82, 7, 23, 51, 73, 7, 8, 2, 162, 8, 51, 75, 31, 10, 64, 75, 31, 13, 2, 162, 70, 11, 12, 15, 78, 64, 40, 22, 17, 2, 162, 17, 64, 33, 22, 90, 51, 30, 94, 0, 1, 64, 111, 15, 4, 53, 111, 7, 8, 2, 162, 1, 32, 111, 22, 67, 64, 6, 15, 4, 64, 6, 15, 23, 51, 26, 22, 8, 2, 162, 8, 51, 9, 22, 72, 64, 6, 31, 91, 2, 162, 13, 2, 115, 13, 64, 26, 36, 17, 2, 113, 27, 64, 12, 34, 0, 1, 2, 52, 1, 64, 73, 19, 4, 64, 76, 52, 8, 51, 6, 41, 13, 51, 62, 15, 8, 64, 6, 7, 10, 64, 26, 31, 13, 2, 162, 17, 64, 12, 7, 0, 10, 64, 73, 15, 4, 2, 162, 8, 64, 12, 31, 23, 51, 20, 7, 8, 64, 21, 7, 72, 64, 20, 22, 8, 2, 161, 10, 64, 73, 94, 0, 1, 2, 90, 64, 73, 15, 10, 51, 73, 87, 74, 64, 73, 15, 13, 64, 73, 73, 15, 16, 32, 12, 31, 17, 2, 161, 90, 64, 73, 7, 0, 8, 64, 12, 31, 10, 64, 9, 7, 13, 64, 73, 52, 70, 64, 12, 31, 0, 1, 64, 12, 19, 17, 51, 12, 7, 0, 1, 64, 6, 15, 4, 51, 6, 15, 10, 38, 62, 22, 10, 51, 6, 15, 13, 25, 62, 22, 67, 64, 73, 15, 8, 51, 73, 22, 8, 32, 76, 15, 10, 51, 6, 31, 13, 51, 6,

[0, 1, 43, 114, 1, 51, 30, 19, 91, 64, 33, 7, 78, 51, 6, 47, 74, 64, 9, 7, 0, 91, 53, 6, 7, 70, 35, 33, 7, 78, 51, 26, 7, 0, 4, 51, 6, 15, 8, 51, 6, 19, 78, 51, 26, 7, 74, 64, 73, 7, 0, 8, 43, 105, 16, 38, 29, 15, 78, 64, 75, 22, 17, 64, 26, 7, 27, 51, 6, 22, 74, 51, 6, 19, 0, 4, 51, 26, 7, 91, 51, 33, 102, 0, 1, 54, 33, 7, 4, 51, 26, 7, 8, 54, 6, 19, 91, 54, 6, 7, 16, 51, 33, 77, 0, 78, 64, 6, 182, 0, 91, 54, 33, 7, 74, 51, 33, 182, 0, 91, 54, 6, 7, 78, 51, 6, 182, 0, 91, 54, 6, 19, 78, 54, 6, 15, 78, 18, 57, 182, 0, 10, 51, 6, 19, 78, 18, 6, 19, 78, 54, 30, 19, 74, 25, 50, 37, 0, 10, 51, 58, 80]
11
[0, 1, 43, 161, 1, 64, 57, 7, 13, 64, 62, 7, 16, 64, 62, 41, 17, 64, 96, 7, 17, 51, 57, 15, 27, 64, 57, 7, 0, 1, 51, 62, 7, 4, 51, 58, 31, 8, 32, 57, 7, 10, 51, 57, 7, 13, 64, 58, 7, 0, 8, 64, 57, 7, 10, 64, 12, 15, 16, 64, 62, 7, 8, 51, 57, 7, 10, 51, 62, 7, 16, 51, 61, 19, 27, 64, 73, 19, 0, 8, 51, 20, 15, 10, 64, 61, 19, 0, 4, 64, 62, 7, 8, 51, 26, 7, 10, 51, 62, 19, 0, 16, 51, 62, 52, 

[0, 1, 2, 116, 1, 51, 55, 15, 67, 64, 48, 7, 23, 64, 48, 15, 91, 64, 86, 7, 13, 51, 55, 7, 70, 51, 33, 66, 0, 67, 64, 57, 22, 4, 64, 12, 22, 23, 51, 61, 22, 72, 64, 30, 7, 91, 32, 30, 7, 78, 51, 30, 7, 90, 32, 48, 7, 74, 64, 48, 31, 0, 1, 64, 55, 22, 4, 64, 48, 7, 23, 51, 48, 7, 8, 64, 48, 7, 91, 51, 111, 31, 70, 64, 62, 7, 78, 64, 33, 7, 90, 51, 33, 7, 74, 51, 7, 74, 32, 30, 7, 67, 54, 7, 23, 53, 48, 7, 72, 54, 50, 7, 67, 54, 48, 94, 72, 64, 48, 31, 0, 23, 54, 48, 41, 74, 35, 50, 41, 67, 51, 57, 41, 72, 91, 38, 6, 7, 13, 51, 48, 50, 15, 91, 51, 12, 7, 91, 51, 57, 22, 0, 72, 64, 50, 7, 70, 38, 48, 15, 70, 64, 6, 15, 70, 32, 50, 7, 70, 38, 48, 31, 7, 78, 64, 48, 22, 70, 11, 50, 31, 90, 54, 50, 31, 17, 54, 30, 31, 0, 72, 38, 6, 22, 70, 64, 50, 22]
24
[0, 1, 2, 139, 1, 64, 20, 15, 4, 32, 86, 7, 8, 38, 73, 15, 10, 64, 21, 15, 13, 64, 58, 7, 16, 38, 20, 19, 27, 32, 73, 7, 27, 32, 57, 7, 0, 67, 64, 62, 7, 4, 38, 73, 47, 27, 32, 73, 7, 0, 1, 64, 62, 7, 4, 32, 62, 7, 72, 32, 58, 7, 10, 11, 62,

[0, 1, 2, 150, 1, 51, 12, 31, 23, 51, 57, 15, 72, 54, 58, 15, 10, 51, 57, 15, 91, 51, 57, 7, 70, 51, 58, 7, 0, 1, 54, 12, 15, 90, 54, 12, 15, 27, 53, 12, 15, 0, 10, 53, 57, 31, 74, 64, 73, 31, 0, 72, 54, 62, 15, 13, 53, 57, 15, 0, 8, 64, 12, 15, 91, 54, 12, 15, 10, 54, 90, 51, 57, 15, 27, 53, 58, 15, 0, 74, 53, 12, 15, 4, 35, 12, 7, 10, 54, 57, 52, 0, 70, 51, 6, 7, 16, 64, 48, 7, 17, 63, 50, 52, 0, 1, 53, 48, 15, 4, 54, 50, 7, 23, 64, 48, 15, 10, 53, 48, 15, 8, 54, 50, 15, 10, 51, 6, 15, 0, 70, 51, 50, 15, 16, 72, 35, 48, 15, 0, 67, 51, 58, 31, 23, 51, 12, 15, 72, 51, 58, 15, 10, 54, 62, 15, 91, 35, 50, 22, 70, 54, 12, 31, 78, 54, 57, 15, 90, 51, 12, 15, 13, 64, 57, 15, 74, 53, 12, 15, 0, 1, 53, 61, 15, 23, 54, 61, 15, 72, 53, 46, 15, 10, 54, 12, 15, 91, 35, 50, 31, 70, 54, 50, 34, 78, 53, 12, 15, 17, 54, 57, 15, 8, 54, 6, 15, 74, 35, 50, 15, 70, 51, 82, 15, 0, 4, 54, 46, 136, 70, 54, 6, 15, 70, 53, 50, 15, 16, 54, 50, 7, 70, 54, 12, 15, 0, 8, 64, 43, 169, 70, 64, 6, 60, 4, 24, 50, 41]

[0, 1, 2, 69, 1, 38, 57, 7, 4, 64, 33, 7, 91, 64, 33, 7, 70, 51, 58, 7, 78, 51, 96, 41, 27, 64, 21, 41, 0, 23, 64, 86, 7, 72, 32, 21, 7, 10, 64, 12, 7, 70, 51, 26, 47, 27, 64, 73, 52, 0, 1, 64, 9, 66, 4, 64, 73, 41, 72, 64, 9, 7, 91, 51, 21, 7, 13, 51, 86, 7, 16, 64, 6, 52, 17, 64, 26, 19, 27, 64, 57, 41, 0, 4, 64, 73, 85, 0, 91, 64, 33, 7, 4, 64, 33, 7, 8, 64, 33, 7, 27, 64, 26, 66, 13, 51, 6, 7, 0, 1, 64, 9, 7, 4, 64, 73, 7, 70, 32, 9, 7, 10, 64, 6, 52, 13, 64, 26, 19, 27, 51, 86, 7, 0, 1, 51, 9, 7, 91, 64, 73, 7, 13, 51, 9, 22, 78, 51, 73, 22, 16, 54, 9, 7, 16, 64, 86, 31, 27, 64, 26, 19, 0, 1, 51, 6, 41, 4, 51, 86, 85]
51
[0, 1, 43, 161, 1, 64, 21, 15, 4, 64, 61, 15, 8, 64, 62, 47, 8, 51, 61, 7, 27, 32, 73, 22, 0, 1, 64, 9, 7, 4, 51, 20, 52, 10, 38, 73, 7, 17, 32, 73, 31, 27, 51, 9, 31, 0, 1, 64, 58, 15, 4, 38, 73, 41, 10, 64, 6, 15, 13, 64, 58, 7, 16, 32, 73, 7, 27, 51, 20, 15, 74, 64, 20, 19, 0, 4, 64, 73, 41, 10, 54, 9, 7, 16, 54, 26, 15, 17, 51, 76, 7, 27, 64, 20, 7, 0, 1, 51, 

[0, 1, 2, 181, 1, 24, 46, 15, 4, 79, 62, 7, 91, 32, 62, 34, 13, 79, 75, 15, 16, 18, 48, 52, 27, 38, 62, 7, 0, 1, 18, 62, 7, 23, 54, 58, 52, 8, 53, 46, 47, 0, 10, 46, 7, 13, 11, 46, 52, 17, 64, 33, 31, 27, 11, 33, 31, 0, 1, 11, 62, 7, 27, 79, 62, 7, 0, 1, 24, 75, 15, 8, 18, 75, 7, 10, 18, 62, 7, 13, 11, 46, 31, 17, 24, 75, 7, 27, 38, 46, 7, 0, 1, 38, 75, 15, 67, 38, 75, 15, 8, 11, 75, 7, 10, 38, 48, 31, 13, 64, 75, 47, 0, 4, 38, 75, 7, 8, 11, 46, 7, 10, 64, 75, 31, 13, 64, 100, 7, 16, 24, 75, 7, 17, 64, 46, 47, 0, 4, 79, 62, 15, 8, 64, 75, 15, 8, 38, 46, 31, 16, 79, 48, 47, 17, 18, 62, 7, 27, 79, 48, 15, 0, 10, 5, 48, 31, 13, 25, 46, 7, 17, 24, 48, 31, 27, 18, 62, 31, 74, 38, 75, 7]
67
[0, 1, 2, 181, 1, 5, 50, 7, 4, 24, 46, 7, 8, 38, 33, 7, 10, 38, 55, 31, 13, 38, 127, 15, 16, 51, 127, 7, 17, 64, 104, 52, 0, 1, 35, 93, 41, 10, 64, 33, 7, 8, 38, 104, 31, 10, 38, 26, 52, 16, 32, 62, 7, 17, 38, 33, 31, 17, 64, 62, 7, 27, 38, 46, 7, 0, 1, 38, 122, 66, 23, 11, 26, 31, 8, 64, 26, 31, 0, 4, 64

[0, 1, 2, 194, 1, 5, 9, 52, 4, 79, 86, 52, 70, 24, 76, 34, 27, 25, 9, 7, 0, 1, 25, 73, 15, 4, 25, 62, 7, 10, 79, 73, 7, 16, 25, 62, 15, 27, 49, 75, 15, 0, 4, 24, 26, 31, 0, 8, 51, 6, 15, 72, 53, 33, 52, 27, 14, 26, 52, 0, 1, 64, 62, 28, 4, 32, 26, 7, 8, 11, 48, 31, 10, 51, 111, 15, 16, 25, 46, 52, 10, 51, 62, 22, 16, 32, 9, 52, 0, 1, 38, 33, 31, 17, 64, 9, 15, 17, 38, 111, 15, 0, 1, 38, 100, 31]
80
[0, 1, 2, 142, 1, 25, 111, 22, 67, 64, 111, 15, 4, 38, 75, 15, 23, 64, 33, 22, 91, 11, 75, 22, 72, 24, 30, 7, 13, 11, 111, 22, 78, 38, 111, 22, 17, 18, 30, 15, 74, 51, 30, 22, 0, 1, 38, 100, 22, 67, 38, 111, 42, 0, 1, 32, 111, 22, 67, 38, 111, 22, 4, 14, 29, 22, 23, 64, 30, 15, 17, 14, 29, 31, 72, 38, 30, 31, 13, 11, 111, 31, 70, 38, 111, 7, 16, 11, 48, 7, 78, 11, 30, 7, 17, 24, 33, 7, 27, 5, 30, 15, 0, 17, 11, 30, 7, 27, 11, 30, 7, 23, 11, 33, 7, 8, 18, 26, 31, 0, 17, 79, 111, 15, 8, 32, 111, 19, 17, 79, 29, 15, 90, 79, 100, 7, 27, 32, 111, 15, 0]
81
[0, 1, 2, 155, 1, 25, 6, 22, 23, 25, 6, 

[0, 1, 43, 208, 1, 51, 55, 7, 4, 45, 48, 31, 8, 5, 62, 42, 13, 38, 58, 7, 17, 38, 55, 31, 0, 1, 38, 100, 7, 4, 32, 76, 31, 8, 11, 84, 41, 16, 49, 128, 31, 17, 2, 142, 8, 49, 84, 34, 13, 2, 142, 16, 64, 129, 34, 16, 51, 127, 7, 16, 11, 104, 52, 17, 11, 128, 31, 27, 11, 104, 52, 1, 2, 142, 1, 24, 62, 7, 4, 38, 58, 31, 8, 11, 48, 31, 8, 51, 62, 15, 10, 11, 128, 15, 16, 38, 26, 15, 8, 2, 142, 17, 25, 48, 68, 0, 1, 11, 61, 19, 10, 38, 57, 15, 0, 1, 11, 61, 15, 1, 51, 62, 7, 17, 38, 48, 15]
95
[0, 1, 43, 113, 1, 25, 55, 52, 8, 5, 55, 19, 16, 5, 33, 31, 17, 5, 75, 85, 0, 4, 24, 55, 31, 10, 79, 33, 15, 17, 64, 6, 52, 0, 1, 24, 62, 7, 8, 5, 50, 31, 16, 11, 48, 7, 27, 38, 46, 15, 0, 1, 18, 33, 7, 4, 25, 75, 15, 8, 79, 33, 34, 0, 17, 79, 57, 31, 4, 11, 33, 19, 8, 18, 48, 7, 10, 38, 48, 31, 13, 11, 33, 66, 0, 16, 38, 50, 15, 17, 11, 55, 7, 10, 25, 75, 28, 0, 8, 2, 3, 17, 11, 50, 19, 17, 11, 33, 7, 74, 25, 19, 0, 10, 55, 41, 0, 16, 11, 33, 31, 8, 2, 46, 41, 17, 2, 46, 7, 17, 25, 50, 19, 74, 25, 48,

[0, 1, 2, 197, 67, 11, 26, 52, 23, 25, 86, 66, 8, 2, 169, 13, 2, 114, 17, 79, 61, 7, 78, 79, 57, 22, 17, 11, 26, 22, 17, 2, 177, 90, 18, 9, 15, 74, 11, 73, 15, 0, 1, 24, 73, 15, 8, 79, 50, 15, 0, 1, 43, 176, 8, 11, 73, 15, 13, 2, 115, 17, 79, 57, 15, 90, 51, 84, 15, 0, 1, 43, 7, 10, 5, 62, 31, 8, 64, 9, 7, 90, 32, 73, 31, 74, 5, 9, 15, 0, 13, 11, 62, 31, 17, 38, 9, 15, 0, 1, 64, 73, 31, 8, 2, 158, 13, 25, 9, 41, 17, 38, 61, 15, 0, 17, 2, 114, 1, 2, 169, 17, 79, 61, 52, 74, 25, 21, 31, 0, 67, 79, 61, 47, 10, 38, 12, 15, 17, 2, 116, 72, 24, 21, 15, 91, 64, 62, 7, 13, 24, 73, 15, 70, 11, 48, 31, 17, 2, 148, 13, 38, 48, 15, 90, 24, 50, 22, 27, 11, 48, 31, 74, 18, 50, 47, 0, 1, 2, 148, 1, 25, 9, 22, 67, 79, 62, 7, 23, 32, 9, 66, 8, 79, 61, 7, 17, 25, 9, 22, 0, 13, 79, 21, 31, 17, 2, 186, 17, 25, 73, 22, 74, 79, 57, 52, 0, 13, 2, 169, 17, 79, 61, 31]
105
[0, 1, 2, 150, 1, 79, 50, 22, 67, 51, 50, 22, 4, 64, 6, 15, 23, 32, 9, 22, 8, 11, 9, 22, 72, 25, 26, 41, 0, 72, 64, 33, 22, 10, 25, 26, 22,

In [34]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(dissimilar_interpolation + "/intro/" + '/intro' + str(i) + '.mid')
    outro = mido.MidiFile(dissimilar_interpolation + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(dissimilar_interpolation + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(dissimilar_interpolation + '/merged_predict' + str(i) + '.mid')

36300
36300
36960
36960
26280
26280
43320
43320
25140
25140
19680
19680
37680
37680
34440
34440
24480
24480
46140
46140
27180
27180
38460
38460
32280
32280
33060
33060
31080
31080
45900
45900
11580
11580
34020
34020
25260
25260
64320
64320
37260
37260
24720
24720
27960
27960
33420
33420
26520
26520
34740
34740
25740
25740
43500
43500
35100
35100
18000
18000
33060
33060
25560
25560
13500
13500
30960
30960
59100
59100
21960
21960
39840
39840
46080
46080
34920
34920
28560
28560
21240
21240
25800
25800
27480
27480
19800
19800
26880
26880
21180
21180
22380
22380
30120
30120
38700
38700
24840
24840
27780
27780
29220
29220
27600
27600
28140
28140
27000
27000
26220
26220
23280
23280
24300
24300
25620
25620
30720
30720
22560
22560
27780
27780
23340
23340
25200
25200
23760
23760
30540
30540
37620
37620
34260
34260
25440
25440
21720
21720
35220
35220
18600
18600
24480
24480
36960
36960
20280
20280
52920
52920
40560
40560
25560
25560
24420
24420
46860
46860
21180
21180
17760
17760
39720
39720
3294

In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
vocab_folder = "vocab/"
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab  + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def load_vocab(path):
    with open(path, 'rb') as f:
        x = pickle.load(f)
    return x

In [None]:
vocab_folder = "vocab_intro/"
intro_field.vocab = load_vocab(vocab_folder + 'intro_vocab.pkl')
solo_field.vocab = load_vocab(vocab_folder + 'solo_vocab.pkl')
outro_field.vocab = load_vocab(vocab_folder + 'outro_vocab.pkl')

In [None]:
vocab_folder = "vocab/"
with open()
intro_field.vocab = 

In [None]:
beam = [2, 59, 119, 13, 212, 59, 212, 59, 59, 75, 59, 59, 13, 119, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 119, 59, 59, 59, 59, 166, 59, 59, 59, 13, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 59, 59, 212, 212, 13, 59, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 14, 59, 59, 212, 59, 212, 212, 59, 59, 59, 68, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 68, 59, 59, 212, 59, 59, 13, 59, 59, 59, 59, 59, 59, 97, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 158, 59, 59, 13, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 13, 212, 13, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 158, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 13, 13, 59, 59, 13, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 13, 59, 166, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 86, 59, 212, 212, 212, 59, 59, 59, 59, 212, 59, 86, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 212, 59, 59, 166, 59, 86, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 59, 59, 59, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 212, 212, 158, 59, 59, 13, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 158, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 212, 212, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 13, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 212, 86, 59, 59, 59, 158, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 212, 59, 13, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 68, 59, 13, 59, 59, 13, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 14, 59, 59, 59, 59, 59, 59, 13, 86, 59, 59, 59, 212, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 13, 59, 59, 59, 59, 68, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 212, 212, 13, 59, 166, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 212, 59, 59, 212, 59, 212, 59, 59, 13, 59, 59, 59, 59, 13, 59, 59, 14, 13, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 68, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 212, 59, 59, 13, 59, 59, 166, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 13, 59, 59, 212, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59]

translated_sentence1 = [solo_field.vocab.itos[idx] for idx in beam]
translated_sentence = [int(x) for x in translated_sentence1 if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']    
utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict1.mid")

In [None]:
translated_sentence

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    list_sentence = [int(x) for x in sentence.split(' ')]
    remi = [word2event[x] for x in list_sentence]
    print(remi)

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns