In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_augmented_presplit"
folder = "dynamic_augmented_models/2enc_2nd"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [2]:
# state = random.getstate()
# pickle.dump(state, open('./state.pkl', 'wb'))

import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor
state = pickle.load(open('./state.pkl', 'rb'))
random.setstate(state)

In [3]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/intro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/outro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/predict").mkdir(parents=True, exist_ok=True)

In [4]:
event2word, word2event = pickle.load(open('dictionary_augmented.pkl', 'rb'))

In [5]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [6]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=3)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=3)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=3)

In [7]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    #print(intro.transpose(1,0).size(0))
    print(solo.size(1))

272
353
509
335
326
279
253
522
281
444
619
325
319
414


In [8]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [9]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        src2_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.src2_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src2_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = kk.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )
        
        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        ).to(self.device)
        embed_src2 = self.dropout(
            (self.src2_word_embedding(src2) + self.src2_position_embedding(src2_positions))
        ).to(self.device)
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        ).to(self.device)
        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        #print(src_padding_mask.size())
        #print(src2_padding_mask.size())
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [10]:
src_vocab_size = len(intro_field.vocab)
src2_vocab_size = len(outro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    src2_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [11]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4) #non augmented 3e-4


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 14,997,275 trainable parameters


In [12]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [13]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [14]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, src2, trg = intro.transpose(1,0), outro.transpose(1,0), solo.transpose(1,0)
        src, src2, trg = src.to(device), src2.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src ,src2, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)
        loss = criterion(output, trg)
#         print(torch.isfinite(trg).all().cpu().item())
#         print(torch.isfinite(output).all().cpu().item())
#         print(torch.isfinite(loss).all().cpu().item())
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, src2, trg = intro.transpose(1,0), outro.transpose(1,0), solo.transpose(1,0)
            src, src2, trg = src.to(device), src2.to(device), trg.to(device)

            output = model(src, src2, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [15]:
def translate_sentence(model, sentence, sentence2, intro, outro, solo, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, intro.init_token)
    tokens.append(intro.eos_token)

    tokens2.insert(0, outro.init_token)
    tokens2.append(outro.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [intro.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [outro.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [solo.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, sentence2_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == solo.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [solo.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [16]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [17]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(3):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [18]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1
best_valid_loss = float('inf')
#torch.autograd.set_detect_anomaly(True)
#model = nn.DataParallel(model, device_ids=[0,1]).to(device)
for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+1) % 25 ==0:
        if check_mode_collapse(model) > 1:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 1m 0s
	Train Loss: 4.249 | Train PPL:  70.016
	 Val. Loss: 3.355 |  Val. PPL:  28.652
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 1m 1s
	Train Loss: 2.996 | Train PPL:  20.004
	 Val. Loss: 2.851 |  Val. PPL:  17.305
=> Saving checkpoint
Epoch: 03 | Time: 1m 2s
	Train Loss: 2.691 | Train PPL:  14.746
	 Val. Loss: 2.619 |  Val. PPL:  13.728
=> Saving checkpoint
Epoch: 04 | Time: 1m 2s
	Train Loss: 2.474 | Train PPL:  11.876
	 Val. Loss: 2.399 |  Val. PPL:  11.008
=> Saving checkpoint
Epoch: 05 | Time: 1m 2s
	Train Loss: 2.283 | Train PPL:   9.802
	 Val. Loss: 2.279 |  Val. PPL:   9.768
=> Saving checkpoint
Epoch: 06 | Time: 1m 2s
	Train Loss: 2.157 | Train PPL:   8.649
	 Val. Loss: 2.195 |  Val. PPL:   8.980
=> Saving checkpoint
Epoch: 07 | Time: 1m 3s
	Train Loss: 2.050 | Train PPL:   7.767
	 Val. Loss: 2.111 |  Val. PPL:   8.259
=> Saving checkpoint
Epoch: 08 | Time: 1m 3s
	Train Loss: 1.964 | Train PPL:   7.130
	 Val. Loss: 2.065 |  Val. PPL:   7.889


Epoch: 48 | Time: 1m 3s
	Train Loss: 0.928 | Train PPL:   2.529
	 Val. Loss: 3.221 |  Val. PPL:  25.050
Epoch: 49 | Time: 1m 3s
	Train Loss: 0.915 | Train PPL:   2.497
	 Val. Loss: 3.215 |  Val. PPL:  24.915
Epoch: 50 | Time: 1m 3s
	Train Loss: 0.900 | Train PPL:   2.459
	 Val. Loss: 3.273 |  Val. PPL:  26.384
[0, 1, 2, 95, 67, 79, 96, 34, 72, 79, 98, 41, 70, 79, 98, 34, 90, 79, 98, 31, 74, 25, 98, 31, 0, 67, 79, 98, 34, 0, 67, 5, 98, 28, 91, 79, 98, 31, 70, 79, 98, 31, 78, 32, 98, 31, 90, 18, 98, 31, 74, 79, 98, 31, 0, 67, 38, 96, 56, 78, 18, 96, 28, 0, 67, 38, 96, 15, 4, 32, 96, 83]
[0, 1, 43, 117, 4, 123, 12, 34, 4, 123, 39, 34, 10, 45, 12, 31, 10, 45, 39, 31, 13, 45, 6, 31, 13, 45, 82, 31, 16, 63, 12, 31, 16, 63, 39, 31, 17, 71, 12, 31, 17, 71, 39, 31, 27, 49, 82, 34, 27, 49, 140, 34, 0, 4, 45, 12, 34, 4, 45, 39, 34, 10, 53, 12, 31, 10, 53, 39, 31, 13, 63, 6, 31, 13, 63, 82, 31, 16, 71, 6, 31, 16, 71, 39, 31, 17, 45, 82, 31, 17, 45, 39, 31, 27, 45, 12, 34, 27, 45, 39, 34, 0, 4, 63,

Epoch: 86 | Time: 1m 3s
	Train Loss: 0.596 | Train PPL:   1.815
	 Val. Loss: 4.258 |  Val. PPL:  70.668
Epoch: 87 | Time: 1m 3s
	Train Loss: 0.588 | Train PPL:   1.801
	 Val. Loss: 4.215 |  Val. PPL:  67.693
Epoch: 88 | Time: 1m 3s
	Train Loss: 0.587 | Train PPL:   1.799
	 Val. Loss: 4.362 |  Val. PPL:  78.384
Epoch: 89 | Time: 1m 3s
	Train Loss: 0.576 | Train PPL:   1.780
	 Val. Loss: 4.343 |  Val. PPL:  76.968
Epoch: 90 | Time: 1m 3s
	Train Loss: 0.575 | Train PPL:   1.777
	 Val. Loss: 4.349 |  Val. PPL:  77.410
Epoch: 91 | Time: 1m 3s
	Train Loss: 0.569 | Train PPL:   1.766
	 Val. Loss: 4.335 |  Val. PPL:  76.317
Epoch: 92 | Time: 1m 3s
	Train Loss: 0.567 | Train PPL:   1.764
	 Val. Loss: 4.431 |  Val. PPL:  84.001
Epoch: 93 | Time: 1m 3s
	Train Loss: 0.561 | Train PPL:   1.752
	 Val. Loss: 4.410 |  Val. PPL:  82.273
Epoch: 94 | Time: 1m 3s
	Train Loss: 0.560 | Train PPL:   1.751
	 Val. Loss: 4.477 |  Val. PPL:  87.977
Epoch: 95 | Time: 1m 4s
	Train Loss: 0.555 | Train PPL:   1.742


[0, 1, 2, 162, 1, 49, 6, 80, 16, 53, 9, 15, 78, 45, 12, 15, 17, 54, 21, 7, 27, 49, 21, 31, 0, 1, 38, 96, 66, 10, 53, 21, 15, 91, 63, 21, 15, 13, 53, 12, 34, 90, 63, 21, 15, 27, 45, 21, 15, 74, 54, 96, 22, 0, 1, 32, 82, 7, 4, 49, 12, 36, 74, 45, 82, 22, 74, 5, 96, 22, 74, 51, 21, 22, 74, 53, 21, 22, 74, 51, 9, 22, 0, 1, 35, 6, 22, 1, 64, 33, 31, 4, 64, 33, 31, 8, 54, 12, 31, 72, 53, 9, 102]
Epoch: 126 | Time: 1m 3s
	Train Loss: 0.450 | Train PPL:   1.568
	 Val. Loss: 4.996 |  Val. PPL: 147.813
Epoch: 127 | Time: 1m 3s
	Train Loss: 0.448 | Train PPL:   1.565
	 Val. Loss: 4.915 |  Val. PPL: 136.308
Epoch: 128 | Time: 1m 2s
	Train Loss: 0.450 | Train PPL:   1.568
	 Val. Loss: 4.951 |  Val. PPL: 141.246
Epoch: 129 | Time: 1m 3s
	Train Loss: 0.446 | Train PPL:   1.562
	 Val. Loss: 4.961 |  Val. PPL: 142.793
Epoch: 130 | Time: 1m 3s
	Train Loss: 0.442 | Train PPL:   1.556
	 Val. Loss: 5.034 |  Val. PPL: 153.560
Epoch: 131 | Time: 1m 3s
	Train Loss: 0.437 | Train PPL:   1.548
	 Val. Loss: 5.02

[0, 1, 2, 162, 1, 49, 57, 80, 16, 53, 12, 15, 78, 45, 21, 15, 17, 54, 21, 7, 27, 49, 21, 31, 0, 1, 38, 96, 66, 10, 53, 21, 15, 91, 63, 21, 15, 13, 53, 57, 34, 90, 63, 12, 15, 27, 45, 21, 15, 74, 54, 21, 22, 0, 1, 32, 96, 7, 4, 49, 57, 36, 74, 45, 96, 22, 74, 5, 21, 22, 74, 51, 12, 22, 74, 53, 57, 22, 74, 51, 6, 22, 0, 1, 35, 50, 22, 1, 64, 33, 31, 4, 64, 50, 31, 8, 54, 57, 31, 72, 53, 6, 102]
Epoch: 176 | Time: 1m 3s
	Train Loss: 0.344 | Train PPL:   1.411
	 Val. Loss: 5.451 |  Val. PPL: 233.010
Epoch: 177 | Time: 1m 2s
	Train Loss: 0.344 | Train PPL:   1.410
	 Val. Loss: 5.435 |  Val. PPL: 229.405
Epoch: 178 | Time: 1m 3s
	Train Loss: 0.342 | Train PPL:   1.407
	 Val. Loss: 5.455 |  Val. PPL: 233.873
Epoch: 179 | Time: 1m 3s
	Train Loss: 0.342 | Train PPL:   1.408
	 Val. Loss: 5.584 |  Val. PPL: 266.049
Epoch: 180 | Time: 1m 3s
	Train Loss: 0.340 | Train PPL:   1.404
	 Val. Loss: 5.592 |  Val. PPL: 268.332
=> Saving checkpoint
Epoch: 181 | Time: 1m 4s
	Train Loss: 0.338 | Train PPL:  

[0, 1, 2, 162, 1, 49, 50, 80, 16, 53, 6, 15, 78, 45, 9, 15, 17, 54, 57, 7, 27, 49, 9, 31, 0, 1, 38, 12, 66, 10, 53, 57, 15, 91, 63, 9, 15, 13, 53, 50, 34, 90, 63, 6, 15, 27, 45, 9, 15, 74, 54, 57, 22, 0, 1, 32, 12, 7, 4, 49, 50, 36, 74, 45, 12, 22, 74, 5, 57, 22, 74, 51, 9, 22, 74, 53, 6, 22, 74, 51, 50, 22, 0, 1, 35, 33, 22, 1, 64, 55, 31, 4, 64, 30, 31, 8, 54, 6, 31, 72, 53, 50, 102]
Epoch: 226 | Time: 1m 3s
	Train Loss: 0.270 | Train PPL:   1.310
	 Val. Loss: 5.943 |  Val. PPL: 381.129
Epoch: 227 | Time: 1m 3s
	Train Loss: 0.267 | Train PPL:   1.307
	 Val. Loss: 5.973 |  Val. PPL: 392.609
Epoch: 228 | Time: 1m 3s
	Train Loss: 0.265 | Train PPL:   1.303
	 Val. Loss: 6.101 |  Val. PPL: 446.260
Epoch: 229 | Time: 1m 3s
	Train Loss: 0.265 | Train PPL:   1.303
	 Val. Loss: 5.876 |  Val. PPL: 356.543
Epoch: 230 | Time: 1m 3s
	Train Loss: 0.263 | Train PPL:   1.301
	 Val. Loss: 6.063 |  Val. PPL: 429.634
Epoch: 231 | Time: 1m 4s
	Train Loss: 0.262 | Train PPL:   1.300
	 Val. Loss: 6.023 | 

[0, 1, 43, 183, 8, 43, 183, 13, 43, 183, 70, 51, 84, 15, 16, 51, 84, 15, 78, 51, 86, 15, 17, 43, 183, 90, 64, 86, 15, 27, 64, 61, 15, 74, 64, 86, 15, 0, 1, 2, 183, 67, 64, 12, 42, 8, 2, 183, 13, 43, 183, 70, 51, 86, 15, 16, 51, 86, 15, 78, 51, 86, 7, 17, 2, 183, 90, 64, 48, 15, 27, 51, 58, 15, 74, 51, 86, 15, 0, 1, 2, 183, 67, 64, 57, 47, 8, 2, 183, 10, 51, 86, 15, 91, 64, 57, 15, 13, 2, 183, 70, 51, 57, 15, 16, 51, 57, 15, 78, 51, 57, 15, 17, 2, 183, 90, 51, 58, 15, 27, 51, 57, 15, 74, 64, 86, 15, 0, 1, 2, 183, 67, 51, 58, 15, 4, 51, 58, 15, 23, 64, 58, 15, 8, 2, 183, 72, 51, 57, 7, 10, 64, 12, 15, 91, 64, 84, 15, 13, 2, 183, 70, 51, 82, 15, 16, 51, 86, 15, 78, 35, 82, 15, 17, 2, 183, 90, 51, 98, 15, 27, 64, 84, 15, 74, 64, 82, 15, 0, 1, 2, 183, 67, 54, 58, 7, 67, 64, 86, 108, 4, 53, 55, 15, 23, 35, 6, 7, 8, 2, 183, 72, 35, 58, 7, 10, 64, 57, 7, 91, 54, 62, 7, 13, 2, 183, 70, 35, 58, 7, 16, 54, 55, 7, 78, 54, 6, 7, 17, 2, 183, 90, 35, 58, 19, 27, 51, 57, 7, 74, 35, 6, 7, 0, 1, 2, 183,

Epoch: 301 | Time: 1m 3s
	Train Loss: 0.164 | Train PPL:   1.178
	 Val. Loss: 6.471 |  Val. PPL: 646.335
=> Saving checkpoint
Epoch: 302 | Time: 1m 3s
	Train Loss: 0.162 | Train PPL:   1.176
	 Val. Loss: 6.658 |  Val. PPL: 779.092
Epoch: 303 | Time: 1m 3s
	Train Loss: 0.163 | Train PPL:   1.177
	 Val. Loss: 6.575 |  Val. PPL: 716.811
Epoch: 304 | Time: 1m 3s
	Train Loss: 0.162 | Train PPL:   1.175
	 Val. Loss: 6.582 |  Val. PPL: 721.799
Epoch: 305 | Time: 1m 3s
	Train Loss: 0.159 | Train PPL:   1.172
	 Val. Loss: 6.668 |  Val. PPL: 786.855
Epoch: 306 | Time: 1m 3s
	Train Loss: 0.158 | Train PPL:   1.171
	 Val. Loss: 6.539 |  Val. PPL: 691.443
Epoch: 307 | Time: 1m 4s
	Train Loss: 0.157 | Train PPL:   1.170
	 Val. Loss: 6.570 |  Val. PPL: 713.260
Epoch: 308 | Time: 1m 3s
	Train Loss: 0.156 | Train PPL:   1.169
	 Val. Loss: 6.590 |  Val. PPL: 728.011
Epoch: 309 | Time: 1m 4s
	Train Loss: 0.154 | Train PPL:   1.167
	 Val. Loss: 6.627 |  Val. PPL: 755.463
Epoch: 310 | Time: 1m 3s
	Train Lo

[0, 1, 2, 162, 1, 49, 50, 80, 16, 53, 6, 15, 78, 45, 9, 15, 17, 54, 57, 7, 27, 49, 9, 31, 0, 1, 38, 12, 66, 10, 53, 57, 15, 91, 63, 6, 15, 13, 53, 50, 34, 90, 63, 6, 15, 27, 45, 9, 15, 74, 54, 57, 22, 0, 1, 32, 12, 7, 4, 49, 48, 36, 74, 45, 57, 22, 74, 5, 57, 22, 74, 51, 6, 22, 74, 53, 50, 22, 74, 51, 26, 22, 0, 1, 35, 33, 22, 1, 64, 55, 31, 4, 64, 46, 31, 8, 54, 6, 31, 72, 53, 50, 102]
Epoch: 351 | Time: 1m 3s
	Train Loss: 0.106 | Train PPL:   1.112
	 Val. Loss: 6.825 |  Val. PPL: 920.878
Epoch: 352 | Time: 1m 4s
	Train Loss: 0.106 | Train PPL:   1.112
	 Val. Loss: 6.916 |  Val. PPL: 1007.851
Epoch: 353 | Time: 1m 4s
	Train Loss: 0.103 | Train PPL:   1.108
	 Val. Loss: 6.892 |  Val. PPL: 984.624
Epoch: 354 | Time: 1m 3s
	Train Loss: 0.104 | Train PPL:   1.109
	 Val. Loss: 6.819 |  Val. PPL: 914.867
Epoch: 355 | Time: 1m 4s
	Train Loss: 0.102 | Train PPL:   1.108
	 Val. Loss: 6.863 |  Val. PPL: 956.504
Epoch: 356 | Time: 1m 4s
	Train Loss: 0.101 | Train PPL:   1.106
	 Val. Loss: 6.934 

[0, 1, 43, 116, 23, 38, 58, 7, 72, 38, 57, 15, 91, 38, 12, 7, 70, 64, 86, 15, 78, 64, 12, 15, 90, 24, 57, 34, 0, 67, 64, 58, 7, 23, 32, 50, 7, 72, 32, 48, 15, 91, 38, 50, 52, 78, 24, 50, 47, 0, 23, 32, 58, 31, 72, 32, 57, 31, 91, 38, 12, 31, 70, 32, 86, 31, 78, 32, 96, 31, 90, 38, 57, 34, 0, 67, 32, 58, 7, 23, 32, 50, 7, 72, 32, 48, 31, 91, 32, 50, 52, 78, 11, 50, 37, 0, 23, 38, 58, 7, 72, 38, 57, 15, 91, 38, 12, 7, 70, 64, 86, 15, 78, 64, 12, 15, 90, 24, 57, 34, 0, 67, 64, 58, 7, 23, 32, 50, 7, 72, 32, 48, 15, 91, 38, 50, 52, 78, 24, 50, 47, 0, 23, 32, 58, 31, 72, 32, 12, 31, 91, 38, 12, 31, 70, 32, 58, 31, 78, 32, 57, 34]
[0, 1, 2, 162, 1, 32, 6, 47, 10, 51, 9, 22, 91, 38, 57, 22, 13, 18, 57, 77, 27, 49, 12, 22, 74, 63, 57, 22, 0, 1, 38, 9, 77, 10, 35, 58, 22, 91, 45, 9, 22, 13, 54, 6, 37, 0, 1, 35, 6, 47, 10, 53, 9, 22, 91, 54, 57, 22, 13, 14, 57, 28, 78, 45, 9, 22, 17, 53, 6, 31, 27, 53, 50, 7, 0, 1, 51, 6, 47, 10, 49, 50, 22, 91, 53, 6, 15, 13, 45, 50, 22, 13, 45, 33, 47, 0, 1, 45

Epoch: 449 | Time: 1m 3s
	Train Loss: 0.056 | Train PPL:   1.058
	 Val. Loss: 7.460 |  Val. PPL: 1737.520
Epoch: 450 | Time: 1m 3s
	Train Loss: 0.058 | Train PPL:   1.060
	 Val. Loss: 7.334 |  Val. PPL: 1531.493
[0, 1, 2, 95, 8, 2, 95, 8, 25, 48, 7, 10, 79, 62, 7, 13, 2, 95, 13, 11, 48, 7, 16, 79, 57, 31, 17, 2, 95, 17, 79, 58, 34, 0, 1, 2, 95, 8, 2, 95, 8, 24, 46, 7, 10, 79, 50, 7, 13, 2, 95, 13, 11, 46, 7, 16, 25, 58, 31, 17, 2, 95, 17, 25, 62, 34, 0, 1, 2, 95, 4, 79, 73, 7, 8, 2, 95, 8, 25, 57, 7, 10, 25, 62, 7, 13, 2, 95, 13, 24, 48, 7, 16, 79, 62, 19, 17, 2, 95, 27, 24, 62, 41, 0, 1, 2, 95, 4, 18, 58, 19, 8, 2, 95, 10, 79, 61, 7, 13, 2, 95, 13, 79, 57, 36, 17, 2, 95, 0, 1, 2, 95, 8, 2, 95, 8, 38, 48, 15, 10, 79, 62, 7, 13, 2, 155, 13, 11, 48, 15, 16, 79, 57, 31, 17, 2, 155, 17, 79, 58, 41, 0, 1, 2, 155, 8, 2, 155, 8, 24, 46, 15, 10, 79, 50, 7, 13, 2, 155, 13, 11, 46, 15, 16, 18, 58, 31, 17, 2, 155, 17, 79, 62, 47, 0, 1, 2, 155, 4, 79, 73, 7, 8, 2, 155, 8, 79, 57, 7, 10, 79, 62, 31

[0, 1, 2, 162, 1, 51, 50, 15, 67, 51, 12, 31, 23, 51, 9, 28, 23, 51, 12, 36, 10, 49, 12, 22, 91, 63, 12, 31, 70, 63, 12, 22, 70, 63, 21, 85, 0, 10, 53, 12, 31, 91, 45, 12, 15, 13, 64, 21, 15, 16, 35, 96, 15, 78, 54, 96, 37, 0, 23, 45, 12, 19, 10, 53, 50, 15, 91, 53, 12, 31, 70, 63, 50, 7, 70, 35, 12, 31, 78, 35, 50, 31, 90, 54, 12, 52, 74, 35, 178, 34, 0, 67, 45, 6, 36, 67, 54, 9, 52, 72, 51, 12, 7, 91, 45, 12, 31, 91, 54, 57, 7, 70, 35, 12, 7, 70, 54, 9, 19, 90, 63, 12, 7, 74, 45, 96, 19, 0, 67, 45, 12, 31, 23, 53, 9, 31, 72, 35, 58, 31, 91, 49, 50, 7, 70, 54, 6, 31, 70, 54, 58, 19, 78, 35, 50, 7, 90, 35, 6, 28, 27, 35, 57, 31, 74, 35, 6, 31, 0, 67, 35, 12, 31, 23, 35, 9, 7, 72, 63, 12, 19, 91, 49, 12, 31, 91, 54, 21, 31, 70, 35, 57, 31, 78, 35, 12, 31, 78, 51, 57, 34, 74, 35, 30, 34, 0, 67, 54, 33, 15, 23, 54, 12, 19, 91, 54, 21, 19, 70, 54, 12, 15, 78, 51, 33, 19, 90, 53, 57, 19, 74, 35, 12, 31, 74, 35, 50, 52]
Epoch: 476 | Time: 1m 3s
	Train Loss: 0.050 | Train PPL:   1.052
	 Val. 

In [None]:
## checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [None]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
load_checkpoint(state, model, optimizer)

In [19]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

2429.232212428656


In [None]:
generated_outputs = folder +  "/generated_samples_600epochs"
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [20]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [21]:
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


[0, 1, 43, 162, 23, 32, 9, 85, 23, 32, 20, 85, 8, 43, 162, 13, 43, 162, 78, 11, 26, 103, 78, 11, 82, 103, 17, 43, 162, 0, 1, 43, 162, 8, 43, 162, 72, 11, 29, 31, 72, 11, 21, 31, 91, 11, 111, 31, 91, 11, 26, 31, 13, 43, 162, 70, 24, 29, 34, 70, 24, 9, 34, 17, 43, 162, 90, 64, 26, 19, 90, 64, 82, 19, 0, 1, 43, 162, 1, 64, 62, 19, 1, 64, 40, 19, 23, 64, 26, 135, 23, 64, 82, 135, 8, 43, 162, 13, 43, 162, 17, 43, 162, 0, 1, 43, 162, 67, 32, 26, 15, 67, 32, 20, 15, 4, 24, 26, 15, 4, 24, 20, 15, 23, 38, 6, 37, 23, 38, 82, 37, 8, 43, 162, 13, 43, 162, 78, 14, 33, 85, 78, 14, 21, 85, 17, 43, 162, 0, 1, 43, 162, 23, 24, 46, 85, 23, 24, 61, 85, 78, 64, 26, 31, 78, 64, 20, 31, 90, 32, 46, 85, 90, 32, 61, 85, 0, 72, 64, 33, 31, 72, 64, 20, 31, 91, 54, 61, 31, 70, 53, 73, 15, 16, 51, 9, 15, 78, 54, 73, 34, 74, 64, 9, 31, 0, 67, 32, 73, 133]
0
[0, 1, 2, 126, 8, 2, 126, 91, 64, 73, 85, 13, 2, 126, 78, 54, 26, 31, 17, 2, 126, 90, 64, 73, 15, 27, 51, 73, 15, 74, 54, 26, 15, 74, 64, 73, 15, 0, 1, 2, 126,

[0, 1, 2, 142, 8, 2, 142, 13, 2, 142, 70, 35, 96, 15, 16, 64, 50, 22, 78, 64, 96, 15, 17, 2, 142, 17, 2, 142, 27, 64, 50, 15, 27, 64, 6, 15, 74, 32, 12, 15, 0, 1, 64, 50, 15, 1, 32, 12, 15, 67, 51, 57, 15, 4, 64, 50, 15, 23, 35, 12, 15, 8, 2, 142, 8, 64, 55, 15, 72, 64, 30, 15, 10, 64, 50, 15, 91, 35, 48, 15, 13, 2, 142, 13, 51, 55, 7, 16, 54, 46, 15, 78, 64, 30, 15, 17, 2, 142, 17, 51, 48, 7, 27, 64, 50, 15, 74, 49, 30, 22, 0, 1, 2, 142, 1, 51, 48, 15, 67, 53, 50, 15, 4, 64, 48, 22, 4, 64, 50, 15, 23, 64, 30, 22, 8, 2, 142, 8, 51, 50, 7, 72, 54, 111, 7, 10, 54, 75, 7, 91, 54, 30, 15, 13, 2, 142, 13, 51, 100, 22, 70, 51, 100, 15, 16, 51, 55, 15, 78, 35, 30, 15, 17, 2, 142, 17, 64, 30, 15, 27, 51, 55, 7, 0, 1, 2, 142, 1, 35, 48, 7, 67, 51, 55, 7, 4, 35, 46, 7, 23, 35, 55, 15, 8, 2, 142, 8, 53, 48, 15, 72, 51, 55, 15, 10, 35, 48, 7, 91, 54, 55, 15, 13, 2, 142, 13, 54, 30, 15, 70, 35, 33, 15, 16, 54, 48, 15, 78, 11, 61, 15, 17, 2, 142, 17, 51, 12, 7, 90, 64, 50, 15, 27, 51, 48, 15, 74, 64

[0, 1, 43, 206, 27, 14, 104, 22, 74, 5, 100, 15, 0, 1, 25, 29, 7, 4, 11, 75, 31, 8, 25, 46, 7, 10, 14, 73, 7, 13, 11, 26, 52, 78, 5, 9, 52, 27, 5, 26, 42, 0, 10, 18, 9, 22, 91, 24, 73, 22, 13, 25, 9, 31, 78, 5, 73, 52, 27, 25, 62, 7, 0, 1, 25, 21, 31, 23, 14, 73, 31, 10, 5, 20, 31, 13, 5, 21, 47, 27, 14, 73, 22, 74, 25, 9, 22, 0, 1, 5, 26, 37, 78, 5, 62, 22, 17, 14, 61, 7, 27, 11, 62, 15, 74, 11, 61, 31, 0, 4, 32, 62, 15, 23, 38, 73, 19, 10, 14, 61, 7, 13, 14, 61, 31, 78, 14, 76, 52, 27, 5, 86, 31, 0, 1, 25, 76, 66, 10, 25, 61, 22, 91, 5, 76, 22, 13, 25, 76, 59, 0, 4, 25, 76, 7, 8, 14, 86, 15, 72, 24, 61, 7]
18
[0, 1, 43, 206, 8, 32, 6, 19, 13, 38, 6, 7, 16, 32, 62, 15, 78, 64, 6, 22, 17, 32, 33, 34, 0, 1, 32, 33, 31, 4, 32, 33, 31, 8, 32, 9, 15, 8, 64, 12, 15, 10, 64, 6, 15, 10, 64, 73, 15, 13, 64, 9, 15, 13, 51, 12, 7, 16, 51, 73, 7, 16, 64, 21, 7, 17, 64, 9, 31, 17, 64, 12, 31, 27, 54, 6, 7, 27, 64, 73, 7, 0, 1, 51, 62, 7, 1, 64, 9, 7, 4, 51, 26, 7, 4, 51, 6, 7, 8, 64, 33, 15, 8, 64

[0, 1, 2, 113, 8, 2, 113, 8, 38, 86, 31, 91, 14, 61, 31, 13, 2, 113, 70, 51, 61, 80, 17, 2, 113, 17, 53, 98, 15, 90, 45, 84, 15, 27, 53, 96, 15, 74, 63, 76, 15, 0, 1, 2, 113, 1, 53, 98, 15, 67, 63, 84, 15, 4, 51, 48, 7, 4, 71, 96, 15, 23, 45, 50, 22, 23, 53, 76, 15, 8, 2, 113, 8, 51, 62, 31, 8, 35, 98, 22, 91, 35, 50, 31, 13, 2, 113, 16, 38, 48, 80, 17, 2, 113, 17, 63, 98, 7, 90, 63, 86, 15, 27, 63, 96, 15, 74, 49, 76, 22, 0, 1, 2, 113, 1, 54, 98, 15, 67, 63, 86, 15, 4, 63, 96, 15, 23, 63, 76, 22, 8, 2, 113, 8, 64, 46, 15, 8, 51, 98, 7, 72, 45, 30, 15, 10, 53, 46, 22, 91, 11, 48, 15, 13, 2, 113, 13, 53, 46, 15, 70, 45, 48, 22, 16, 11, 50, 15, 78, 63, 48, 15, 17, 2, 113, 17, 53, 50, 15, 90, 5, 62, 15, 27, 49, 50, 15, 74, 49, 62, 22, 0, 1, 2, 113, 1, 63, 58, 22, 67, 49, 62, 15, 4, 54, 58, 15, 23, 51, 57, 15, 8, 2, 113, 8, 53, 58, 15, 72, 45, 57, 15, 10, 35, 12, 22, 91, 54, 57, 15, 13, 2, 113, 13, 45, 12, 15, 70, 51, 61, 22, 16, 63, 12, 15, 78, 35, 61, 15, 17, 2, 113, 17, 51, 96, 77, 0, 8

[0, 1, 2, 3, 17, 53, 12, 31, 27, 51, 86, 34, 0, 4, 45, 12, 31, 8, 45, 57, 31, 10, 49, 12, 31, 13, 35, 50, 56, 0, 4, 51, 86, 31, 8, 35, 57, 7, 10, 35, 12, 31, 13, 14, 46, 42, 0, 8, 11, 58, 31, 10, 32, 57, 31, 13, 51, 50, 42, 27, 35, 58, 7, 0, 1, 38, 57, 34, 8, 14, 61, 19, 13, 51, 12, 59, 0, 1, 53, 46, 41, 4, 45, 50, 7, 8, 35, 57, 52, 91, 49, 58, 15, 13, 54, 50, 66, 0, 4, 49, 12, 31, 8, 32, 86, 31, 10, 63, 12, 31, 13, 53, 57, 42, 27, 45, 48, 56, 0, 1, 53, 57, 47, 10, 35, 12, 34, 16, 210, 46, 137]
36
[0, 1, 43, 110, 78, 123, 46, 31, 90, 63, 48, 7, 27, 71, 50, 37, 0, 72, 71, 100, 7, 10, 71, 58, 36, 90, 45, 26, 15, 27, 63, 61, 52, 0, 67, 71, 73, 31, 23, 63, 9, 19, 10, 71, 58, 31, 13, 45, 62, 7, 70, 71, 26, 31, 78, 63, 62, 52, 27, 49, 62, 19, 0, 67, 45, 58, 94, 78, 71, 73, 19, 90, 45, 26, 136, 0, 90, 65, 100, 15, 27, 123, 46, 36, 0, 8, 65, 48, 15, 72, 123, 55, 15]
37
[0, 1, 43, 202, 78, 51, 122, 34, 0, 1, 51, 30, 125, 17, 51, 50, 34, 0, 1, 51, 6, 207, 0, 13, 54, 55, 34, 17, 64, 30, 34, 0, 1,

[0, 1, 43, 208, 91, 53, 46, 31, 70, 5, 62, 28, 74, 64, 62, 28, 0, 72, 53, 48, 52, 91, 49, 62, 7, 70, 32, 58, 135, 0, 72, 11, 62, 41, 70, 38, 50, 28, 74, 11, 55, 219]
45
[0, 1, 2, 114, 67, 64, 48, 102, 8, 2, 114, 13, 2, 114, 70, 32, 62, 19, 78, 32, 62, 15, 78, 32, 62, 15, 17, 2, 114, 90, 51, 48, 15, 27, 35, 62, 42, 0, 1, 2, 114, 67, 64, 58, 15, 4, 51, 48, 15, 23, 51, 48, 47, 70, 51, 46, 60, 70, 51, 48, 15, 70, 51, 48, 15, 16, 51, 86, 15, 78, 54, 46, 28, 0, 67, 54, 48, 31, 67, 64, 61, 85, 8, 2, 114, 78, 51, 58, 15, 78, 51, 57, 15, 17, 2, 114, 90, 64, 6, 15, 27, 51, 58, 15, 74, 49, 62, 15, 0, 1, 2, 114, 67, 32, 73, 15, 67, 32, 86, 31, 67, 32, 76, 15, 67, 51, 62, 15, 4, 51, 48, 15, 23, 35, 50, 28, 70, 51, 86, 15, 78, 35, 58, 15, 78, 51, 84, 15, 17, 2, 114, 67, 64, 48, 15, 4, 64, 48, 15, 23, 54, 46, 28, 8, 51, 86, 34, 91, 35, 76, 34, 70, 38, 76, 28, 70, 51, 76, 34, 70, 32, 86, 15, 16, 35, 48, 31, 78, 64, 86, 31, 78, 64, 76, 15, 17, 51, 86, 15, 17, 64, 76, 15, 90, 64, 76, 15, 27, 51, 86, 15,

[0, 1, 43, 161, 1, 38, 9, 31, 4, 45, 57, 15, 8, 63, 9, 31, 10, 64, 50, 52, 13, 38, 33, 125, 0, 1, 64, 9, 31, 4, 53, 57, 52, 8, 54, 9, 52, 10, 35, 6, 52, 13, 54, 26, 85, 0, 1, 11, 57, 31, 4, 32, 12, 52, 8, 38, 57, 52, 10, 35, 9, 31, 13, 38, 6, 85, 0, 1, 32, 9, 31, 4, 35, 57, 52, 8, 51, 9, 31, 10, 53, 6, 7, 13, 32, 6, 36, 17, 32, 30, 19, 0, 1, 38, 30, 125, 10, 11, 50, 42, 0, 1, 38, 33, 42, 8, 11, 62, 41, 13, 11, 21, 66, 17, 11, 12, 41, 0, 1, 38, 12, 66, 8, 38, 73, 41, 13, 32, 73, 19, 17, 11, 9, 41, 0, 1, 38, 9, 34, 8, 64, 6, 66, 13, 64, 62, 19, 17, 11, 48, 41, 0, 1, 38, 26, 15, 4, 11, 26, 68, 4, 32, 6, 31, 8, 32, 62, 31, 10, 32, 73, 34, 16, 32, 62, 31, 78, 201, 26, 34, 17, 32, 33, 31, 27, 32, 9, 34, 0, 4, 32, 33, 31, 8, 32, 30, 31, 10, 32, 6, 34, 16, 32, 33, 31, 17, 32, 29, 31, 27, 32, 75, 31, 0, 4, 32, 6, 31, 8, 32, 62, 31, 10, 32, 73, 34, 16, 32, 62, 31, 17, 32, 33, 31, 27, 32, 9, 34, 0, 4, 32, 33, 31, 8, 32, 30, 31, 10, 32, 6, 34, 16, 32, 33, 31, 17, 32, 29, 31, 90, 32, 75, 31]
52
[0,

[0, 1, 43, 170, 67, 49, 104, 15, 23, 51, 75, 15, 72, 38, 55, 15, 91, 49, 100, 15, 70, 54, 104, 22, 78, 24, 100, 192, 0, 23, 53, 75, 52, 72, 45, 46, 31, 91, 53, 100, 52, 70, 35, 104, 31, 78, 51, 128, 85, 0, 72, 49, 48, 22, 72, 51, 58, 22, 70, 51, 46, 66, 70, 38, 73, 66, 0, 1, 53, 100, 31, 67, 49, 104, 15, 23, 64, 26, 15, 72, 51, 46, 31, 91, 45, 46, 31, 70, 35, 75, 52, 78, 51, 46, 225, 0, 27, 49, 100, 52, 0, 67, 53, 104, 15, 23, 51, 46, 31, 72, 35, 48, 52, 91, 35, 100, 52, 13, 35, 104, 31]
62
[0, 1, 2, 89, 67, 14, 9, 41, 8, 2, 89, 72, 14, 57, 15, 10, 11, 9, 15, 91, 5, 57, 28, 13, 2, 89, 17, 2, 89, 90, 5, 9, 31, 74, 25, 57, 7, 0, 1, 2, 89, 67, 25, 21, 34, 8, 2, 89, 72, 25, 40, 7, 91, 18, 82, 28, 13, 2, 89, 17, 2, 89, 90, 11, 12, 15, 27, 32, 21, 22, 74, 5, 39, 52, 0, 1, 2, 89, 4, 24, 98, 52, 8, 2, 89, 72, 14, 40, 22, 10, 51, 82, 22, 91, 32, 40, 34, 13, 2, 89, 78, 14, 30, 7, 17, 2, 89, 90, 14, 12, 7, 74, 5, 9, 37, 0, 1, 2, 89, 8, 2, 89, 91, 14, 9, 31, 13, 2, 89, 70, 51, 6, 22, 16, 24, 9, 15

[0, 1, 2, 89, 67, 5, 61, 94, 90, 25, 20, 31, 74, 25, 73, 31, 0, 67, 25, 20, 34, 72, 25, 73, 31, 91, 5, 9, 80, 90, 25, 61, 7, 74, 25, 20, 7, 0, 1, 25, 76, 60, 17, 14, 20, 7, 27, 25, 76, 31, 0, 67, 14, 20, 60, 90, 24, 29, 7, 90, 24, 9, 7, 74, 14, 75, 31, 74, 14, 20, 31, 0, 67, 24, 46, 31, 67, 24, 76, 31, 23, 24, 75, 7, 23, 24, 20, 7, 72, 11, 75, 7, 72, 11, 20, 7, 91, 24, 29, 7, 91, 24, 21, 7, 70, 14, 29, 41, 70, 14, 9, 41, 90, 5, 73, 41, 90, 5, 97, 41, 0, 67, 14, 26, 28, 67, 14, 40, 28, 91, 5, 73, 31, 91, 5, 97, 31, 70, 24, 26, 85, 70, 24, 40, 85, 0, 67, 5, 75, 60, 67, 5, 73, 60, 90, 24, 9, 7, 90, 24, 20, 7, 74, 11, 46, 31, 74, 11, 76, 31, 0, 1, 14, 26, 136, 1, 14, 73, 136]
70
[0, 1, 2, 89, 67, 5, 21, 94, 78, 25, 96, 31, 90, 25, 12, 31, 74, 14, 21, 31, 0, 67, 25, 12, 31, 23, 25, 57, 80, 90, 25, 21, 7, 74, 25, 96, 7, 0, 1, 25, 82, 60, 17, 14, 96, 7, 27, 25, 21, 31, 0, 67, 14, 12, 60, 90, 24, 55, 7, 90, 24, 57, 7, 74, 14, 30, 31, 74, 14, 12, 31, 0, 67, 24, 33, 31, 67, 24, 21, 31, 23, 24, 3

[0, 1, 2, 3, 4, 5, 48, 7, 8, 5, 50, 7, 10, 11, 58, 7, 13, 14, 58, 15, 16, 14, 50, 7, 17, 18, 58, 31, 27, 11, 48, 31, 0, 1, 14, 46, 31, 4, 14, 55, 31, 8, 5, 46, 52, 10, 25, 48, 31, 13, 25, 48, 15, 16, 25, 46, 31, 17, 5, 55, 31, 27, 11, 55, 28, 0, 4, 14, 100, 31, 8, 18, 100, 7, 10, 14, 100, 7, 13, 24, 122, 7, 16, 14, 46, 31, 17, 11, 46, 31, 27, 14, 48, 31, 0, 1, 11, 46, 31, 4, 14, 48, 31, 8, 5, 46, 52, 10, 24, 55, 31, 13, 24, 30, 34, 16, 11, 48, 31, 17, 14, 48, 31, 27, 14, 48, 31, 0, 1, 11, 55, 31, 4, 11, 100, 34, 10, 11, 122, 31, 13, 14, 100, 31, 16, 32, 55, 31, 17, 11, 100, 31, 27, 14, 48, 31, 0, 1, 32, 100, 31, 1, 11, 100, 31, 4, 32, 46, 31, 8, 11, 48, 31, 10, 11, 100, 31, 13, 11, 100, 31, 16, 24, 55, 31, 17, 14, 100, 31, 27, 11, 100, 31, 0, 1, 11, 48, 31, 4, 11, 48, 31, 8, 11, 55, 31, 10, 11, 46, 31, 13, 5, 46, 31, 17, 38, 48, 31, 27, 11, 55, 31, 0, 1, 5, 55, 31, 4, 18, 100, 31, 4, 24, 128, 31, 4, 11, 48, 31, 8, 14, 55, 31, 10, 5, 46, 56]
78
[0, 1, 43, 114, 70, 32, 129, 68, 0, 70, 25

[0, 1, 43, 124, 78, 24, 111, 15, 90, 54, 30, 15, 74, 64, 33, 52, 0, 23, 53, 50, 31, 72, 63, 6, 31, 91, 45, 50, 52, 70, 45, 6, 52, 78, 32, 6, 31, 90, 54, 6, 102, 0, 72, 71, 33, 31, 91, 51, 6, 22, 78, 35, 6, 15, 90, 63, 57, 15, 74, 63, 6, 15, 0, 67, 53, 6, 7, 23, 49, 50, 31, 72, 35, 6, 59, 0, 78, 32, 6, 15, 90, 64, 6, 15, 74, 38, 6, 19, 0, 23, 54, 50, 31, 72, 53, 6, 31, 91, 53, 57, 7, 70, 49, 12, 31, 78, 24, 6, 31, 90, 54, 6, 94, 0, 72, 49, 30, 7, 10, 63, 33, 15, 91, 54, 48, 15, 70, 53, 6, 31, 78, 45, 50, 52, 90, 49, 48, 52, 74, 45, 33, 52, 0, 67, 45, 30, 52, 23, 45, 55, 52, 72, 54, 30, 192]
85
[0, 1, 2, 159, 91, 5, 26, 15, 13, 5, 26, 15, 70, 25, 50, 31, 78, 5, 26, 31, 90, 5, 29, 15, 27, 14, 29, 31, 0, 67, 5, 29, 31, 23, 5, 29, 31, 72, 5, 111, 15, 10, 5, 29, 15, 91, 5, 111, 15, 13, 24, 29, 15, 70, 14, 29, 15, 70, 14, 29, 15, 16, 11, 29, 15, 78, 14, 29, 31, 17, 5, 29, 15, 90, 14, 29, 15, 27, 25, 29, 31, 0, 1, 11, 29, 31, 4, 25, 29, 31, 8, 5, 29, 15, 72, 11, 111, 15, 10, 38, 29, 31, 13, 24

[0, 1, 2, 113, 67, 25, 75, 7, 23, 79, 46, 103, 90, 79, 46, 15, 74, 18, 46, 7, 0, 67, 18, 75, 15, 4, 64, 29, 15, 23, 79, 75, 56, 0, 67, 18, 104, 15, 4, 54, 29, 15, 23, 18, 100, 28, 74, 11, 93, 41, 0, 23, 18, 127, 34, 70, 11, 122, 15, 16, 11, 100, 15, 78, 11, 46, 15, 17, 38, 75, 15, 90, 25, 100, 15, 74, 25, 75, 15, 0, 1, 54, 29, 15, 67, 25, 100, 15, 4, 5, 104, 15, 23, 11, 29, 34, 74, 5, 29, 7, 0, 67, 25, 75, 7, 4, 5, 46, 15, 23, 79, 75, 103, 74, 38, 29, 52, 0, 67, 63, 100, 15, 4, 53, 29, 15, 23, 38, 100, 77, 74, 25, 93, 34, 0, 4, 79, 104, 80]
96
[0, 1, 2, 113, 1, 79, 58, 41, 8, 79, 86, 31, 91, 24, 61, 15, 13, 25, 57, 28, 0, 1, 25, 58, 19, 8, 79, 57, 19, 91, 79, 50, 15, 13, 38, 48, 28, 0, 1, 79, 50, 19, 8, 79, 48, 85, 0, 1, 79, 48, 19, 8, 79, 50, 19, 13, 11, 46, 19, 91, 79, 48, 15, 13, 79, 46, 31, 13, 25, 48, 136, 0, 1, 79, 100, 41, 8, 18, 100, 31, 91, 79, 46, 15, 13, 5, 46, 52, 17, 79, 100, 15, 74, 79, 55, 15, 0, 1, 54, 100, 15, 67, 79, 55, 56, 0, 1, 25, 55, 19, 8, 5, 46, 31, 10, 32, 46,

[0, 1, 2, 158, 91, 11, 50, 15, 13, 14, 12, 15, 70, 14, 12, 22, 16, 25, 86, 31, 17, 5, 96, 15, 90, 25, 12, 7, 74, 11, 12, 22, 0, 1, 14, 12, 22, 67, 38, 12, 22, 4, 18, 12, 15, 23, 25, 86, 15, 8, 25, 58, 15, 72, 25, 6, 80, 0, 91, 5, 50, 15, 13, 25, 12, 15, 70, 11, 12, 22, 16, 25, 86, 7, 17, 25, 96, 22, 90, 18, 86, 7, 74, 14, 96, 7, 0, 67, 32, 96, 22, 4, 5, 86, 7, 8, 14, 96, 22, 72, 25, 86, 7, 91, 24, 96, 7, 70, 24, 96, 52, 0, 23, 14, 50, 22, 8, 25, 57, 22, 72, 5, 12, 22, 10, 25, 57, 7, 13, 5, 61, 22, 70, 5, 57, 22, 16, 5, 61, 7, 17, 5, 96, 22, 90, 14, 96, 31, 0, 10, 18, 12, 7, 13, 5, 12, 22, 70, 18, 12, 31, 78, 18, 86, 7, 90, 5, 86, 19, 0, 23, 25, 86, 22, 8, 14, 86, 22, 72, 25, 96, 22, 10, 14, 86, 66, 0, 23, 14, 12, 22, 8, 25, 86, 22, 72, 5, 61, 22, 10, 25, 84, 87]
105
[0, 1, 2, 150, 90, 5, 50, 28, 90, 5, 82, 28, 0, 23, 14, 33, 31, 23, 14, 21, 31, 72, 24, 33, 31, 72, 24, 21, 31, 91, 25, 33, 31, 91, 25, 21, 31, 70, 5, 50, 31, 70, 5, 82, 31, 78, 5, 9, 31, 78, 5, 40, 31, 90, 5, 111, 137, 90,

In [22]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

30480
30480
29100
29100
40200
40200
36600
36600
27720
27720
30240
30240
53340
53340
43140
43140
25800
25800
28800
28800
20400
20400
24060
24060
30900
30900
25800
25800
31680
31680
32280
32280
30720
30720
30660
30660
61500
61500
53820
53820
31380
31380
26280
26280


In [None]:
# dissimilar_interpolation
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    #solo = test_solo[i]
    if i + 3 < (len(test_intro)):
        outro = test_outro[i+3]
    else:
        outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    #list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, dissimilar_interpolation + "/intro/" + "/intro" + str(i)  + ".mid")
    #utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, dissimilar_interpolation + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, dissimilar_interpolation + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


In [None]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(dissimilar_interpolation + "/intro/" + '/intro' + str(i) + '.mid')
    outro = mido.MidiFile(dissimilar_interpolation + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(dissimilar_interpolation + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(dissimilar_interpolation + '/merged_predict' + str(i) + '.mid')

In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns