In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_augmented_presplit"
folder = "dynamic_augmented_models/2enc"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [2]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/intro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/outro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/predict").mkdir(parents=True, exist_ok=True)

In [3]:
event2word, word2event = pickle.load(open('dictionary_augmented.pkl', 'rb'))

In [4]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [5]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=3)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=3)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=3)

In [6]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    #print(intro.transpose(1,0).size(0))
    print(solo.size(1))

325
319
414
253
509
335
522
281
272
619
326
353
279
444


In [7]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [8]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

In [9]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        src2_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.src2_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src2_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = kk.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )
        
        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        ).to(self.device)
        embed_src2 = self.dropout(
            (self.src2_word_embedding(src2) + self.src2_position_embedding(src2_positions))
        ).to(self.device)
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        ).to(self.device)
        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        #print(src_padding_mask.size())
        #print(src2_padding_mask.size())
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [10]:
src_vocab_size = len(intro_field.vocab)
src2_vocab_size = len(outro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    src2_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [11]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4) #non augmented 3e-4


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 14,997,275 trainable parameters


In [12]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [13]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [14]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, src2, trg = intro.transpose(1,0), outro.transpose(1,0), solo.transpose(1,0)
        src, src2, trg = src.to(device), src2.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src ,src2, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)
        loss = criterion(output, trg)
#         print(torch.isfinite(trg).all().cpu().item())
#         print(torch.isfinite(output).all().cpu().item())
#         print(torch.isfinite(loss).all().cpu().item())
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, src2, trg = intro.transpose(1,0), outro.transpose(1,0), solo.transpose(1,0)
            src, src2, trg = src.to(device), src2.to(device), trg.to(device)

            output = model(src, src2, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [15]:
def translate_sentence(model, sentence, sentence2, intro, outro, solo, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, intro.init_token)
    tokens.append(intro.eos_token)

    tokens2.insert(0, outro.init_token)
    tokens2.append(outro.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [intro.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [outro.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [solo.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, sentence2_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == solo.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [solo.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [16]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [17]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(3):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [18]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1
best_valid_loss = float('inf')
#torch.autograd.set_detect_anomaly(True)
#model = nn.DataParallel(model, device_ids=[0,1]).to(device)
for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+1) % 25 ==0:
        if check_mode_collapse(model) > 1:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 1m 1s
	Train Loss: 4.279 | Train PPL:  72.184
	 Val. Loss: 3.425 |  Val. PPL:  30.728
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 1m 3s
	Train Loss: 3.018 | Train PPL:  20.445
	 Val. Loss: 2.869 |  Val. PPL:  17.624
=> Saving checkpoint
Epoch: 03 | Time: 1m 4s
	Train Loss: 2.688 | Train PPL:  14.703
	 Val. Loss: 2.574 |  Val. PPL:  13.122
=> Saving checkpoint
Epoch: 04 | Time: 1m 4s
	Train Loss: 2.421 | Train PPL:  11.262
	 Val. Loss: 2.359 |  Val. PPL:  10.583
=> Saving checkpoint
Epoch: 05 | Time: 1m 3s
	Train Loss: 2.248 | Train PPL:   9.468
	 Val. Loss: 2.252 |  Val. PPL:   9.510
=> Saving checkpoint
Epoch: 06 | Time: 1m 4s
	Train Loss: 2.119 | Train PPL:   8.323
	 Val. Loss: 2.164 |  Val. PPL:   8.707
=> Saving checkpoint
Epoch: 07 | Time: 1m 3s
	Train Loss: 2.011 | Train PPL:   7.469
	 Val. Loss: 2.105 |  Val. PPL:   8.209
=> Saving checkpoint
Epoch: 08 | Time: 1m 3s
	Train Loss: 1.927 | Train PPL:   6.868
	 Val. Loss: 2.055 |  Val. PPL:   7.806


Epoch: 40 | Time: 1m 4s
	Train Loss: 1.001 | Train PPL:   2.721
	 Val. Loss: 2.982 |  Val. PPL:  19.724
=> Saving checkpoint
Epoch: 41 | Time: 1m 3s
	Train Loss: 0.990 | Train PPL:   2.691
	 Val. Loss: 3.072 |  Val. PPL:  21.575
=> Saving checkpoint
Epoch: 42 | Time: 1m 4s
	Train Loss: 0.968 | Train PPL:   2.634
	 Val. Loss: 3.038 |  Val. PPL:  20.867
Epoch: 43 | Time: 1m 4s
	Train Loss: 0.951 | Train PPL:   2.589
	 Val. Loss: 3.092 |  Val. PPL:  22.018
Epoch: 44 | Time: 1m 4s
	Train Loss: 0.935 | Train PPL:   2.548
	 Val. Loss: 3.108 |  Val. PPL:  22.385
Epoch: 45 | Time: 1m 3s
	Train Loss: 0.921 | Train PPL:   2.511
	 Val. Loss: 3.192 |  Val. PPL:  24.339
Epoch: 46 | Time: 1m 4s
	Train Loss: 0.905 | Train PPL:   2.471
	 Val. Loss: 3.191 |  Val. PPL:  24.323
Epoch: 47 | Time: 1m 3s
	Train Loss: 0.891 | Train PPL:   2.438
	 Val. Loss: 3.254 |  Val. PPL:  25.894
Epoch: 48 | Time: 1m 3s
	Train Loss: 0.875 | Train PPL:   2.398
	 Val. Loss: 3.253 |  Val. PPL:  25.861
Epoch: 49 | Time: 1m 3

Epoch: 88 | Time: 1m 4s
	Train Loss: 0.553 | Train PPL:   1.738
	 Val. Loss: 4.265 |  Val. PPL:  71.176
Epoch: 89 | Time: 1m 3s
	Train Loss: 0.547 | Train PPL:   1.729
	 Val. Loss: 4.433 |  Val. PPL:  84.195
Epoch: 90 | Time: 1m 3s
	Train Loss: 0.543 | Train PPL:   1.722
	 Val. Loss: 4.350 |  Val. PPL:  77.478
Epoch: 91 | Time: 1m 3s
	Train Loss: 0.539 | Train PPL:   1.715
	 Val. Loss: 4.430 |  Val. PPL:  83.941
Epoch: 92 | Time: 1m 3s
	Train Loss: 0.533 | Train PPL:   1.704
	 Val. Loss: 4.461 |  Val. PPL:  86.596
Epoch: 93 | Time: 1m 4s
	Train Loss: 0.530 | Train PPL:   1.698
	 Val. Loss: 4.493 |  Val. PPL:  89.403
Epoch: 94 | Time: 1m 3s
	Train Loss: 0.525 | Train PPL:   1.691
	 Val. Loss: 4.442 |  Val. PPL:  84.986
Epoch: 95 | Time: 1m 3s
	Train Loss: 0.520 | Train PPL:   1.683
	 Val. Loss: 4.566 |  Val. PPL:  96.138
Epoch: 96 | Time: 1m 3s
	Train Loss: 0.518 | Train PPL:   1.679
	 Val. Loss: 4.545 |  Val. PPL:  94.159
Epoch: 97 | Time: 1m 3s
	Train Loss: 0.511 | Train PPL:   1.667


Epoch: 126 | Time: 1m 3s
	Train Loss: 0.423 | Train PPL:   1.526
	 Val. Loss: 4.849 |  Val. PPL: 127.558
Epoch: 127 | Time: 1m 3s
	Train Loss: 0.420 | Train PPL:   1.522
	 Val. Loss: 4.924 |  Val. PPL: 137.577
Epoch: 128 | Time: 1m 3s
	Train Loss: 0.421 | Train PPL:   1.523
	 Val. Loss: 4.883 |  Val. PPL: 132.047
Epoch: 129 | Time: 1m 3s
	Train Loss: 0.416 | Train PPL:   1.516
	 Val. Loss: 4.993 |  Val. PPL: 147.421
Epoch: 130 | Time: 1m 4s
	Train Loss: 0.412 | Train PPL:   1.509
	 Val. Loss: 4.936 |  Val. PPL: 139.200
Epoch: 131 | Time: 1m 3s
	Train Loss: 0.411 | Train PPL:   1.508
	 Val. Loss: 5.035 |  Val. PPL: 153.670
Epoch: 132 | Time: 1m 3s
	Train Loss: 0.407 | Train PPL:   1.503
	 Val. Loss: 4.933 |  Val. PPL: 138.730
Epoch: 133 | Time: 1m 3s
	Train Loss: 0.406 | Train PPL:   1.501
	 Val. Loss: 5.006 |  Val. PPL: 149.340
Epoch: 134 | Time: 1m 3s
	Train Loss: 0.401 | Train PPL:   1.493
	 Val. Loss: 5.091 |  Val. PPL: 162.586
Epoch: 135 | Time: 1m 4s
	Train Loss: 0.401 | Train PPL

[0, 1, 2, 126, 8, 2, 126, 13, 2, 126, 13, 32, 6, 15, 70, 64, 50, 7, 16, 64, 6, 15, 78, 38, 57, 52, 17, 2, 126, 90, 54, 50, 15, 27, 64, 12, 31, 0, 1, 2, 126, 1, 64, 57, 34, 8, 2, 126, 10, 64, 6, 15, 91, 64, 50, 15, 13, 2, 126, 13, 32, 6, 31, 16, 64, 55, 7, 78, 64, 50, 19, 17, 2, 126, 90, 64, 55, 15, 27, 51, 57, 31, 0, 1, 2, 126, 1, 64, 6, 28, 8, 2, 126, 10, 32, 57, 15, 91, 32, 57, 15, 13, 2, 126, 13, 32, 12, 31, 16, 51, 50, 7, 78, 54, 57, 19, 17, 2, 126, 90, 54, 48, 15, 27, 32, 57, 52, 0, 1, 2, 126, 1, 64, 57, 19, 4, 64, 33, 7, 23, 64, 50, 34, 8, 2, 126, 91, 64, 57, 7, 13, 2, 126, 13, 32, 57, 31, 16, 64, 33, 15, 78, 54, 6, 60, 17, 2, 126, 0, 1, 2, 126, 8, 2, 126, 10, 51, 50, 22, 10, 54, 57, 22, 91, 51, 6, 15, 91, 64, 21, 22, 13, 2, 126, 13, 51, 57, 42, 13, 64, 21, 42, 17, 2, 126, 27, 35, 50, 15, 27, 64, 12, 15, 74, 53, 50, 15, 74, 64, 57, 22, 0, 1, 2, 126, 1, 35, 50, 42, 1, 51, 12, 42, 8, 2, 126, 10, 54, 33, 31, 10, 11, 57, 7, 13, 2, 126, 13, 35, 50, 47, 13, 32, 57, 42, 27, 35, 30, 31, 

Epoch: 221 | Time: 1m 4s
	Train Loss: 0.231 | Train PPL:   1.259
	 Val. Loss: 5.761 |  Val. PPL: 317.614
=> Saving checkpoint
Epoch: 222 | Time: 1m 4s
	Train Loss: 0.230 | Train PPL:   1.259
	 Val. Loss: 5.691 |  Val. PPL: 296.209
Epoch: 223 | Time: 1m 4s
	Train Loss: 0.229 | Train PPL:   1.257
	 Val. Loss: 5.896 |  Val. PPL: 363.629
Epoch: 224 | Time: 1m 4s
	Train Loss: 0.227 | Train PPL:   1.255
	 Val. Loss: 5.764 |  Val. PPL: 318.719
Epoch: 225 | Time: 1m 4s
	Train Loss: 0.224 | Train PPL:   1.251
	 Val. Loss: 5.781 |  Val. PPL: 324.062
[0, 1, 2, 95, 67, 79, 61, 34, 72, 18, 86, 41, 70, 79, 76, 87, 0, 72, 79, 86, 15, 10, 25, 61, 22, 13, 79, 86, 87, 0, 72, 5, 76, 31, 91, 79, 84, 31, 70, 18, 98, 60, 0, 23, 18, 84, 7, 72, 18, 98, 31, 91, 79, 76, 31, 70, 79, 86, 60, 0, 72, 79, 61, 31, 91, 25, 86, 7, 70, 79, 76, 60, 0, 72, 5, 57, 31, 91, 79, 61, 31, 70, 11, 58, 103, 0, 72, 79, 61, 31, 91, 25, 86, 7, 70, 79, 76, 7, 78, 18, 84, 56, 0, 72, 79, 61, 31, 10, 18, 86, 31]
[0, 1, 43, 105, 10, 38, 

Epoch: 255 | Time: 1m 3s
	Train Loss: 0.180 | Train PPL:   1.197
	 Val. Loss: 5.954 |  Val. PPL: 385.414
Epoch: 256 | Time: 1m 4s
	Train Loss: 0.177 | Train PPL:   1.194
	 Val. Loss: 6.038 |  Val. PPL: 418.931
Epoch: 257 | Time: 1m 4s
	Train Loss: 0.176 | Train PPL:   1.192
	 Val. Loss: 6.005 |  Val. PPL: 405.643
Epoch: 258 | Time: 1m 4s
	Train Loss: 0.174 | Train PPL:   1.191
	 Val. Loss: 6.010 |  Val. PPL: 407.606
Epoch: 259 | Time: 1m 4s
	Train Loss: 0.171 | Train PPL:   1.187
	 Val. Loss: 6.056 |  Val. PPL: 426.562
Epoch: 260 | Time: 1m 4s
	Train Loss: 0.171 | Train PPL:   1.187
	 Val. Loss: 6.090 |  Val. PPL: 441.428
=> Saving checkpoint
Epoch: 261 | Time: 1m 4s
	Train Loss: 0.169 | Train PPL:   1.184
	 Val. Loss: 6.107 |  Val. PPL: 448.894
=> Saving checkpoint
Epoch: 262 | Time: 1m 4s
	Train Loss: 0.168 | Train PPL:   1.183
	 Val. Loss: 6.137 |  Val. PPL: 462.812
Epoch: 263 | Time: 1m 3s
	Train Loss: 0.168 | Train PPL:   1.183
	 Val. Loss: 6.081 |  Val. PPL: 437.566
Epoch: 264 | 

[0, 1, 2, 197, 67, 18, 86, 15, 4, 54, 61, 15, 23, 25, 57, 31, 72, 32, 58, 66, 72, 18, 86, 22, 10, 11, 84, 22, 91, 38, 84, 22, 13, 38, 84, 22, 70, 18, 84, 22, 16, 11, 76, 22, 78, 38, 76, 22, 17, 32, 86, 22, 90, 38, 86, 22, 27, 11, 61, 22, 74, 38, 61, 22, 0, 1, 11, 57, 22, 67, 25, 57, 22, 4, 11, 58, 22, 23, 38, 58, 22, 8, 38, 58, 22, 72, 11, 58, 22, 10, 11, 58, 22, 91, 5, 84, 22, 13, 38, 84, 22, 70, 18, 84, 22, 16, 25, 76, 22, 78, 38, 86, 22, 17, 5, 61, 22, 90, 5, 61, 22, 27, 11, 73, 22, 74, 38, 57, 22, 0, 1, 38, 58, 22, 67, 11, 58, 22, 4, 11, 58, 22, 23, 38, 58, 22, 8, 5, 62, 22, 72, 11, 62, 22, 10, 79, 61, 15, 13, 32, 73, 22, 70, 11, 57, 22, 16, 25, 58, 22, 78, 24, 58, 22, 17, 38, 48, 22, 90, 79, 57, 22, 27, 38, 58, 22, 74, 32, 58, 22, 0, 1, 32, 62, 22, 67, 11, 62, 22, 4, 5, 58, 22, 23, 11, 58, 22, 8, 11, 62, 22, 72, 25, 58, 22, 72, 18, 76, 52, 10, 5, 58, 22, 91, 32, 62, 7, 91, 51, 58, 22, 13, 38, 62, 22, 70, 11, 62, 22, 70, 24, 84, 52, 90, 25, 76, 31, 74, 38, 62, 31, 0, 67, 79, 84, 52

Epoch: 338 | Time: 1m 4s
	Train Loss: 0.099 | Train PPL:   1.104
	 Val. Loss: 6.509 |  Val. PPL: 671.047
Epoch: 339 | Time: 1m 4s
	Train Loss: 0.099 | Train PPL:   1.104
	 Val. Loss: 6.480 |  Val. PPL: 652.217
Epoch: 340 | Time: 1m 3s
	Train Loss: 0.098 | Train PPL:   1.103
	 Val. Loss: 6.598 |  Val. PPL: 733.675
=> Saving checkpoint
Epoch: 341 | Time: 1m 3s
	Train Loss: 0.098 | Train PPL:   1.103
	 Val. Loss: 6.585 |  Val. PPL: 724.271
=> Saving checkpoint
Epoch: 342 | Time: 1m 4s
	Train Loss: 0.097 | Train PPL:   1.102
	 Val. Loss: 6.615 |  Val. PPL: 746.544
Epoch: 343 | Time: 1m 4s
	Train Loss: 0.097 | Train PPL:   1.102
	 Val. Loss: 6.634 |  Val. PPL: 760.183
Epoch: 344 | Time: 1m 4s
	Train Loss: 0.096 | Train PPL:   1.101
	 Val. Loss: 6.625 |  Val. PPL: 753.885
Epoch: 345 | Time: 1m 3s
	Train Loss: 0.096 | Train PPL:   1.101
	 Val. Loss: 6.580 |  Val. PPL: 720.622
Epoch: 346 | Time: 1m 3s
	Train Loss: 0.094 | Train PPL:   1.099
	 Val. Loss: 6.622 |  Val. PPL: 751.325
Epoch: 347 | 

Epoch: 376 | Time: 1m 3s
	Train Loss: 0.080 | Train PPL:   1.083
	 Val. Loss: 6.768 |  Val. PPL: 869.949
Epoch: 377 | Time: 1m 4s
	Train Loss: 0.078 | Train PPL:   1.082
	 Val. Loss: 6.746 |  Val. PPL: 850.393
Epoch: 378 | Time: 1m 3s
	Train Loss: 0.078 | Train PPL:   1.081
	 Val. Loss: 6.842 |  Val. PPL: 936.600
Epoch: 379 | Time: 1m 4s
	Train Loss: 0.079 | Train PPL:   1.082
	 Val. Loss: 6.785 |  Val. PPL: 884.786
Epoch: 380 | Time: 1m 3s
	Train Loss: 0.078 | Train PPL:   1.081
	 Val. Loss: 6.841 |  Val. PPL: 935.625
=> Saving checkpoint
Epoch: 381 | Time: 1m 3s
	Train Loss: 0.079 | Train PPL:   1.082
	 Val. Loss: 6.797 |  Val. PPL: 895.440
=> Saving checkpoint
Epoch: 382 | Time: 1m 3s
	Train Loss: 0.078 | Train PPL:   1.081
	 Val. Loss: 6.816 |  Val. PPL: 912.015
Epoch: 383 | Time: 1m 4s
	Train Loss: 0.078 | Train PPL:   1.081
	 Val. Loss: 6.936 |  Val. PPL: 1028.534
Epoch: 384 | Time: 1m 3s
	Train Loss: 0.076 | Train PPL:   1.079
	 Val. Loss: 6.888 |  Val. PPL: 980.440
Epoch: 385 |

Epoch: 422 | Time: 1m 4s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 6.994 |  Val. PPL: 1089.581
Epoch: 423 | Time: 1m 4s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 6.907 |  Val. PPL: 999.118
Epoch: 424 | Time: 1m 4s
	Train Loss: 0.063 | Train PPL:   1.065
	 Val. Loss: 7.019 |  Val. PPL: 1117.151
Epoch: 425 | Time: 1m 4s
	Train Loss: 0.063 | Train PPL:   1.065
	 Val. Loss: 6.986 |  Val. PPL: 1081.581
[0, 1, 2, 95, 67, 11, 46, 41, 8, 18, 75, 31, 72, 79, 6, 87, 0, 72, 79, 6, 15, 10, 25, 58, 22, 13, 79, 62, 87, 0, 72, 5, 62, 31, 91, 79, 48, 85, 0, 72, 64, 46, 15, 10, 18, 55, 7, 91, 79, 48, 85, 0, 23, 79, 57, 60, 0, 23, 18, 75, 7, 72, 11, 46, 31, 91, 24, 104, 7, 70, 79, 58, 7, 70, 11, 75, 56, 0, 72, 25, 100, 15]
[0, 1, 2, 183, 8, 2, 183, 13, 2, 183, 70, 51, 20, 15, 16, 51, 96, 15, 78, 51, 86, 15, 17, 2, 183, 90, 64, 61, 15, 27, 64, 57, 15, 74, 64, 12, 15, 0, 1, 2, 183, 67, 64, 57, 42, 8, 2, 183, 13, 2, 183, 70, 51, 86, 15, 16, 51, 61, 15, 78, 51, 86, 7, 17, 2, 183, 90, 6

Epoch: 451 | Time: 1m 4s
	Train Loss: 0.056 | Train PPL:   1.058
	 Val. Loss: 7.099 |  Val. PPL: 1210.847
Epoch: 452 | Time: 1m 3s
	Train Loss: 0.056 | Train PPL:   1.058
	 Val. Loss: 7.011 |  Val. PPL: 1108.536
Epoch: 453 | Time: 1m 4s
	Train Loss: 0.056 | Train PPL:   1.057
	 Val. Loss: 7.164 |  Val. PPL: 1292.211
Epoch: 454 | Time: 1m 4s
	Train Loss: 0.056 | Train PPL:   1.057
	 Val. Loss: 7.082 |  Val. PPL: 1189.984
Epoch: 455 | Time: 1m 3s
	Train Loss: 0.055 | Train PPL:   1.057
	 Val. Loss: 7.113 |  Val. PPL: 1227.392
Epoch: 456 | Time: 1m 4s
	Train Loss: 0.055 | Train PPL:   1.057
	 Val. Loss: 7.215 |  Val. PPL: 1359.416
Epoch: 457 | Time: 1m 3s
	Train Loss: 0.055 | Train PPL:   1.056
	 Val. Loss: 7.184 |  Val. PPL: 1318.739
Epoch: 458 | Time: 1m 4s
	Train Loss: 0.055 | Train PPL:   1.057
	 Val. Loss: 7.169 |  Val. PPL: 1298.896
Epoch: 459 | Time: 1m 4s
	Train Loss: 0.054 | Train PPL:   1.056
	 Val. Loss: 7.183 |  Val. PPL: 1316.690
Epoch: 460 | Time: 1m 4s
	Train Loss: 0.055 | 

[0, 1, 2, 197, 67, 18, 57, 15, 4, 54, 58, 15, 23, 25, 62, 31, 72, 32, 48, 66, 72, 18, 57, 22, 10, 11, 86, 22, 91, 38, 86, 22, 13, 38, 86, 22, 70, 18, 86, 22, 16, 11, 61, 22, 78, 38, 61, 22, 17, 32, 57, 22, 90, 38, 57, 22, 27, 11, 58, 22, 74, 38, 58, 22, 0, 1, 11, 62, 22, 67, 25, 62, 22, 4, 11, 48, 22, 23, 38, 48, 22, 8, 38, 48, 22, 72, 11, 48, 22, 10, 11, 48, 22, 91, 5, 86, 22, 13, 38, 86, 22, 70, 18, 86, 22, 16, 25, 61, 22, 78, 38, 61, 22, 17, 5, 57, 22, 90, 5, 57, 22, 27, 11, 58, 22, 74, 38, 58, 22, 0, 1, 38, 62, 22, 67, 11, 62, 22, 4, 11, 62, 22, 23, 38, 62, 22, 8, 5, 48, 22, 72, 11, 48, 22, 10, 79, 57, 15, 13, 32, 58, 22, 70, 11, 58, 22, 16, 25, 62, 22, 78, 24, 62, 22, 17, 38, 48, 22, 90, 79, 58, 22, 27, 38, 62, 22, 74, 32, 62, 22, 0, 1, 32, 48, 22, 67, 11, 48, 22, 4, 5, 62, 22, 23, 11, 62, 22, 8, 11, 48, 22, 72, 25, 58, 22, 72, 18, 86, 52, 10, 5, 62, 22, 91, 32, 48, 7, 91, 51, 62, 22, 13, 38, 48, 22, 70, 11, 48, 22, 70, 24, 86, 52, 90, 25, 86, 31, 74, 38, 48, 31, 0, 67, 79, 86, 52

In [None]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [54]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
load_checkpoint(state, model, optimizer)

=> Loading checkpoint


In [55]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

2075.0795034486036


In [None]:
generated_outputs = folder +  "/generated_samples_2000epochs"
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [52]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [56]:
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


[0, 1, 2, 162, 23, 32, 62, 47, 70, 24, 58, 15, 16, 38, 9, 15, 78, 64, 62, 41, 74, 32, 62, 31, 0, 67, 32, 26, 7, 23, 64, 46, 15, 8, 32, 75, 15, 10, 38, 33, 34, 16, 38, 75, 37, 78, 64, 46, 22, 17, 64, 26, 22, 90, 32, 62, 7, 0, 1, 24, 62, 15, 67, 32, 9, 15, 4, 32, 73, 47, 13, 38, 61, 41, 70, 24, 21, 34, 0, 1, 24, 73, 22, 67, 32, 9, 15, 4, 32, 73, 22, 23, 32, 61, 22, 8, 32, 73, 22, 72, 51, 9, 22, 10, 64, 62, 15, 91, 32, 9, 22, 13, 32, 73, 22, 70, 54, 9, 22, 16, 54, 73, 22, 78, 54, 9, 22, 17, 32, 62, 22, 90, 32, 62, 22, 27, 54, 9, 22, 74, 38, 62, 22, 0, 1, 51, 9, 22, 67, 32, 73, 22, 67, 54, 9, 22, 4, 32, 9, 22, 23, 32, 9, 15, 8, 64, 86, 19, 91, 51, 20, 22, 13, 32, 61, 22, 70, 53, 9, 22, 70, 32, 20, 22, 16, 32, 20, 22, 78, 32, 76, 22, 17, 32, 84, 22, 17, 64, 76, 22, 90, 32, 20, 22, 27, 32, 61, 22, 27, 51, 73, 22, 0, 1, 11, 61, 22, 1, 32, 20, 22, 1, 51, 21, 22, 67, 64, 61, 22, 67, 32, 61, 22, 67, 32, 9, 22, 4, 54, 73, 22, 23, 32, 9, 22, 8, 32, 62, 15, 91, 11, 26, 22, 13, 32, 62, 22, 70, 5, 46

[0, 1, 2, 147, 67, 54, 21, 47, 67, 38, 20, 52, 23, 64, 76, 31, 72, 32, 40, 59, 91, 54, 180, 31, 13, 35, 131, 15, 70, 54, 92, 15, 16, 54, 75, 31, 78, 35, 33, 31, 17, 35, 62, 28, 90, 51, 9, 42, 74, 32, 76, 31, 0, 67, 51, 20, 31, 23, 51, 76, 52, 72, 64, 20, 31, 91, 35, 152, 34, 91, 51, 21, 31, 13, 71, 93, 41, 70, 53, 29, 31, 70, 51, 73, 31, 16, 49, 46, 31, 78, 53, 26, 52, 78, 64, 21, 31, 17, 35, 9, 31, 90, 51, 76, 52, 74, 35, 40, 52, 0, 67, 64, 20, 34, 72, 54, 76, 15, 10, 64, 20, 15, 91, 64, 76, 77, 74, 51, 20, 31, 0, 67, 51, 21, 31, 23, 64, 73, 31, 72, 64, 9, 31, 91, 51, 73, 31, 70, 64, 21, 31, 78, 51, 73, 15, 17, 51, 9, 15, 90, 64, 62, 42, 0, 67, 51, 20, 34, 72, 35, 180, 42, 72, 51, 76, 31, 91, 45, 92, 19, 91, 64, 40, 34, 13, 45, 111, 19, 70, 35, 75, 31, 16, 49, 33, 59, 78, 51, 76, 31, 90, 65, 6, 42, 90, 51, 40, 31, 74, 32, 97, 15, 0, 1, 51, 40, 15, 67, 32, 76, 31, 23, 51, 20, 15, 8, 51, 21, 15, 72, 51, 20, 31, 91, 64, 76, 31, 70, 51, 40, 31, 78, 53, 62, 7, 78, 51, 76, 15, 17, 49, 26, 1

[0, 1, 43, 124, 8, 43, 124, 10, 18, 62, 34, 13, 43, 124, 16, 5, 48, 34, 17, 43, 124, 27, 5, 26, 34, 0, 1, 43, 124, 8, 43, 124, 13, 43, 124, 17, 43, 124, 17, 35, 48, 41, 0, 1, 43, 124, 1, 43, 124, 8, 43, 124, 8, 43, 124, 13, 43, 124, 13, 43, 124, 16, 35, 62, 31, 17, 43, 124, 17, 43, 124, 27, 35, 26, 34, 0, 1, 43, 124, 8, 43, 124, 13, 43, 124, 16, 35, 62, 31, 17, 43, 124, 17, 35, 62, 31, 27, 35, 26, 34, 0, 1, 43, 124, 8, 43, 124, 8, 35, 62, 28, 13, 43, 124, 16, 35, 46, 34, 17, 43, 124, 27, 35, 75, 108, 0, 1, 43, 124, 8, 43, 124, 10, 35, 76, 41, 13, 43, 124, 16, 54, 76, 34, 17, 43, 124, 27, 35, 20, 31, 0, 1, 43, 124, 1, 53, 73, 7, 4, 35, 73, 31, 8, 43, 124, 8, 35, 62, 31, 10, 35, 73, 66, 13, 43, 124, 16, 35, 61, 34, 17, 43, 124, 27, 53, 73, 56, 0, 1, 43, 124, 8, 43, 124, 13, 43, 124, 13, 43, 124, 13, 43, 124, 16, 35, 73, 34, 17, 43, 124, 17, 35, 73, 66, 0, 1, 43, 124, 8, 43, 124, 10, 35, 73, 31, 13, 43, 124, 16, 35, 58, 31, 17, 43, 124, 17, 43, 124, 0, 1, 43, 124, 8, 43, 124, 8, 43, 124, 

[0, 1, 2, 139, 4, 11, 97, 22, 23, 64, 97, 22, 8, 14, 120, 52, 91, 32, 97, 31, 16, 5, 97, 22, 78, 5, 118, 22, 17, 5, 120, 47, 0, 4, 14, 97, 22, 23, 11, 40, 22, 8, 38, 118, 52, 91, 14, 40, 31, 16, 38, 97, 22, 78, 14, 40, 22, 17, 5, 20, 66, 0, 4, 14, 40, 22, 23, 14, 97, 22, 8, 14, 97, 19, 91, 51, 40, 22, 13, 25, 97, 7, 16, 14, 97, 22, 78, 25, 120, 22, 17, 18, 97, 66, 0, 4, 11, 40, 22, 23, 14, 120, 22, 8, 5, 106, 7, 10, 11, 120, 22, 91, 5, 118, 7, 70, 11, 97, 22, 16, 11, 118, 7, 17, 25, 120, 66, 0, 8, 25, 20, 47, 8, 25, 120, 47, 16, 18, 76, 22, 16, 18, 106, 22, 78, 5, 20, 22, 78, 5, 120, 22, 17, 25, 73, 66, 17, 25, 97, 66, 0, 4, 5, 73, 22, 4, 5, 97, 22, 23, 25, 9, 22, 23, 25, 40, 22, 8, 25, 73, 52, 8, 25, 97, 52, 91, 14, 21, 22, 91, 14, 118, 22, 13, 25, 73, 7, 13, 25, 97, 7, 16, 18, 9, 22, 16, 18, 40, 22, 78, 14, 73, 22, 78, 14, 97, 22, 17, 18, 26, 66, 17, 18, 20, 66, 0, 8, 18, 73, 66, 8, 18, 97, 66, 16, 25, 9, 15, 16, 25, 40, 15, 78, 11, 73, 22, 78, 11, 97, 22, 17, 5, 9, 19, 17, 5, 40, 19

[0, 1, 2, 115, 74, 11, 96, 7, 0, 67, 11, 96, 31, 23, 38, 86, 31, 72, 11, 96, 31, 91, 11, 86, 7, 70, 24, 12, 22, 16, 11, 57, 47, 74, 32, 12, 15, 0, 1, 54, 57, 15, 67, 32, 12, 22, 23, 38, 86, 41, 72, 64, 12, 15, 10, 64, 12, 34, 70, 5, 86, 31, 78, 11, 12, 34, 74, 32, 96, 34, 0, 67, 32, 57, 7, 23, 54, 12, 31, 72, 51, 57, 42, 0, 67, 11, 96, 52, 72, 11, 50, 31, 91, 11, 12, 41, 70, 32, 86, 31, 78, 38, 96, 31, 90, 38, 86, 31, 74, 32, 12, 31, 0, 67, 32, 86, 7, 23, 32, 96, 15, 8, 54, 86, 15, 72, 32, 96, 7, 91, 11, 57, 15, 13, 38, 86, 7, 70, 38, 40, 28, 74, 32, 98, 87, 0, 67, 38, 82, 31, 23, 38, 86, 7, 23, 38, 96, 15, 72, 11, 86, 31, 91, 11, 86, 31, 70, 38, 96, 31, 78, 38, 86, 85, 0, 67, 32, 96, 31, 23, 32, 86, 31, 72, 11, 12, 31, 91, 32, 57, 31, 70, 32, 6, 59]
35
[0, 1, 2, 202, 8, 51, 55, 7, 8, 32, 50, 52, 10, 51, 55, 19, 10, 35, 57, 7, 91, 32, 96, 19, 16, 64, 55, 7, 16, 51, 57, 31, 17, 64, 29, 7, 17, 64, 50, 52, 27, 64, 29, 19, 27, 64, 9, 7, 74, 32, 96, 34, 0, 4, 64, 29, 31, 4, 64, 9, 7, 8, 64,

[0, 1, 2, 89, 67, 14, 21, 41, 8, 2, 89, 72, 14, 12, 15, 10, 11, 12, 15, 91, 5, 12, 28, 13, 2, 89, 17, 2, 89, 90, 5, 57, 31, 74, 25, 12, 7, 0, 1, 2, 89, 67, 25, 96, 34, 8, 2, 89, 72, 25, 98, 7, 91, 18, 82, 28, 13, 2, 89, 17, 2, 89, 90, 11, 12, 15, 27, 32, 86, 22, 74, 5, 39, 52, 0, 1, 2, 89, 4, 24, 39, 52, 8, 2, 89, 72, 14, 98, 22, 10, 51, 39, 22, 91, 32, 98, 34, 13, 2, 89, 78, 14, 30, 7, 17, 2, 89, 90, 14, 12, 7, 74, 5, 57, 37, 0, 1, 2, 89, 8, 2, 89, 91, 14, 57, 31, 13, 2, 89, 70, 51, 6, 22, 16, 24, 57, 15, 78, 24, 12, 7, 17, 2, 89, 90, 51, 50, 15, 27, 45, 33, 22, 74, 32, 50, 77, 0, 1, 2, 89, 8, 2, 89, 13, 2, 89, 17, 2, 89, 90, 51, 30, 15, 27, 11, 33, 15, 74, 11, 50, 15, 0, 1, 2, 89, 1, 38, 50, 15, 67, 11, 50, 15, 4, 24, 50, 22, 23, 11, 6, 15, 8, 2, 89, 8, 24, 57, 31, 10, 14, 57, 15, 91, 14, 12, 41, 13, 2, 89, 78, 25, 12, 31, 17, 2, 89, 90, 18, 50, 80, 0, 1, 2, 89, 8, 2, 89, 13, 2, 89, 78, 18, 58, 7, 17, 2, 89, 90, 25, 61, 42, 0, 1, 2, 89, 23, 5, 12, 15, 8, 2, 89, 8, 11, 61, 15, 72, 14,

[0, 1, 43, 161, 1, 38, 9, 31, 4, 45, 57, 15, 8, 63, 9, 31, 10, 64, 50, 52, 13, 38, 33, 125, 0, 1, 64, 9, 31, 4, 53, 57, 52, 8, 54, 9, 52, 10, 35, 6, 52, 13, 54, 26, 85, 0, 1, 11, 57, 31, 4, 32, 12, 52, 8, 38, 57, 52, 10, 35, 9, 31, 13, 38, 6, 85, 0, 1, 32, 9, 31, 4, 35, 57, 52, 8, 51, 9, 31, 10, 53, 6, 7, 13, 32, 6, 36, 17, 32, 30, 19, 0, 1, 38, 30, 125, 10, 11, 50, 42, 0, 1, 38, 33, 42, 8, 11, 62, 41, 13, 11, 21, 66, 17, 11, 12, 41, 0, 1, 38, 12, 66, 8, 38, 73, 41, 13, 32, 73, 19, 17, 11, 9, 41, 0, 1, 38, 9, 34, 8, 64, 6, 66, 13, 64, 62, 19, 17, 11, 48, 41, 0, 1, 38, 26, 15, 4, 11, 26, 68, 4, 32, 6, 31, 8, 32, 62, 31, 10, 32, 73, 34, 16, 32, 62, 31, 78, 201, 26, 34, 17, 32, 33, 31, 27, 32, 9, 34, 0, 4, 32, 33, 31, 8, 32, 30, 31, 10, 32, 6, 34, 16, 32, 33, 31, 17, 32, 29, 31, 27, 32, 75, 31, 0, 4, 32, 6, 31, 8, 32, 62, 31, 10, 32, 73, 34, 16, 32, 62, 31, 17, 32, 33, 31, 27, 32, 9, 34, 0, 4, 32, 33, 31, 8, 32, 30, 31, 10, 32, 6, 34, 16, 32, 33, 31, 17, 32, 29, 31, 90, 32, 75, 31]
52
[0,

[0, 1, 43, 170, 67, 38, 73, 47, 91, 38, 61, 7, 70, 11, 86, 77, 0, 67, 11, 61, 41, 72, 11, 61, 7, 10, 11, 73, 22, 91, 64, 61, 22, 13, 35, 73, 15, 70, 32, 58, 36, 0, 67, 32, 61, 7, 23, 38, 61, 41, 10, 38, 73, 7, 13, 11, 58, 52, 17, 24, 48, 52, 27, 14, 58, 7, 0, 67, 14, 62, 112, 0, 67, 38, 73, 41, 72, 38, 73, 7, 91, 38, 61, 7, 70, 38, 86, 37, 0, 67, 24, 84, 52, 8, 14, 84, 52, 91, 24, 76, 7, 70, 38, 86, 37, 0, 67, 24, 61, 19, 72, 11, 61, 7, 91, 11, 86, 7, 70, 38, 61, 15, 78, 11, 86, 15, 90, 14, 62, 7, 74, 38, 58, 7, 0, 67, 38, 73, 136, 23, 32, 62, 31, 72, 64, 46, 31, 91, 51, 48, 31, 13, 64, 75, 37]
62
[0, 1, 2, 89, 23, 35, 20, 34, 91, 35, 82, 52, 70, 49, 39, 15, 16, 35, 40, 15, 78, 35, 82, 77, 0, 23, 53, 20, 34, 91, 35, 82, 52, 70, 35, 40, 15, 16, 53, 39, 15, 78, 53, 82, 37, 0, 23, 35, 39, 34, 91, 35, 118, 31, 70, 53, 120, 52, 78, 35, 118, 34, 74, 53, 39, 66, 0, 23, 54, 118, 31, 72, 53, 39, 31, 91, 53, 118, 31, 70, 49, 120, 31, 78, 35, 39, 77, 0, 23, 35, 82, 34, 91, 35, 20, 31, 70, 35, 82,

[0, 1, 43, 105, 91, 5, 57, 22, 13, 14, 61, 22, 70, 25, 86, 22, 16, 5, 96, 22, 78, 5, 84, 22, 90, 5, 86, 15, 74, 25, 61, 7, 0, 67, 5, 96, 28, 78, 5, 96, 7, 90, 5, 21, 15, 74, 5, 57, 15, 0, 67, 25, 86, 66, 91, 5, 57, 22, 13, 5, 58, 22, 70, 5, 86, 22, 16, 5, 96, 22, 78, 14, 84, 22, 90, 5, 86, 7, 74, 5, 58, 15, 0, 67, 5, 96, 19, 72, 5, 57, 15, 91, 5, 96, 15, 70, 25, 86, 77, 0, 4, 5, 57, 22, 23, 5, 57, 7, 72, 25, 58, 22, 91, 5, 48, 22, 91, 14, 50, 15, 91, 5, 58, 22, 13, 5, 50, 22, 13, 14, 57, 22, 70, 18, 48, 31, 70, 25, 86, 22, 16, 5, 86, 22, 78, 5, 96, 22, 90, 5, 86, 15, 74, 25, 61, 7, 0, 67, 5, 96, 28, 78, 5, 96, 7, 90, 5, 21, 15, 74, 5, 58, 15, 0, 67, 25, 86, 66, 91, 5, 58, 22, 13, 5, 57, 22, 70, 5, 86, 22, 16, 5, 96, 22, 78, 14, 84, 22, 90, 5, 86, 7, 74, 5, 57, 15, 0, 67, 5, 96, 19, 72, 5, 58, 15, 91, 5, 96, 15, 13, 25, 86, 77]
72
[0, 1, 43, 105, 91, 25, 9, 31, 70, 5, 9, 103, 0, 23, 25, 9, 52, 72, 25, 73, 34, 70, 25, 21, 37, 0, 72, 18, 21, 34, 70, 25, 20, 34, 70, 18, 82, 34, 70, 25, 76,

[0, 1, 2, 155, 72, 11, 33, 7, 10, 24, 26, 59, 74, 18, 73, 7, 0, 67, 11, 62, 31, 23, 32, 26, 31, 72, 25, 33, 37, 78, 18, 9, 7, 90, 24, 6, 34, 0, 67, 5, 62, 41, 23, 25, 6, 15, 8, 38, 73, 42, 70, 11, 30, 37, 0, 67, 5, 26, 31, 23, 38, 62, 15, 8, 25, 9, 66, 70, 32, 75, 80, 74, 5, 73, 31, 0, 67, 18, 21, 31, 23, 79, 76, 52, 72, 79, 26, 103, 72, 11, 20, 34, 74, 79, 9, 31, 0, 67, 5, 62, 31, 23, 24, 26, 31, 72, 18, 33, 59, 70, 11, 73, 31, 78, 5, 40, 31, 90, 79, 6, 77, 90, 79, 82, 85, 0, 72, 25, 6, 52, 91, 64, 50, 7, 70, 24, 30, 31, 78, 5, 33, 41, 74, 11, 111, 7, 0, 67, 11, 55, 31, 23, 79, 30, 68, 74, 11, 55, 52, 0, 67, 5, 30, 7, 23, 38, 50, 31, 72, 79, 33, 22, 72, 79, 33, 22, 91, 79, 33, 22, 70, 18, 75, 22, 78, 25, 33, 22, 90, 25, 33, 22, 74, 5, 33, 22, 0, 67, 25, 33, 22, 23, 79, 33, 22, 23, 38, 62, 22, 8, 35, 62, 22, 72, 5, 26, 37, 78, 18, 48, 102, 0, 13, 35, 111, 15, 70, 24, 33, 7, 16, 35, 111, 15, 16, 64, 29, 80]
82
[0, 1, 2, 194, 1, 2, 194, 8, 38, 100, 31, 10, 38, 104, 31, 13, 2, 194, 13, 11

[0, 1, 2, 115, 4, 18, 92, 31, 72, 25, 122, 7, 13, 25, 29, 52, 17, 5, 29, 15, 27, 5, 92, 15, 0, 1, 38, 122, 15, 4, 14, 55, 15, 8, 11, 29, 15, 10, 5, 100, 15, 13, 11, 29, 15, 16, 18, 92, 52, 27, 18, 92, 52, 0, 4, 18, 92, 31, 72, 14, 122, 31, 13, 25, 29, 52, 17, 18, 92, 15, 27, 54, 122, 15, 0, 1, 25, 111, 15, 4, 24, 93, 15, 8, 14, 127, 15, 10, 24, 152, 52, 16, 5, 92, 19, 17, 45, 21, 15, 27, 18, 29, 52, 27, 11, 21, 22, 0, 1, 11, 21, 15, 4, 18, 92, 31, 4, 14, 21, 15, 8, 25, 61, 22, 72, 25, 122, 7, 10, 11, 12, 15, 13, 25, 29, 52, 13, 24, 73, 15, 17, 5, 29, 15, 17, 45, 21, 15, 27, 5, 92, 15, 27, 11, 21, 22, 0, 1, 38, 122, 15, 1, 11, 21, 15, 4, 14, 55, 15, 4, 14, 21, 15, 8, 11, 29, 15, 8, 25, 61, 22, 10, 5, 100, 15, 10, 11, 12, 15, 13, 11, 29, 15, 13, 24, 73, 15, 16, 18, 92, 52, 17, 45, 21, 15, 27, 18, 92, 52, 27, 11, 21, 22, 0, 1, 11, 21, 15, 4, 18, 92, 31, 4, 14, 21, 15, 8, 25, 61, 22, 72, 14, 122, 31, 10, 11, 12, 15, 13, 25, 29, 52, 13, 24, 73, 15, 17, 14, 93, 52, 17, 11, 9, 15, 27, 14, 29,

[0, 1, 2, 188, 1, 5, 62, 22, 67, 5, 58, 22, 67, 5, 9, 22, 4, 5, 58, 22, 4, 5, 62, 22, 23, 5, 9, 22, 23, 5, 58, 22, 8, 5, 9, 22, 8, 5, 57, 22, 72, 5, 73, 22, 72, 5, 12, 22, 10, 5, 61, 22, 10, 5, 21, 22, 91, 5, 86, 22, 91, 5, 20, 34, 78, 5, 20, 52, 27, 18, 61, 22, 74, 25, 20, 7, 0, 67, 14, 76, 22, 4, 18, 20, 22, 23, 18, 61, 7, 72, 25, 73, 7, 91, 5, 20, 7, 70, 5, 76, 22, 16, 14, 20, 22, 78, 18, 61, 7, 90, 14, 73, 7, 74, 14, 58, 36, 0, 23, 24, 107, 7, 72, 14, 97, 7, 91, 14, 26, 77, 91, 24, 84, 34, 90, 14, 97, 15, 74, 14, 48, 37, 74, 25, 76, 7, 0, 67, 5, 20, 7, 23, 5, 61, 41, 91, 5, 20, 22, 13, 32, 20, 22, 70, 25, 86, 7, 78, 5, 61, 7, 90, 25, 73, 7, 74, 24, 20, 22, 0, 1, 24, 20, 22, 67, 18, 86, 22, 4, 24, 86, 22, 23, 5, 61, 7, 72, 5, 73, 7]
101
[0, 1, 2, 188, 1, 25, 61, 66, 8, 2, 188, 10, 25, 86, 22, 91, 24, 61, 22, 13, 2, 188, 13, 14, 61, 7, 16, 14, 86, 15, 17, 2, 188, 17, 25, 96, 52, 0, 1, 2, 188, 1, 14, 84, 7, 4, 24, 98, 15, 8, 2, 188, 8, 24, 98, 7, 10, 11, 84, 7, 13, 32, 98, 19, 17, 11,

[0, 1, 2, 183, 67, 64, 6, 19, 4, 53, 50, 31, 23, 35, 9, 52, 10, 51, 6, 136, 91, 51, 6, 41, 74, 51, 9, 7, 0, 67, 35, 6, 31, 23, 35, 6, 7, 8, 54, 9, 47, 13, 54, 33, 15, 70, 54, 50, 31, 78, 54, 33, 7, 90, 35, 6, 15, 74, 35, 9, 66, 0, 67, 53, 33, 52, 23, 53, 33, 52, 72, 54, 50, 7, 91, 53, 9, 7, 70, 54, 6, 41, 78, 54, 33, 52, 17, 54, 9, 31, 74, 54, 50, 15, 0, 1, 54, 21, 41, 67, 35, 12, 19, 23, 54, 21, 31, 72, 53, 33, 31, 91, 53, 30, 7, 70, 53, 33, 31, 78, 64, 6, 15, 17, 54, 9, 15, 90, 51, 33, 31, 74, 51, 30, 31, 0, 67, 51, 33, 7, 4, 54, 33, 7, 23, 54, 9, 15, 8, 54, 20, 7, 72, 51, 82, 15, 91, 35, 33, 31, 70, 54, 21, 7, 78, 51, 20, 15, 17, 51, 82, 52, 74, 32, 40, 15, 0, 67, 54, 21, 15, 4, 54, 12, 31, 23, 53, 82, 31, 10, 49, 9, 7, 91, 35, 73, 31, 70, 51, 9, 41, 78, 54, 12, 22, 17, 54, 21, 47, 74, 54, 12, 15, 0, 1, 54, 21, 15, 67, 51, 20, 52, 23, 51, 82, 31, 72, 54, 82, 31, 91, 35, 21, 85, 78, 54, 12, 7, 17, 54, 20, 7, 90, 51, 21, 7, 74, 54, 21, 7, 0, 91, 51, 40, 52, 70, 32, 40, 19, 74, 54, 21,

In [57]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

30480
30480
33720
33720
40200
40200
34200
34200
27720
27720
23520
23520
53340
53340
44280
44280
25800
25800
27600
27600
20400
20400
24060
24060
30900
30900
23760
23760
31680
31680
33420
33420
30720
30720
31560
31560
61500
61500
45780
45780
31380
31380
26280
26280


In [58]:
# dissimilar_interpolation
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    #solo = test_solo[i]
    if i + 3 < (len(test_intro)):
        outro = test_outro[i+3]
    else:
        outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    #list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, dissimilar_interpolation + "/intro/" + "/intro" + str(i)  + ".mid")
    #utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, dissimilar_interpolation + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, dissimilar_interpolation + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


[0, 1, 2, 162, 23, 32, 62, 47, 70, 24, 58, 15, 16, 38, 9, 15, 78, 64, 62, 41, 74, 32, 62, 31, 0, 67, 32, 26, 7, 23, 64, 46, 15, 8, 32, 75, 15, 10, 38, 33, 34, 16, 38, 75, 37, 78, 64, 46, 22, 17, 64, 26, 22, 90, 32, 62, 7, 0, 1, 24, 62, 15, 67, 32, 9, 15, 4, 32, 73, 47, 13, 38, 61, 41, 70, 24, 21, 34, 0, 1, 24, 73, 22, 67, 32, 9, 15, 4, 32, 73, 22, 23, 32, 61, 22, 8, 32, 73, 22, 72, 51, 9, 22, 10, 64, 62, 15, 91, 32, 9, 22, 13, 32, 73, 22, 70, 54, 9, 22, 16, 54, 73, 22, 78, 54, 9, 22, 17, 32, 62, 22, 90, 32, 62, 22, 27, 54, 9, 22, 74, 38, 62, 22, 0, 1, 51, 9, 22, 67, 32, 73, 22, 67, 54, 9, 22, 4, 32, 9, 22, 23, 32, 9, 15, 8, 64, 86, 19, 91, 51, 20, 22, 13, 32, 61, 22, 70, 53, 9, 22, 70, 32, 20, 22, 16, 32, 20, 22, 78, 32, 76, 22, 17, 32, 84, 22, 17, 64, 76, 22, 90, 32, 20, 22, 27, 32, 61, 22, 27, 51, 73, 22, 0, 1, 11, 61, 22, 1, 32, 20, 22, 1, 51, 21, 22, 67, 64, 61, 22, 67, 32, 61, 22, 67, 32, 9, 22, 4, 54, 73, 22, 23, 32, 9, 22, 8, 32, 62, 15, 91, 11, 26, 22, 13, 32, 62, 22, 70, 5, 46

[0, 1, 2, 147, 67, 54, 21, 47, 67, 38, 20, 52, 23, 64, 76, 31, 72, 32, 40, 59, 91, 54, 180, 31, 13, 35, 131, 15, 70, 54, 92, 15, 16, 54, 75, 31, 78, 35, 33, 31, 17, 35, 62, 28, 90, 51, 9, 42, 74, 32, 76, 31, 0, 67, 51, 20, 31, 23, 51, 76, 52, 72, 64, 20, 31, 91, 35, 152, 34, 91, 51, 21, 31, 13, 71, 93, 41, 70, 53, 29, 31, 70, 51, 73, 31, 16, 49, 46, 31, 78, 53, 26, 52, 78, 64, 21, 31, 17, 35, 9, 31, 90, 51, 76, 52, 74, 35, 40, 52, 0, 67, 64, 20, 34, 72, 54, 76, 15, 10, 64, 20, 15, 91, 64, 76, 77, 74, 51, 20, 31, 0, 67, 51, 21, 31, 23, 64, 73, 31, 72, 64, 9, 31, 91, 51, 73, 31, 70, 64, 21, 31, 78, 51, 73, 15, 17, 51, 9, 15, 90, 64, 62, 42, 0, 67, 51, 20, 34, 72, 35, 180, 42, 72, 51, 76, 31, 91, 45, 92, 19, 91, 64, 40, 34, 13, 45, 111, 19, 70, 35, 75, 31, 16, 49, 33, 59, 78, 51, 76, 31, 90, 65, 6, 42, 90, 51, 40, 31, 74, 32, 97, 15, 0, 1, 51, 40, 15, 67, 32, 76, 31, 23, 51, 20, 15, 8, 51, 21, 15, 72, 51, 20, 31, 91, 64, 76, 31, 70, 51, 40, 31, 78, 53, 62, 7, 78, 51, 76, 15, 17, 49, 26, 1

[0, 1, 43, 124, 8, 43, 124, 10, 18, 62, 34, 13, 43, 124, 16, 5, 48, 34, 17, 43, 124, 27, 5, 26, 34, 0, 1, 43, 124, 8, 43, 124, 13, 43, 124, 17, 43, 124, 17, 35, 48, 41, 0, 1, 43, 124, 1, 43, 124, 8, 43, 124, 8, 43, 124, 13, 43, 124, 13, 43, 124, 16, 35, 62, 31, 17, 43, 124, 17, 43, 124, 27, 35, 26, 34, 0, 1, 43, 124, 8, 43, 124, 13, 43, 124, 16, 35, 62, 31, 17, 43, 124, 17, 35, 62, 31, 27, 35, 26, 34, 0, 1, 43, 124, 8, 43, 124, 8, 35, 62, 28, 13, 43, 124, 16, 35, 46, 34, 17, 43, 124, 27, 35, 75, 108, 0, 1, 43, 124, 8, 43, 124, 10, 35, 76, 41, 13, 43, 124, 16, 54, 76, 34, 17, 43, 124, 27, 35, 20, 31, 0, 1, 43, 124, 1, 53, 73, 7, 4, 35, 73, 31, 8, 43, 124, 8, 35, 62, 31, 10, 35, 73, 66, 13, 43, 124, 16, 35, 61, 34, 17, 43, 124, 27, 53, 73, 56, 0, 1, 43, 124, 8, 43, 124, 13, 43, 124, 13, 43, 124, 13, 43, 124, 16, 35, 73, 34, 17, 43, 124, 17, 35, 73, 66, 0, 1, 43, 124, 8, 43, 124, 10, 35, 73, 31, 13, 43, 124, 16, 35, 58, 31, 17, 43, 124, 17, 43, 124, 0, 1, 43, 124, 8, 43, 124, 8, 43, 124, 

[0, 1, 2, 139, 4, 11, 97, 22, 23, 64, 97, 22, 8, 14, 120, 52, 91, 32, 97, 31, 16, 5, 97, 22, 78, 5, 118, 22, 17, 5, 120, 47, 0, 4, 14, 97, 22, 23, 11, 40, 22, 8, 38, 118, 52, 91, 14, 40, 31, 16, 38, 97, 22, 78, 14, 40, 22, 17, 5, 20, 66, 0, 4, 14, 40, 22, 23, 14, 97, 22, 8, 14, 97, 19, 91, 51, 40, 22, 13, 25, 97, 7, 16, 14, 97, 22, 78, 25, 120, 22, 17, 18, 97, 66, 0, 4, 11, 40, 22, 23, 14, 120, 22, 8, 5, 106, 7, 10, 11, 120, 22, 91, 5, 118, 7, 70, 11, 97, 22, 16, 11, 118, 7, 17, 25, 120, 66, 0, 8, 25, 20, 47, 8, 25, 120, 47, 16, 18, 76, 22, 16, 18, 106, 22, 78, 5, 20, 22, 78, 5, 120, 22, 17, 25, 73, 66, 17, 25, 97, 66, 0, 4, 5, 73, 22, 4, 5, 97, 22, 23, 25, 9, 22, 23, 25, 40, 22, 8, 25, 73, 52, 8, 25, 97, 52, 91, 14, 21, 22, 91, 14, 118, 22, 13, 25, 73, 7, 13, 25, 97, 7, 16, 18, 9, 22, 16, 18, 40, 22, 78, 14, 73, 22, 78, 14, 97, 22, 17, 18, 26, 66, 17, 18, 20, 66, 0, 8, 18, 73, 66, 8, 18, 97, 66, 16, 25, 9, 15, 16, 25, 40, 15, 78, 11, 73, 22, 78, 11, 97, 22, 17, 5, 9, 19, 17, 5, 40, 19

[0, 1, 2, 115, 74, 11, 96, 7, 0, 67, 11, 96, 31, 23, 38, 86, 31, 72, 11, 96, 31, 91, 11, 86, 7, 70, 24, 12, 22, 16, 11, 57, 47, 74, 32, 12, 15, 0, 1, 54, 57, 15, 67, 32, 12, 22, 23, 38, 86, 41, 72, 64, 12, 15, 10, 64, 12, 34, 70, 5, 86, 31, 78, 11, 12, 34, 74, 32, 96, 34, 0, 67, 32, 57, 7, 23, 54, 12, 31, 72, 51, 57, 42, 0, 67, 11, 96, 52, 72, 11, 50, 31, 91, 11, 12, 41, 70, 32, 86, 31, 78, 38, 96, 31, 90, 38, 86, 31, 74, 32, 12, 31, 0, 67, 32, 86, 7, 23, 32, 96, 15, 8, 54, 86, 15, 72, 32, 96, 7, 91, 11, 57, 15, 13, 38, 86, 7, 70, 38, 40, 28, 74, 32, 98, 87, 0, 67, 38, 82, 31, 23, 38, 86, 7, 23, 38, 96, 15, 72, 11, 86, 31, 91, 11, 86, 31, 70, 38, 96, 31, 78, 38, 86, 85, 0, 67, 32, 96, 31, 23, 32, 86, 31, 72, 11, 12, 31, 91, 32, 57, 31, 70, 32, 6, 59]
35
[0, 1, 2, 202, 8, 51, 55, 7, 8, 32, 50, 52, 10, 51, 55, 19, 10, 35, 57, 7, 91, 32, 96, 19, 16, 64, 55, 7, 16, 51, 57, 31, 17, 64, 29, 7, 17, 64, 50, 52, 27, 64, 29, 19, 27, 64, 9, 7, 74, 32, 96, 34, 0, 4, 64, 29, 31, 4, 64, 9, 7, 8, 64,

[0, 1, 2, 89, 67, 14, 21, 41, 8, 2, 89, 72, 14, 12, 15, 10, 11, 12, 15, 91, 5, 12, 28, 13, 2, 89, 17, 2, 89, 90, 5, 57, 31, 74, 25, 12, 7, 0, 1, 2, 89, 67, 25, 96, 34, 8, 2, 89, 72, 25, 98, 7, 91, 18, 82, 28, 13, 2, 89, 17, 2, 89, 90, 11, 12, 15, 27, 32, 86, 22, 74, 5, 39, 52, 0, 1, 2, 89, 4, 24, 39, 52, 8, 2, 89, 72, 14, 98, 22, 10, 51, 39, 22, 91, 32, 98, 34, 13, 2, 89, 78, 14, 30, 7, 17, 2, 89, 90, 14, 12, 7, 74, 5, 57, 37, 0, 1, 2, 89, 8, 2, 89, 91, 14, 57, 31, 13, 2, 89, 70, 51, 6, 22, 16, 24, 57, 15, 78, 24, 12, 7, 17, 2, 89, 90, 51, 50, 15, 27, 45, 33, 22, 74, 32, 50, 77, 0, 1, 2, 89, 8, 2, 89, 13, 2, 89, 17, 2, 89, 90, 51, 30, 15, 27, 11, 33, 15, 74, 11, 50, 15, 0, 1, 2, 89, 1, 38, 50, 15, 67, 11, 50, 15, 4, 24, 50, 22, 23, 11, 6, 15, 8, 2, 89, 8, 24, 57, 31, 10, 14, 57, 15, 91, 14, 12, 41, 13, 2, 89, 78, 25, 12, 31, 17, 2, 89, 90, 18, 50, 80, 0, 1, 2, 89, 8, 2, 89, 13, 2, 89, 78, 18, 58, 7, 17, 2, 89, 90, 25, 61, 42, 0, 1, 2, 89, 23, 5, 12, 15, 8, 2, 89, 8, 11, 61, 15, 72, 14,

[0, 1, 43, 161, 1, 38, 9, 31, 4, 45, 57, 15, 8, 63, 9, 31, 10, 64, 50, 52, 13, 38, 33, 125, 0, 1, 64, 9, 31, 4, 53, 57, 52, 8, 54, 9, 52, 10, 35, 6, 52, 13, 54, 26, 85, 0, 1, 11, 57, 31, 4, 32, 12, 52, 8, 38, 57, 52, 10, 35, 9, 31, 13, 38, 6, 85, 0, 1, 32, 9, 31, 4, 35, 57, 52, 8, 51, 9, 31, 10, 53, 6, 7, 13, 32, 6, 36, 17, 32, 30, 19, 0, 1, 38, 30, 125, 10, 11, 50, 42, 0, 1, 38, 33, 42, 8, 11, 62, 41, 13, 11, 21, 66, 17, 11, 12, 41, 0, 1, 38, 12, 66, 8, 38, 73, 41, 13, 32, 73, 19, 17, 11, 9, 41, 0, 1, 38, 9, 34, 8, 64, 6, 66, 13, 64, 62, 19, 17, 11, 48, 41, 0, 1, 38, 26, 15, 4, 11, 26, 68, 4, 32, 6, 31, 8, 32, 62, 31, 10, 32, 73, 34, 16, 32, 62, 31, 78, 201, 26, 34, 17, 32, 33, 31, 27, 32, 9, 34, 0, 4, 32, 33, 31, 8, 32, 30, 31, 10, 32, 6, 34, 16, 32, 33, 31, 17, 32, 29, 31, 27, 32, 75, 31, 0, 4, 32, 6, 31, 8, 32, 62, 31, 10, 32, 73, 34, 16, 32, 62, 31, 17, 32, 33, 31, 27, 32, 9, 34, 0, 4, 32, 33, 31, 8, 32, 30, 31, 10, 32, 6, 34, 16, 32, 33, 31, 17, 32, 29, 31, 90, 32, 75, 31]
52
[0,

[0, 1, 43, 170, 67, 38, 73, 47, 91, 38, 61, 7, 70, 11, 86, 77, 0, 67, 11, 61, 41, 72, 11, 61, 7, 10, 11, 73, 22, 91, 64, 61, 22, 13, 35, 73, 15, 70, 32, 58, 36, 0, 67, 32, 61, 7, 23, 38, 61, 41, 10, 38, 73, 7, 13, 11, 58, 52, 17, 24, 48, 52, 27, 14, 58, 7, 0, 67, 14, 62, 112, 0, 67, 38, 73, 41, 72, 38, 73, 7, 91, 38, 61, 7, 70, 38, 86, 37, 0, 67, 24, 84, 52, 8, 14, 84, 52, 91, 24, 76, 7, 70, 38, 86, 37, 0, 67, 24, 61, 19, 72, 11, 61, 7, 91, 11, 86, 7, 70, 38, 61, 15, 78, 11, 86, 15, 90, 14, 62, 7, 74, 38, 58, 7, 0, 67, 38, 73, 136, 23, 32, 62, 31, 72, 64, 46, 31, 91, 51, 48, 31, 13, 64, 75, 37]
62
[0, 1, 2, 89, 23, 35, 20, 34, 91, 35, 82, 52, 70, 49, 39, 15, 16, 35, 40, 15, 78, 35, 82, 77, 0, 23, 53, 20, 34, 91, 35, 82, 52, 70, 35, 40, 15, 16, 53, 39, 15, 78, 53, 82, 37, 0, 23, 35, 39, 34, 91, 35, 118, 31, 70, 53, 120, 52, 78, 35, 118, 34, 74, 53, 39, 66, 0, 23, 54, 118, 31, 72, 53, 39, 31, 91, 53, 118, 31, 70, 49, 120, 31, 78, 35, 39, 77, 0, 23, 35, 82, 34, 91, 35, 20, 31, 70, 35, 82,

[0, 1, 43, 105, 91, 5, 57, 22, 13, 14, 61, 22, 70, 25, 86, 22, 16, 5, 96, 22, 78, 5, 84, 22, 90, 5, 86, 15, 74, 25, 61, 7, 0, 67, 5, 96, 28, 78, 5, 96, 7, 90, 5, 21, 15, 74, 5, 57, 15, 0, 67, 25, 86, 66, 91, 5, 57, 22, 13, 5, 58, 22, 70, 5, 86, 22, 16, 5, 96, 22, 78, 14, 84, 22, 90, 5, 86, 7, 74, 5, 58, 15, 0, 67, 5, 96, 19, 72, 5, 57, 15, 91, 5, 96, 15, 70, 25, 86, 77, 0, 4, 5, 57, 22, 23, 5, 57, 7, 72, 25, 58, 22, 91, 5, 48, 22, 91, 14, 50, 15, 91, 5, 58, 22, 13, 5, 50, 22, 13, 14, 57, 22, 70, 18, 48, 31, 70, 25, 86, 22, 16, 5, 86, 22, 78, 5, 96, 22, 90, 5, 86, 15, 74, 25, 61, 7, 0, 67, 5, 96, 28, 78, 5, 96, 7, 90, 5, 21, 15, 74, 5, 58, 15, 0, 67, 25, 86, 66, 91, 5, 58, 22, 13, 5, 57, 22, 70, 5, 86, 22, 16, 5, 96, 22, 78, 14, 84, 22, 90, 5, 86, 7, 74, 5, 57, 15, 0, 67, 5, 96, 19, 72, 5, 58, 15, 91, 5, 96, 15, 13, 25, 86, 77]
72
[0, 1, 43, 105, 91, 25, 9, 31, 70, 5, 9, 103, 0, 23, 25, 9, 52, 72, 25, 73, 34, 70, 25, 21, 37, 0, 72, 18, 21, 34, 70, 25, 20, 34, 70, 18, 82, 34, 70, 25, 76,

[0, 1, 2, 155, 72, 11, 33, 7, 10, 24, 26, 59, 74, 18, 73, 7, 0, 67, 11, 62, 31, 23, 32, 26, 31, 72, 25, 33, 37, 78, 18, 9, 7, 90, 24, 6, 34, 0, 67, 5, 62, 41, 23, 25, 6, 15, 8, 38, 73, 42, 70, 11, 30, 37, 0, 67, 5, 26, 31, 23, 38, 62, 15, 8, 25, 9, 66, 70, 32, 75, 80, 74, 5, 73, 31, 0, 67, 18, 21, 31, 23, 79, 76, 52, 72, 79, 26, 103, 72, 11, 20, 34, 74, 79, 9, 31, 0, 67, 5, 62, 31, 23, 24, 26, 31, 72, 18, 33, 59, 70, 11, 73, 31, 78, 5, 40, 31, 90, 79, 6, 77, 90, 79, 82, 85, 0, 72, 25, 6, 52, 91, 64, 50, 7, 70, 24, 30, 31, 78, 5, 33, 41, 74, 11, 111, 7, 0, 67, 11, 55, 31, 23, 79, 30, 68, 74, 11, 55, 52, 0, 67, 5, 30, 7, 23, 38, 50, 31, 72, 79, 33, 22, 72, 79, 33, 22, 91, 79, 33, 22, 70, 18, 75, 22, 78, 25, 33, 22, 90, 25, 33, 22, 74, 5, 33, 22, 0, 67, 25, 33, 22, 23, 79, 33, 22, 23, 38, 62, 22, 8, 35, 62, 22, 72, 5, 26, 37, 78, 18, 48, 102, 0, 13, 35, 111, 15, 70, 24, 33, 7, 16, 35, 111, 15, 16, 64, 29, 80]
82
[0, 1, 2, 194, 1, 2, 194, 8, 38, 100, 31, 10, 38, 104, 31, 13, 2, 194, 13, 11

[0, 1, 2, 115, 4, 18, 92, 31, 72, 25, 122, 7, 13, 25, 29, 52, 17, 5, 29, 15, 27, 5, 92, 15, 0, 1, 38, 122, 15, 4, 14, 55, 15, 8, 11, 29, 15, 10, 5, 100, 15, 13, 11, 29, 15, 16, 18, 92, 52, 27, 18, 92, 52, 0, 4, 18, 92, 31, 72, 14, 122, 31, 13, 25, 29, 52, 17, 18, 92, 15, 27, 54, 122, 15, 0, 1, 25, 111, 15, 4, 24, 93, 15, 8, 14, 127, 15, 10, 24, 152, 52, 16, 5, 92, 19, 17, 45, 21, 15, 27, 18, 29, 52, 27, 11, 21, 22, 0, 1, 11, 21, 15, 4, 18, 92, 31, 4, 14, 21, 15, 8, 25, 61, 22, 72, 25, 122, 7, 10, 11, 12, 15, 13, 25, 29, 52, 13, 24, 73, 15, 17, 5, 29, 15, 17, 45, 21, 15, 27, 5, 92, 15, 27, 11, 21, 22, 0, 1, 38, 122, 15, 1, 11, 21, 15, 4, 14, 55, 15, 4, 14, 21, 15, 8, 11, 29, 15, 8, 25, 61, 22, 10, 5, 100, 15, 10, 11, 12, 15, 13, 11, 29, 15, 13, 24, 73, 15, 16, 18, 92, 52, 17, 45, 21, 15, 27, 18, 92, 52, 27, 11, 21, 22, 0, 1, 11, 21, 15, 4, 18, 92, 31, 4, 14, 21, 15, 8, 25, 61, 22, 72, 14, 122, 31, 10, 11, 12, 15, 13, 25, 29, 52, 13, 24, 73, 15, 17, 14, 93, 52, 17, 11, 9, 15, 27, 14, 29,

[0, 1, 2, 188, 1, 5, 62, 22, 67, 5, 58, 22, 67, 5, 9, 22, 4, 5, 58, 22, 4, 5, 62, 22, 23, 5, 9, 22, 23, 5, 58, 22, 8, 5, 9, 22, 8, 5, 57, 22, 72, 5, 73, 22, 72, 5, 12, 22, 10, 5, 61, 22, 10, 5, 21, 22, 91, 5, 86, 22, 91, 5, 20, 34, 78, 5, 20, 52, 27, 18, 61, 22, 74, 25, 20, 7, 0, 67, 14, 76, 22, 4, 18, 20, 22, 23, 18, 61, 7, 72, 25, 73, 7, 91, 5, 20, 7, 70, 5, 76, 22, 16, 14, 20, 22, 78, 18, 61, 7, 90, 14, 73, 7, 74, 14, 58, 36, 0, 23, 24, 107, 7, 72, 14, 97, 7, 91, 14, 26, 77, 91, 24, 84, 34, 90, 14, 97, 15, 74, 14, 48, 37, 74, 25, 76, 7, 0, 67, 5, 20, 7, 23, 5, 61, 41, 91, 5, 20, 22, 13, 32, 20, 22, 70, 25, 86, 7, 78, 5, 61, 7, 90, 25, 73, 7, 74, 24, 20, 22, 0, 1, 24, 20, 22, 67, 18, 86, 22, 4, 24, 86, 22, 23, 5, 61, 7, 72, 5, 73, 7]
101
[0, 1, 2, 188, 1, 25, 61, 66, 8, 2, 188, 10, 25, 86, 22, 91, 24, 61, 22, 13, 2, 188, 13, 14, 61, 7, 16, 14, 86, 15, 17, 2, 188, 17, 25, 96, 52, 0, 1, 2, 188, 1, 14, 84, 7, 4, 24, 98, 15, 8, 2, 188, 8, 24, 98, 7, 10, 11, 84, 7, 13, 32, 98, 19, 17, 11,

[0, 1, 2, 183, 67, 64, 6, 19, 4, 53, 50, 31, 23, 35, 9, 52, 10, 51, 6, 136, 91, 51, 6, 41, 74, 51, 9, 7, 0, 67, 35, 6, 31, 23, 35, 6, 7, 8, 54, 9, 47, 13, 54, 33, 15, 70, 54, 50, 31, 78, 54, 33, 7, 90, 35, 6, 15, 74, 35, 9, 66, 0, 67, 53, 33, 52, 23, 53, 33, 52, 72, 54, 50, 7, 91, 53, 9, 7, 70, 54, 6, 41, 78, 54, 33, 52, 17, 54, 9, 31, 74, 54, 50, 15, 0, 1, 54, 21, 41, 67, 35, 12, 19, 23, 54, 21, 31, 72, 53, 33, 31, 91, 53, 30, 7, 70, 53, 33, 31, 78, 64, 6, 15, 17, 54, 9, 15, 90, 51, 33, 31, 74, 51, 30, 31, 0, 67, 51, 33, 7, 4, 54, 33, 7, 23, 54, 9, 15, 8, 54, 20, 7, 72, 51, 82, 15, 91, 35, 33, 31, 70, 54, 21, 7, 78, 51, 20, 15, 17, 51, 82, 52, 74, 32, 40, 15, 0, 67, 54, 21, 15, 4, 54, 12, 31, 23, 53, 82, 31, 10, 49, 9, 7, 91, 35, 73, 31, 70, 51, 9, 41, 78, 54, 12, 22, 17, 54, 21, 47, 74, 54, 12, 15, 0, 1, 54, 21, 15, 67, 51, 20, 52, 23, 51, 82, 31, 72, 54, 82, 31, 91, 35, 21, 85, 78, 54, 12, 7, 17, 54, 20, 7, 90, 51, 21, 7, 74, 54, 21, 7, 0, 91, 51, 40, 52, 70, 32, 40, 19, 74, 54, 21,

In [59]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(dissimilar_interpolation + "/intro/" + '/intro' + str(i) + '.mid')
    outro = mido.MidiFile(dissimilar_interpolation + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(dissimilar_interpolation + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(dissimilar_interpolation + '/merged_predict' + str(i) + '.mid')

33720
33720
34200
34200
23520
23520
44280
44280
27600
27600
24060
24060
23760
23760
33420
33420
31560
31560
45780
45780
26280
26280
30540
30540
30780
30780
25500
25500
30000
30000
54420
54420
18660
18660
30420
30420
41940
41940
49440
49440
35520
35520
27660
27660
36840
36840
30540
30540
29220
29220
31260
31260
36960
36960
28320
28320
41160
41160
14940
14940
30780
30780
27660
27660
22020
22020
30480
30480
46560
46560
24180
24180
31500
31500
45960
45960
32460
32460
38880
38880
23040
23040
28200
28200
33720
33720
22800
22800
24840
24840
22920
22920
32220
32220
46800
46800
23760
23760
24420
24420
19740
19740
30060
30060
36720
36720
29280
29280
30960
30960
28620
28620
39900
39900
26760
26760
40140
40140
27420
27420
33600
33600
35520
35520
30240
30240
29280
29280
17940
17940
31200
31200
32940
32940
31800
31800
27000
27000
24960
24960
30240
30240
23760
23760
23820
23820
32580
32580
16200
16200
46740
46740
31080
31080
32400
32400
38220
38220
41280
41280
18540
18540
24720
24720
36960
36960
2214

In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns