In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_fixed_augmented"
folder = "dynamic_fixed_augmented_models/intro"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
vocab = folder + "/vocab"

In [3]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

state = pickle.load(open('./state.pkl', 'rb'))
random.setstate(state)

In [4]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [5]:
event2word, word2event = pickle.load(open('dictionary_fixed_augmented.pkl', 'rb'))

In [6]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [7]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=1)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=1)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=1)

In [8]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    print(solo.transpose(1,0).size())

torch.Size([269, 8])
torch.Size([331, 8])
torch.Size([353, 8])
torch.Size([522, 8])
torch.Size([325, 8])
torch.Size([444, 8])
torch.Size([281, 8])
torch.Size([473, 8])
torch.Size([619, 8])
torch.Size([365, 8])
torch.Size([509, 8])
torch.Size([353, 8])
torch.Size([433, 8])
torch.Size([491, 8])


In [9]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [10]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [11]:
src_vocab_size = len(intro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [12]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4) #4e-5


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 11,076,868 trainable parameters


In [13]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [14]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [15]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, trg = intro.transpose(1,0), solo.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, trg = intro.transpose(1,0), solo.transpose(1,0)
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [16]:
def translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [17]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [18]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(5):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, intro, intro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [19]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+2) % 20 ==0:
        if check_mode_collapse(model) > 2:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 0m 6s
	Train Loss: 5.339 | Train PPL: 208.299
	 Val. Loss: 5.128 |  Val. PPL: 168.641
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 0m 6s
	Train Loss: 4.809 | Train PPL: 122.553
	 Val. Loss: 4.594 |  Val. PPL:  98.849
=> Saving checkpoint
Epoch: 03 | Time: 0m 6s
	Train Loss: 4.389 | Train PPL:  80.531
	 Val. Loss: 4.301 |  Val. PPL:  73.757
=> Saving checkpoint
Epoch: 04 | Time: 0m 6s
	Train Loss: 4.072 | Train PPL:  58.687
	 Val. Loss: 3.968 |  Val. PPL:  52.880
=> Saving checkpoint
Epoch: 05 | Time: 0m 6s
	Train Loss: 3.732 | Train PPL:  41.748
	 Val. Loss: 3.650 |  Val. PPL:  38.482
=> Saving checkpoint
Epoch: 06 | Time: 0m 6s
	Train Loss: 3.443 | Train PPL:  31.291
	 Val. Loss: 3.413 |  Val. PPL:  30.357
=> Saving checkpoint
Epoch: 07 | Time: 0m 6s
	Train Loss: 3.236 | Train PPL:  25.436
	 Val. Loss: 3.239 |  Val. PPL:  25.513
=> Saving checkpoint
Epoch: 08 | Time: 0m 6s
	Train Loss: 3.099 | Train PPL:  22.171
	 Val. Loss: 3.132 |  Val. PPL:  22.919


[0, 1, 50, 186, 10, 50, 186, 13, 50, 186, 13, 50, 186, 13, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5

[0, 1, 50, 186, 10, 50, 186, 13, 50, 186, 13, 50, 186, 13, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5

[0, 1, 50, 186, 1, 4, 58, 9, 37, 4, 58, 9, 37, 4, 58, 9, 7, 4, 58, 9, 42, 4, 58, 9, 42, 4, 58, 9, 10, 4, 58, 9, 10, 4, 58, 9, 43, 4, 58, 9, 28, 4, 58, 9, 13, 4, 58, 9, 31, 4, 58, 9, 21, 4, 58, 9, 16, 4, 58, 9, 23, 4, 58, 9, 36, 4, 58, 9, 0, 1, 4, 58, 9, 37, 4, 58, 9, 37, 4, 58, 9, 7, 4, 58, 9, 42, 4, 58, 9, 42, 4, 58, 9, 43, 4, 58, 9, 43, 4, 58, 9, 28, 4, 58, 9, 28, 4, 58, 9, 13, 4, 58, 9, 31, 4, 58, 9, 21, 4, 58, 9, 21, 4, 58, 9, 16, 4, 58, 9, 23, 4, 58, 9, 23, 4, 58, 9, 36, 4, 58, 9, 0, 1, 4, 58, 9, 37, 4, 58, 9, 37, 4, 58, 9, 42, 4, 58, 9, 42, 4, 58, 9, 10, 4, 58, 9, 43, 4, 58, 9, 43, 4, 58, 9, 28, 4, 58, 9, 28, 4, 58, 9, 13, 4, 58, 9, 31, 4, 58, 9, 21, 4, 58, 9, 21, 4, 58, 9, 16, 4, 58, 9, 23, 4, 58, 9, 23, 4, 58, 9, 36, 4, 58, 9, 0, 1, 4, 58, 9, 37, 4, 58, 9, 37, 4, 58, 9, 37, 4, 58, 9, 7, 4, 58, 9, 42, 4, 58, 9, 42, 4, 58, 9, 43, 4, 58, 9, 10, 4, 58, 9, 28, 4, 58, 9, 28, 4, 58, 9, 13, 4, 58, 9, 31, 4, 58, 9, 13, 4, 58, 9, 31, 4, 58, 9, 21, 4, 58, 9, 21, 4, 58, 9, 16, 4, 58, 9, 23

Epoch: 40 | Time: 0m 6s
	Train Loss: 1.929 | Train PPL:   6.886
	 Val. Loss: 2.145 |  Val. PPL:   8.540
=> Saving checkpoint
Epoch: 41 | Time: 0m 6s
	Train Loss: 1.914 | Train PPL:   6.778
	 Val. Loss: 2.133 |  Val. PPL:   8.439
=> Saving checkpoint
Epoch: 42 | Time: 0m 6s
	Train Loss: 1.899 | Train PPL:   6.682
	 Val. Loss: 2.116 |  Val. PPL:   8.301
=> Saving checkpoint
Epoch: 43 | Time: 0m 6s
	Train Loss: 1.879 | Train PPL:   6.549
	 Val. Loss: 2.118 |  Val. PPL:   8.312
Epoch: 44 | Time: 0m 6s
	Train Loss: 1.869 | Train PPL:   6.483
	 Val. Loss: 2.107 |  Val. PPL:   8.221
=> Saving checkpoint
Epoch: 45 | Time: 0m 6s
	Train Loss: 1.852 | Train PPL:   6.372
	 Val. Loss: 2.120 |  Val. PPL:   8.329
Epoch: 46 | Time: 0m 6s
	Train Loss: 1.841 | Train PPL:   6.300
	 Val. Loss: 2.102 |  Val. PPL:   8.179
=> Saving checkpoint
Epoch: 47 | Time: 0m 6s
	Train Loss: 1.827 | Train PPL:   6.213
	 Val. Loss: 2.090 |  Val. PPL:   8.084
=> Saving checkpoint
Epoch: 48 | Time: 0m 6s
	Train Loss: 1.816

[0, 1, 2, 3, 10, 2, 3, 13, 2, 3, 31, 57, 18, 20, 15, 57, 18, 20, 21, 57, 18, 20, 16, 2, 3, 16, 55, 18, 20, 23, 55, 18, 20, 24, 55, 18, 20, 36, 55, 18, 20, 0, 1, 2, 3, 37, 55, 18, 20, 7, 55, 18, 20, 42, 55, 18, 20, 10, 2, 3, 43, 55, 18, 20, 49, 55, 18, 20, 28, 55, 18, 20, 13, 2, 3, 13, 2, 3, 16, 2, 3, 16, 2, 3, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 37, 55, 18, 20, 42, 55, 18, 20, 10, 2, 3, 43, 55, 18, 20, 49, 55, 18, 20, 28, 55, 18, 20, 28, 55, 18, 20, 13, 2, 3, 13, 2, 3, 16, 2, 3, 16, 2, 3, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 37, 55, 18, 20, 7, 55, 18, 20, 42, 55, 18, 20, 10, 2, 3, 10, 2, 3, 10, 2, 3, 10, 2, 3, 13, 2, 3, 13, 2, 3, 13, 2, 3, 16, 2, 3, 16, 2, 3, 16, 2, 3, 0, 1, 2, 3, 10, 2, 3, 10, 2, 3, 13, 2, 3, 16, 2, 3, 16, 2, 3, 0, 1, 2, 3, 10, 2, 3, 10, 2, 3, 13, 2, 3, 13, 2, 3, 13, 2, 3, 13, 2, 3, 16, 2, 3, 16, 2, 3, 16, 2, 3, 16, 2, 3, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 10, 2, 3, 10, 2, 3, 13, 2, 3, 10, 2, 3, 13, 2, 3, 16, 2, 3, 16, 2, 3, 1, 2, 3, 1, 2, 3, 10, 2, 3, 10, 2, 3, 10, 2

[0, 1, 2, 216, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 16, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 10, 55, 30, 38, 13, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 30, 38, 1, 55, 30, 38, 13, 55, 30, 38, 16, 55, 30, 38, 16, 55, 30, 38, 16, 55, 30, 38, 0, 1, 5

[0, 1, 50, 186, 1, 63, 78, 20, 37, 114, 78, 20, 7, 63, 78, 20, 42, 63, 78, 20, 10, 50, 186, 10, 35, 78, 20, 43, 94, 78, 20, 49, 61, 78, 20, 28, 94, 78, 20, 13, 50, 186, 13, 55, 78, 20, 31, 63, 78, 20, 15, 63, 78, 20, 21, 94, 78, 20, 21, 94, 78, 20, 16, 94, 78, 20, 23, 94, 78, 20, 23, 94, 78, 20, 24, 94, 78, 20, 36, 94, 78, 20, 0, 1, 94, 78, 20, 37, 55, 78, 20, 7, 55, 78, 20, 42, 94, 78, 20, 42, 94, 78, 20, 10, 94, 78, 20, 43, 94, 78, 20, 49, 94, 78, 20, 28, 94, 78, 20, 13, 94, 78, 20, 31, 94, 78, 20, 15, 63, 78, 20, 21, 94, 78, 20, 21, 57, 78, 20, 16, 94, 78, 20, 23, 94, 78, 20, 23, 94, 78, 20, 23, 94, 78, 20, 24, 94, 78, 20, 36, 94, 78, 20, 0, 1, 94, 78, 20, 37, 94, 78, 20, 37, 94, 78, 20, 7, 94, 78, 20, 42, 94, 78, 20, 10, 55, 78, 20, 10, 94, 78, 20, 43, 94, 78, 20, 43, 94, 78, 20, 49, 94, 78, 20, 49, 94, 78, 20, 28, 94, 78, 20, 28, 94, 78, 20, 13, 94, 70, 20, 13, 94, 70, 20, 31, 94, 70, 20, 31, 94, 70, 20, 15, 94, 70, 20, 15, 94, 70, 20, 21, 94, 70, 20, 21, 57, 70, 20, 16, 94, 70, 2

[0, 1, 2, 127, 1, 4, 29, 20, 37, 4, 29, 20, 37, 4, 29, 20, 7, 4, 29, 20, 7, 4, 29, 20, 42, 4, 29, 20, 42, 4, 30, 20, 42, 4, 52, 20, 10, 4, 32, 20, 43, 4, 32, 20, 49, 4, 32, 20, 49, 4, 33, 20, 28, 4, 32, 20, 13, 4, 32, 20, 31, 4, 32, 20, 15, 4, 32, 20, 21, 4, 32, 20, 21, 4, 32, 38, 0, 37, 4, 32, 38, 42, 4, 32, 38, 43, 4, 32, 38, 31, 4, 32, 38, 21, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 21, 4, 32, 38, 21, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 31, 4, 32, 38, 21, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 21, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 21, 4, 32, 38, 23, 4, 65, 38, 0, 37, 4, 65, 38, 43, 4, 65, 38, 43, 4, 32, 38, 31, 4, 65, 38, 21, 4, 32, 38, 21, 4, 65, 38, 0, 37, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 64, 38, 43, 4, 32, 38, 31, 4, 65, 38, 31, 4, 65, 38, 23, 4, 65, 38, 0, 37, 4, 32, 38, 0, 37, 4, 65, 7, 4, 32, 38, 43, 4, 32, 38, 43, 4, 65, 38, 43, 4, 58, 38, 31, 4, 53, 38, 31, 4, 65, 38, 21, 4, 65, 38, 23, 4, 53, 38, 23, 4, 32, 38, 23, 4, 65, 38, 23,

Epoch: 126 | Time: 0m 6s
	Train Loss: 1.221 | Train PPL:   3.392
	 Val. Loss: 2.590 |  Val. PPL:  13.335
Epoch: 127 | Time: 0m 6s
	Train Loss: 1.222 | Train PPL:   3.395
	 Val. Loss: 2.587 |  Val. PPL:  13.288
Epoch: 128 | Time: 0m 6s
	Train Loss: 1.213 | Train PPL:   3.364
	 Val. Loss: 2.581 |  Val. PPL:  13.205
Epoch: 129 | Time: 0m 6s
	Train Loss: 1.212 | Train PPL:   3.362
	 Val. Loss: 2.587 |  Val. PPL:  13.292
Epoch: 130 | Time: 0m 6s
	Train Loss: 1.199 | Train PPL:   3.316
	 Val. Loss: 2.596 |  Val. PPL:  13.414
Epoch: 131 | Time: 0m 6s
	Train Loss: 1.197 | Train PPL:   3.310
	 Val. Loss: 2.640 |  Val. PPL:  14.013
Epoch: 132 | Time: 0m 6s
	Train Loss: 1.191 | Train PPL:   3.291
	 Val. Loss: 2.605 |  Val. PPL:  13.528
Epoch: 133 | Time: 0m 6s
	Train Loss: 1.186 | Train PPL:   3.275
	 Val. Loss: 2.644 |  Val. PPL:  14.063
Epoch: 134 | Time: 0m 6s
	Train Loss: 1.181 | Train PPL:   3.259
	 Val. Loss: 2.623 |  Val. PPL:  13.779
Epoch: 135 | Time: 0m 6s
	Train Loss: 1.164 | Train PPL

[0, 1, 50, 51, 37, 62, 33, 44, 42, 55, 33, 25, 49, 55, 33, 25, 15, 55, 33, 25, 24, 35, 33, 44, 0, 37, 55, 32, 44, 7, 55, 33, 25, 49, 55, 33, 25, 24, 35, 33, 38, 24, 35, 33, 38, 0, 7, 55, 33, 38, 49, 55, 33, 38, 49, 55, 33, 38, 49, 55, 32, 38, 31, 55, 33, 38, 15, 35, 33, 38, 24, 55, 33, 38, 24, 35, 32, 38, 0, 7, 55, 33, 38, 49, 55, 33, 19, 49, 35, 33, 19, 24, 35, 33, 38, 24, 35, 33, 38, 24, 55, 64, 38, 0, 7, 35, 65, 38, 7, 35, 65, 38, 7, 55, 128, 38, 49, 35, 64, 38, 49, 55, 65, 38, 15, 55, 65, 38, 49, 55, 64, 38, 15, 35, 53, 38, 24, 35, 33, 44, 0, 7, 55, 128, 38, 49, 55, 33, 38, 49, 35, 33, 38, 15, 55, 128, 38, 49, 55, 128, 38, 15, 55, 65, 38, 23, 35, 128, 38, 24, 35, 65, 38, 24, 55, 33, 38, 0, 7, 55, 64, 38, 49, 35, 64, 38, 49, 55, 64, 38, 49, 55, 33, 38, 15, 60, 128, 38, 24, 35, 65, 38, 24, 55, 65, 38, 0, 7, 55, 33, 38, 49, 35, 65, 38, 49, 60, 128, 38, 49, 35, 128, 38, 49, 35, 65, 38, 49, 55, 64, 38, 15, 60, 64, 38, 15, 55, 33, 38, 15, 55, 65, 38, 15, 55, 65, 38]
Epoch: 140 | Time: 0m

[0, 1, 50, 51, 37, 62, 33, 25, 10, 50, 51, 43, 35, 32, 9, 49, 35, 32, 9, 28, 35, 32, 9, 13, 50, 51, 13, 35, 30, 9, 31, 60, 32, 9, 31, 35, 30, 25, 16, 50, 51, 16, 60, 30, 9, 23, 35, 32, 9, 24, 60, 32, 9, 36, 35, 32, 9, 0, 1, 35, 32, 9, 37, 60, 29, 9, 37, 35, 30, 9, 7, 62, 29, 9, 42, 35, 32, 44, 43, 35, 30, 9, 28, 62, 32, 9, 13, 60, 32, 9, 31, 60, 33, 9, 31, 60, 32, 9, 15, 35, 30, 9, 21, 35, 32, 9, 16, 35, 32, 9, 23, 66, 32, 9, 24, 60, 33, 9, 36, 35, 34, 9, 0, 1, 35, 33, 25, 37, 35, 32, 25, 10, 35, 33, 9, 43, 35, 32, 6, 28, 35, 30, 9, 13, 35, 33, 25, 13, 35, 33, 9, 31, 35, 32, 9, 15, 35, 32, 9, 21, 35, 33, 9, 21, 35, 34, 9, 16, 35, 53, 9, 16, 35, 34, 9, 23, 35, 54, 38, 0, 1, 35, 65, 9, 37, 35, 65, 9, 7, 35, 128, 9, 42, 35, 65, 9, 10, 35, 64, 9, 10, 50, 205, 13, 35, 53, 9, 43, 60, 54, 9, 31, 60, 65, 9, 15, 35, 33, 9, 21, 66, 33, 9, 16, 35, 65, 9, 23, 60, 64, 9, 24, 60, 64, 9, 36, 60, 65, 9, 36, 35, 54, 83]
Epoch: 160 | Time: 0m 6s
	Train Loss: 1.052 | Train PPL:   2.864
	 Val. Loss: 2.878

[0, 1, 2, 127, 1, 55, 41, 44, 7, 35, 78, 44, 10, 60, 78, 44, 49, 60, 78, 44, 13, 60, 78, 44, 15, 61, 90, 44, 16, 66, 18, 38, 0, 1, 66, 78, 44, 7, 60, 78, 44, 10, 66, 78, 44, 49, 66, 78, 44, 13, 60, 70, 44, 15, 66, 18, 212, 1, 66, 70, 44, 16, 60, 78, 44, 49, 66, 70, 44, 13, 60, 78, 44, 13, 66, 70, 44, 15, 66, 77, 44, 16, 81, 78, 44, 16, 66, 70, 44, 24, 66, 77, 44, 0, 1, 66, 78, 44, 1, 66, 70, 44, 7, 66, 78, 44, 7, 66, 70, 44, 10, 66, 78, 38, 10, 66, 70, 38, 13, 66, 78, 38, 13, 60, 70, 38, 16, 66, 70, 38, 16, 60, 78, 38, 0, 1, 35, 70, 38, 10, 60, 70, 38, 13, 60, 78, 44, 7, 35, 70, 44, 49, 66, 70, 44, 13, 35, 70, 44, 15, 60, 70, 44, 15, 66, 68, 44, 16, 61, 70, 44, 16, 35, 72, 44, 24, 66, 70, 44, 0, 1, 55, 72, 44, 1, 35, 107, 44, 1, 66, 41, 44, 1, 55, 73, 44, 1, 60, 73, 38, 10, 66, 70, 48, 10, 66, 73, 44, 10, 60, 107, 44, 49, 66, 73, 44, 49, 66, 73, 44, 13, 55, 128, 44, 13, 35, 107, 44, 15, 66, 73, 44, 15, 66, 73, 48, 16, 66, 107, 44, 16, 66, 73, 44, 24, 66, 73, 44, 24, 66, 107, 44, 0, 1, 

Epoch: 204 | Time: 0m 6s
	Train Loss: 0.870 | Train PPL:   2.388
	 Val. Loss: 3.299 |  Val. PPL:  27.075
Epoch: 205 | Time: 0m 6s
	Train Loss: 0.851 | Train PPL:   2.341
	 Val. Loss: 3.413 |  Val. PPL:  30.351
Epoch: 206 | Time: 0m 6s
	Train Loss: 0.854 | Train PPL:   2.349
	 Val. Loss: 3.409 |  Val. PPL:  30.238
Epoch: 207 | Time: 0m 6s
	Train Loss: 0.839 | Train PPL:   2.313
	 Val. Loss: 3.364 |  Val. PPL:  28.906
Epoch: 208 | Time: 0m 6s
	Train Loss: 0.841 | Train PPL:   2.319
	 Val. Loss: 3.377 |  Val. PPL:  29.270
Epoch: 209 | Time: 0m 6s
	Train Loss: 0.853 | Train PPL:   2.346
	 Val. Loss: 3.387 |  Val. PPL:  29.578
Epoch: 210 | Time: 0m 6s
	Train Loss: 0.838 | Train PPL:   2.312
	 Val. Loss: 3.361 |  Val. PPL:  28.824
Epoch: 211 | Time: 0m 6s
	Train Loss: 0.829 | Train PPL:   2.291
	 Val. Loss: 3.424 |  Val. PPL:  30.698
Epoch: 212 | Time: 0m 6s
	Train Loss: 0.842 | Train PPL:   2.322
	 Val. Loss: 3.370 |  Val. PPL:  29.072
Epoch: 213 | Time: 0m 6s
	Train Loss: 0.837 | Train PPL

Epoch: 220 | Time: 0m 6s
	Train Loss: 0.794 | Train PPL:   2.212
	 Val. Loss: 3.459 |  Val. PPL:  31.771
=> Saving checkpoint
Epoch: 221 | Time: 0m 6s
	Train Loss: 0.802 | Train PPL:   2.230
	 Val. Loss: 3.440 |  Val. PPL:  31.180
=> Saving checkpoint
Epoch: 222 | Time: 0m 6s
	Train Loss: 0.794 | Train PPL:   2.212
	 Val. Loss: 3.471 |  Val. PPL:  32.163
Epoch: 223 | Time: 0m 6s
	Train Loss: 0.798 | Train PPL:   2.220
	 Val. Loss: 3.483 |  Val. PPL:  32.543
Epoch: 224 | Time: 0m 6s
	Train Loss: 0.793 | Train PPL:   2.210
	 Val. Loss: 3.504 |  Val. PPL:  33.236
Epoch: 225 | Time: 0m 6s
	Train Loss: 0.790 | Train PPL:   2.203
	 Val. Loss: 3.501 |  Val. PPL:  33.150
Epoch: 226 | Time: 0m 6s
	Train Loss: 0.785 | Train PPL:   2.193
	 Val. Loss: 3.538 |  Val. PPL:  34.385
Epoch: 227 | Time: 0m 6s
	Train Loss: 0.778 | Train PPL:   2.177
	 Val. Loss: 3.512 |  Val. PPL:  33.525
Epoch: 228 | Time: 0m 6s
	Train Loss: 0.776 | Train PPL:   2.172
	 Val. Loss: 3.493 |  Val. PPL:  32.897
Epoch: 229 | 

Epoch: 240 | Time: 0m 6s
	Train Loss: 0.756 | Train PPL:   2.130
	 Val. Loss: 3.616 |  Val. PPL:  37.200
=> Saving checkpoint
Epoch: 241 | Time: 0m 6s
	Train Loss: 0.737 | Train PPL:   2.089
	 Val. Loss: 3.682 |  Val. PPL:  39.711
=> Saving checkpoint
Epoch: 242 | Time: 0m 6s
	Train Loss: 0.730 | Train PPL:   2.074
	 Val. Loss: 3.642 |  Val. PPL:  38.165
Epoch: 243 | Time: 0m 6s
	Train Loss: 0.734 | Train PPL:   2.083
	 Val. Loss: 3.673 |  Val. PPL:  39.383
Epoch: 244 | Time: 0m 6s
	Train Loss: 0.729 | Train PPL:   2.074
	 Val. Loss: 3.685 |  Val. PPL:  39.849
Epoch: 245 | Time: 0m 6s
	Train Loss: 0.737 | Train PPL:   2.089
	 Val. Loss: 3.662 |  Val. PPL:  38.928
Epoch: 246 | Time: 0m 6s
	Train Loss: 0.734 | Train PPL:   2.083
	 Val. Loss: 3.645 |  Val. PPL:  38.274
Epoch: 247 | Time: 0m 6s
	Train Loss: 0.715 | Train PPL:   2.044
	 Val. Loss: 3.671 |  Val. PPL:  39.291
Epoch: 248 | Time: 0m 6s
	Train Loss: 0.708 | Train PPL:   2.029
	 Val. Loss: 3.637 |  Val. PPL:  37.991
Epoch: 249 | 

[0, 1, 50, 104, 37, 57, 90, 46, 42, 39, 18, 44, 43, 95, 30, 56, 21, 62, 77, 9, 16, 62, 78, 56, 0, 42, 95, 18, 44, 43, 95, 30, 9, 49, 95, 41, 25, 21, 17, 30, 9, 16, 62, 90, 75, 0, 10, 57, 78, 20, 43, 17, 77, 20, 49, 95, 18, 75, 16, 95, 11, 27, 0, 1, 62, 90, 44, 42, 62, 90, 44, 49, 95, 18, 44, 13, 95, 78, 9, 31, 57, 26, 6, 15, 39, 58, 47, 16, 62, 26, 6, 23, 57, 11, 44, 0, 1, 95, 11, 44, 37, 62, 78, 44, 10, 39, 5, 25, 15, 57, 90, 44, 16, 40, 78, 44, 24, 62, 26, 44, 0, 1, 39, 58, 44, 37, 62, 98, 44, 42, 57, 53, 75, 21, 39, 41, 44, 23, 95, 78, 44, 36, 55, 78, 44, 0, 1, 55, 98, 44, 37, 57, 52, 38, 43, 62, 41, 44, 28, 57, 78, 48]
[0, 1, 50, 172, 10, 50, 172, 13, 50, 172, 16, 50, 172, 16, 35, 69, 9, 23, 61, 53, 9, 24, 63, 72, 9, 36, 94, 69, 9, 0, 1, 50, 172, 1, 35, 69, 56, 10, 50, 172, 28, 136, 65, 9, 13, 50, 172, 13, 66, 106, 9, 15, 66, 72, 9, 16, 50, 172, 24, 35, 106, 9, 36, 94, 65, 9, 0, 1, 50, 172, 1, 60, 52, 74, 10, 50, 172, 13, 50, 172, 16, 50, 172, 16, 35, 52, 120, 0, 1, 50, 172, 1, 55,

Epoch: 291 | Time: 0m 6s
	Train Loss: 0.620 | Train PPL:   1.859
	 Val. Loss: 4.086 |  Val. PPL:  59.507
Epoch: 292 | Time: 0m 6s
	Train Loss: 0.603 | Train PPL:   1.828
	 Val. Loss: 3.987 |  Val. PPL:  53.884
Epoch: 293 | Time: 0m 6s
	Train Loss: 0.584 | Train PPL:   1.792
	 Val. Loss: 4.057 |  Val. PPL:  57.817
Epoch: 294 | Time: 0m 6s
	Train Loss: 0.587 | Train PPL:   1.798
	 Val. Loss: 4.104 |  Val. PPL:  60.585
Epoch: 295 | Time: 0m 6s
	Train Loss: 0.606 | Train PPL:   1.834
	 Val. Loss: 4.033 |  Val. PPL:  56.439
Epoch: 296 | Time: 0m 6s
	Train Loss: 0.585 | Train PPL:   1.794
	 Val. Loss: 4.083 |  Val. PPL:  59.302
Epoch: 297 | Time: 0m 6s
	Train Loss: 0.591 | Train PPL:   1.806
	 Val. Loss: 3.989 |  Val. PPL:  53.978
Epoch: 298 | Time: 0m 6s
	Train Loss: 0.583 | Train PPL:   1.791
	 Val. Loss: 4.110 |  Val. PPL:  60.928
Epoch: 299 | Time: 0m 6s
	Train Loss: 0.575 | Train PPL:   1.777
	 Val. Loss: 4.117 |  Val. PPL:  61.402
[0, 1, 50, 186, 1, 35, 68, 56, 7, 57, 69, 47, 10, 57, 6

[0, 1, 2, 216, 1, 55, 30, 38, 10, 55, 41, 47, 13, 55, 30, 38, 16, 55, 78, 48, 0, 1, 55, 30, 38, 10, 55, 78, 38, 13, 55, 18, 38, 16, 55, 18, 38, 0, 1, 55, 29, 47, 10, 55, 18, 44, 49, 55, 78, 19, 0, 10, 55, 30, 38, 49, 55, 41, 82, 13, 55, 30, 38, 16, 55, 30, 38, 0, 1, 55, 26, 96, 49, 35, 18, 25, 16, 55, 18, 38, 0, 1, 55, 18, 38, 10, 55, 30, 48, 13, 55, 30, 48, 0, 1, 55, 41, 48, 13, 55, 18, 38, 16, 55, 34, 48, 0, 10, 55, 30, 38, 13, 55, 30, 38, 16, 55, 41, 38, 16, 55, 30, 38, 0, 1, 55, 29, 38, 10, 55, 30, 38]
[0, 1, 2, 216, 1, 55, 30, 38, 10, 55, 41, 47, 13, 55, 30, 38, 16, 55, 78, 48, 0, 1, 55, 30, 38, 10, 55, 78, 38, 13, 55, 18, 38, 16, 55, 18, 38, 0, 1, 55, 29, 47, 10, 55, 11, 100, 0, 1, 55, 90, 48, 13, 55, 18, 74, 0, 1, 55, 34, 74, 16, 55, 18, 38, 16, 55, 34, 38, 16, 55, 33, 38, 0, 1, 55, 30, 48, 49, 55, 30, 48, 0, 1, 55, 33, 38, 10, 55, 41, 38, 13, 55, 41, 38, 16, 55, 30, 38, 0, 1, 55, 78, 38, 10, 55, 18, 38, 28, 35, 30, 38, 13, 55, 30, 38, 0, 1, 35, 78, 38]
Epoch: 320 | Time: 0m 6s


Epoch: 341 | Time: 0m 6s
	Train Loss: 0.494 | Train PPL:   1.638
	 Val. Loss: 4.313 |  Val. PPL:  74.668
=> Saving checkpoint
Epoch: 342 | Time: 0m 6s
	Train Loss: 0.513 | Train PPL:   1.670
	 Val. Loss: 4.330 |  Val. PPL:  75.918
Epoch: 343 | Time: 0m 6s
	Train Loss: 0.493 | Train PPL:   1.637
	 Val. Loss: 4.360 |  Val. PPL:  78.231
Epoch: 344 | Time: 0m 6s
	Train Loss: 0.472 | Train PPL:   1.604
	 Val. Loss: 4.382 |  Val. PPL:  80.020
Epoch: 345 | Time: 0m 6s
	Train Loss: 0.499 | Train PPL:   1.646
	 Val. Loss: 4.423 |  Val. PPL:  83.378
Epoch: 346 | Time: 0m 6s
	Train Loss: 0.489 | Train PPL:   1.630
	 Val. Loss: 4.369 |  Val. PPL:  78.929
Epoch: 347 | Time: 0m 6s
	Train Loss: 0.497 | Train PPL:   1.644
	 Val. Loss: 4.370 |  Val. PPL:  79.066
Epoch: 348 | Time: 0m 6s
	Train Loss: 0.493 | Train PPL:   1.637
	 Val. Loss: 4.416 |  Val. PPL:  82.753
Epoch: 349 | Time: 0m 6s
	Train Loss: 0.474 | Train PPL:   1.606
	 Val. Loss: 4.368 |  Val. PPL:  78.885
Epoch: 350 | Time: 0m 6s
	Train Lo

[0, 1, 50, 104, 37, 57, 90, 46, 42, 39, 18, 44, 43, 95, 30, 56, 21, 62, 77, 9, 16, 4, 30, 6, 24, 55, 77, 9, 36, 40, 18, 27, 0, 43, 39, 58, 75, 21, 40, 90, 20, 16, 40, 90, 25, 0, 49, 40, 90, 6, 13, 4, 78, 20, 31, 95, 78, 38, 24, 62, 90, 6, 0, 1, 62, 78, 20, 37, 40, 78, 9, 42, 62, 78, 20, 10, 57, 78, 9, 43, 62, 26, 6, 28, 55, 90, 9, 13, 62, 26, 6, 15, 55, 90, 9, 21, 60, 26, 6, 16, 62, 90, 9, 23, 55, 26, 6, 36, 55, 90, 9, 0, 1, 39, 18, 56, 43, 95, 41, 9, 28, 57, 41, 20, 13, 57, 41, 20, 31, 39, 41, 12, 23, 95, 68, 27, 0, 37, 4, 68, 20, 7, 4, 18, 46, 43, 95, 34, 6, 28, 95, 34, 9, 13, 57, 69, 20, 31, 95, 33, 38, 23, 95, 34, 20, 36, 95, 34, 9, 0, 1, 39, 69, 20, 37, 95, 33, 12, 43, 95, 34, 75, 21, 17, 41, 85]
[0, 1, 50, 79, 37, 35, 90, 44, 42, 35, 78, 56, 15, 66, 105, 9, 15, 66, 90, 20, 21, 35, 58, 47, 36, 35, 58, 44, 0, 37, 35, 90, 44, 42, 35, 78, 38, 28, 57, 70, 38, 21, 57, 53, 38, 36, 62, 69, 44, 0, 37, 35, 68, 46, 42, 55, 69, 74, 36, 62, 68, 44, 0, 37, 55, 69, 44, 42, 62, 52, 103, 21, 57, 

Epoch: 400 | Time: 0m 6s
	Train Loss: 0.418 | Train PPL:   1.518
	 Val. Loss: 4.641 |  Val. PPL: 103.677
=> Saving checkpoint
Epoch: 401 | Time: 0m 6s
	Train Loss: 0.426 | Train PPL:   1.531
	 Val. Loss: 4.733 |  Val. PPL: 113.610
=> Saving checkpoint
Epoch: 402 | Time: 0m 6s
	Train Loss: 0.408 | Train PPL:   1.503
	 Val. Loss: 4.743 |  Val. PPL: 114.730
Epoch: 403 | Time: 0m 6s
	Train Loss: 0.396 | Train PPL:   1.486
	 Val. Loss: 4.714 |  Val. PPL: 111.505
Epoch: 404 | Time: 0m 6s
	Train Loss: 0.412 | Train PPL:   1.510
	 Val. Loss: 4.681 |  Val. PPL: 107.919
Epoch: 405 | Time: 0m 6s
	Train Loss: 0.410 | Train PPL:   1.507
	 Val. Loss: 4.714 |  Val. PPL: 111.510
Epoch: 406 | Time: 0m 6s
	Train Loss: 0.408 | Train PPL:   1.504
	 Val. Loss: 4.753 |  Val. PPL: 115.956
Epoch: 407 | Time: 0m 6s
	Train Loss: 0.392 | Train PPL:   1.481
	 Val. Loss: 4.758 |  Val. PPL: 116.534
Epoch: 408 | Time: 0m 6s
	Train Loss: 0.407 | Train PPL:   1.502
	 Val. Loss: 4.754 |  Val. PPL: 116.058
Epoch: 409 | 

Epoch: 422 | Time: 0m 6s
	Train Loss: 0.375 | Train PPL:   1.454
	 Val. Loss: 4.836 |  Val. PPL: 126.001
Epoch: 423 | Time: 0m 6s
	Train Loss: 0.367 | Train PPL:   1.444
	 Val. Loss: 4.776 |  Val. PPL: 118.688
Epoch: 424 | Time: 0m 6s
	Train Loss: 0.368 | Train PPL:   1.444
	 Val. Loss: 4.818 |  Val. PPL: 123.778
Epoch: 425 | Time: 0m 6s
	Train Loss: 0.378 | Train PPL:   1.460
	 Val. Loss: 4.752 |  Val. PPL: 115.817
Epoch: 426 | Time: 0m 6s
	Train Loss: 0.375 | Train PPL:   1.454
	 Val. Loss: 4.811 |  Val. PPL: 122.902
Epoch: 427 | Time: 0m 6s
	Train Loss: 0.381 | Train PPL:   1.463
	 Val. Loss: 4.850 |  Val. PPL: 127.701
Epoch: 428 | Time: 0m 6s
	Train Loss: 0.371 | Train PPL:   1.449
	 Val. Loss: 4.776 |  Val. PPL: 118.633
Epoch: 429 | Time: 0m 6s
	Train Loss: 0.372 | Train PPL:   1.451
	 Val. Loss: 4.835 |  Val. PPL: 125.800
Epoch: 430 | Time: 0m 6s
	Train Loss: 0.369 | Train PPL:   1.446
	 Val. Loss: 4.806 |  Val. PPL: 122.272
Epoch: 431 | Time: 0m 6s
	Train Loss: 0.371 | Train PPL

Epoch: 451 | Time: 0m 6s
	Train Loss: 0.362 | Train PPL:   1.436
	 Val. Loss: 4.912 |  Val. PPL: 135.972
Epoch: 452 | Time: 0m 6s
	Train Loss: 0.332 | Train PPL:   1.393
	 Val. Loss: 4.936 |  Val. PPL: 139.180
Epoch: 453 | Time: 0m 6s
	Train Loss: 0.335 | Train PPL:   1.398
	 Val. Loss: 4.941 |  Val. PPL: 139.913
Epoch: 454 | Time: 0m 6s
	Train Loss: 0.338 | Train PPL:   1.402
	 Val. Loss: 4.890 |  Val. PPL: 132.999
Epoch: 455 | Time: 0m 6s
	Train Loss: 0.337 | Train PPL:   1.400
	 Val. Loss: 4.939 |  Val. PPL: 139.677
Epoch: 456 | Time: 0m 6s
	Train Loss: 0.343 | Train PPL:   1.409
	 Val. Loss: 5.000 |  Val. PPL: 148.361
Epoch: 457 | Time: 0m 6s
	Train Loss: 0.342 | Train PPL:   1.408
	 Val. Loss: 4.978 |  Val. PPL: 145.248
Epoch: 458 | Time: 0m 6s
	Train Loss: 0.342 | Train PPL:   1.407
	 Val. Loss: 4.958 |  Val. PPL: 142.374
Epoch: 459 | Time: 0m 6s
	Train Loss: 0.354 | Train PPL:   1.425
	 Val. Loss: 5.070 |  Val. PPL: 159.156
[0, 1, 50, 104, 37, 57, 8, 47, 10, 50, 104, 43, 4, 30, 

[0, 1, 50, 104, 1, 95, 33, 46, 42, 40, 34, 20, 42, 17, 33, 20, 10, 50, 104, 10, 35, 41, 6, 49, 4, 30, 6, 13, 50, 104, 13, 57, 78, 56, 16, 50, 104, 0, 1, 50, 104, 1, 95, 34, 56, 10, 50, 104, 49, 95, 33, 9, 28, 40, 34, 9, 13, 50, 104, 13, 95, 68, 44, 16, 95, 64, 38, 0, 1, 95, 70, 38, 10, 95, 30, 9, 43, 4, 33, 20, 49, 95, 34, 9, 28, 17, 70, 9, 13, 95, 72, 44, 21, 95, 106, 46, 24, 17, 70, 44, 0, 1, 4, 41, 56, 1, 95, 73, 47, 49, 95, 33, 9, 28, 57, 70, 9, 13, 95, 34, 27, 23, 55, 70, 20, 24, 35, 34, 20, 24, 62, 33, 6, 36, 39, 34, 9, 0, 1, 55, 70, 46, 7, 40, 41, 46, 10, 95, 41, 44, 49, 39, 30, 44, 13, 17, 78, 38, 24, 95, 33, 6, 0, 1, 95, 68, 47, 49, 95, 33, 6, 13, 95, 34, 47, 16, 17, 30, 38, 36, 4, 41, 86]
[0, 1, 50, 135, 37, 22, 52, 9, 7, 22, 52, 9, 42, 57, 32, 20, 42, 40, 52, 20, 10, 50, 135, 43, 22, 52, 9, 49, 57, 32, 20, 28, 4, 32, 6, 13, 50, 135, 16, 50, 135, 23, 57, 52, 38, 0, 1, 50, 135, 37, 40, 69, 47, 10, 50, 135, 10, 57, 32, 20, 28, 22, 52, 20, 13, 50, 135, 21, 22, 52, 9, 16, 50, 135

[0, 1, 50, 135, 37, 4, 34, 45, 21, 17, 54, 44, 23, 17, 33, 44, 36, 22, 34, 44, 0, 37, 17, 32, 27, 37, 17, 32, 120, 23, 17, 34, 6, 36, 17, 54, 6, 0, 1, 17, 64, 59, 16, 22, 54, 44, 24, 17, 34, 44, 0, 37, 22, 33, 59, 23, 40, 5, 6, 23, 40, 32, 6, 36, 22, 11, 44, 36, 22, 33, 44, 0, 37, 40, 18, 44, 37, 40, 34, 44, 42, 40, 11, 6, 42, 40, 33, 6, 43, 57, 11, 6, 43, 57, 33, 6, 28, 40, 5, 6, 28, 40, 32, 6, 31, 22, 5, 47, 31, 22, 32, 47, 23, 4, 30, 47, 23, 4, 64, 47, 0, 37, 22, 29, 25, 37, 22, 54, 25, 28, 4, 30, 44, 28, 4, 64, 44, 31, 40, 29, 48, 31, 40, 54, 48, 0, 37, 4, 11, 59, 37, 4, 33, 59, 23, 40, 29, 6, 23, 40, 54, 6, 36, 57, 18, 44, 36, 57, 34, 44, 0, 1, 22, 11, 112, 1, 22, 33, 112]
[0, 1, 50, 79, 37, 4, 29, 56, 10, 50, 79, 28, 4, 58, 20, 13, 50, 79, 13, 39, 98, 20, 31, 22, 11, 6, 21, 39, 98, 6, 16, 50, 79, 23, 4, 11, 44, 36, 4, 5, 9, 0, 1, 50, 79, 1, 4, 105, 9, 37, 4, 5, 99, 10, 50, 79, 28, 61, 69, 9, 13, 50, 79, 13, 60, 53, 9, 31, 55, 54, 44, 21, 35, 69, 44, 16, 50, 79, 23, 55, 53, 44, 36

In [None]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [17]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, model, optimizer)

model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint


In [23]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

183.57825508342776


In [21]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [22]:
for i in range(0,len(test_intro)):
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, intro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
    
        


[0, 1, 50, 186, 1, 63, 54, 75, 10, 50, 186, 49, 61, 34, 9, 28, 94, 54, 82, 13, 50, 186, 16, 50, 186, 23, 81, 64, 9, 24, 81, 73, 9, 36, 136, 65, 75, 0, 1, 50, 186, 10, 50, 186, 49, 60, 34, 20, 28, 61, 34, 27, 13, 50, 186, 16, 50, 186, 24, 35, 54, 44, 0, 1, 50, 186, 1, 35, 64, 56, 42, 61, 65, 12, 10, 50, 186, 13, 50, 186, 13, 94, 128, 12, 21, 61, 54, 56, 16, 50, 186, 36, 35, 34, 12, 0, 1, 50, 186, 10, 50, 186, 13, 50, 186, 16, 50, 186, 23, 61, 18, 9, 23, 61, 34, 9, 24, 63, 58, 9, 24, 63, 53, 9, 36, 61, 29, 56, 36, 61, 54, 56, 0, 1, 50, 186, 42, 81, 32, 44, 42, 94, 32, 44, 42, 81, 65, 44, 10, 50, 186, 13, 50, 186, 13, 94, 33, 9, 13, 94, 128, 9, 16, 50, 186, 23, 63, 54, 9, 24, 114, 30, 9, 24, 94, 64, 9, 36, 114, 41, 9, 36, 81, 73, 9, 0, 1, 50, 186, 1, 61, 32, 6, 1, 114, 65, 12, 1, 114, 213, 12, 42, 35, 65, 46, 42, 35, 213, 46, 10, 50, 186, 49, 61, 125, 44, 49, 61, 214, 44, 13, 50, 186, 13, 35, 73, 47, 13, 35, 190, 47, 16, 50, 186, 24, 35, 65, 9, 24, 35, 213, 9, 36, 66, 64, 12, 36, 66, 189,

[0, 1, 50, 127, 1, 61, 69, 44, 7, 61, 108, 44, 10, 61, 32, 44, 49, 61, 65, 44, 13, 55, 65, 38, 15, 61, 107, 38, 16, 81, 32, 38, 0, 1, 61, 65, 38, 31, 81, 98, 44, 15, 81, 69, 44, 16, 61, 32, 44, 16, 81, 52, 44, 24, 61, 32, 44, 0, 1, 61, 32, 44, 7, 61, 32, 44, 10, 61, 53, 44, 49, 61, 58, 44, 28, 61, 98, 44, 13, 94, 98, 85, 0, 1, 94, 77, 9, 42, 81, 32, 44, 49, 61, 52, 44, 13, 55, 32, 44, 15, 61, 69, 44, 23, 66, 32, 44, 24, 57, 53, 27, 0, 1, 61, 69, 44, 7, 61, 32, 44, 10, 61, 32, 38, 49, 63, 69, 38, 13, 94, 69, 56, 31, 61, 32, 44, 21, 62, 69, 44, 16, 61, 32, 44, 23, 35, 69, 44, 24, 62, 32, 44, 36, 61, 32, 44, 0, 1, 61, 32, 44, 1, 55, 53, 27, 49, 66, 52, 44, 7, 61, 32, 44, 10, 61, 32, 44, 49, 66, 69, 44, 13, 61, 32, 44, 15, 66, 69, 44, 15, 61, 32, 44, 16, 61, 53, 44, 16, 61, 77, 46, 36, 61, 53, 19, 0, 10, 61, 32, 44, 49, 61, 32, 38, 28, 61, 32, 44, 13, 61, 52, 44, 15, 61, 52, 44, 16, 61, 77, 44, 24, 61, 32, 44, 36, 61, 32, 44, 24, 61, 58, 46]
8
[0, 1, 2, 137, 7, 115, 30, 38, 7, 115, 64, 38,

[0, 1, 2, 137, 7, 115, 30, 38, 7, 115, 64, 38, 49, 81, 30, 44, 49, 81, 64, 44, 13, 81, 78, 44, 13, 81, 70, 44, 15, 114, 30, 44, 15, 114, 64, 44, 16, 136, 41, 44, 16, 136, 73, 44, 24, 63, 33, 38, 24, 63, 128, 38, 0, 7, 81, 30, 38, 7, 81, 64, 38, 49, 94, 30, 44, 49, 94, 64, 44, 13, 114, 78, 44, 13, 114, 70, 44, 15, 136, 30, 44, 15, 136, 64, 44, 16, 81, 33, 44, 16, 81, 128, 44, 24, 81, 41, 38, 24, 81, 73, 38, 0, 7, 114, 30, 38, 7, 114, 64, 38, 49, 81, 30, 44, 49, 81, 64, 44, 13, 114, 78, 44, 13, 114, 70, 44, 15, 114, 30, 44, 15, 114, 64, 44, 16, 136, 41, 44, 16, 136, 73, 44, 24, 94, 33, 38, 24, 94, 128, 38, 0, 7, 136, 30, 38, 7, 136, 64, 38, 49, 81, 30, 44, 49, 81, 64, 44, 13, 114, 78, 44, 13, 114, 70, 44, 15, 115, 30, 44, 15, 115, 64, 44, 16, 114, 33, 44, 16, 114, 128, 44, 23, 63, 41, 38, 23, 63, 73, 38]
17
[0, 1, 2, 137, 7, 115, 30, 38, 7, 115, 64, 38, 49, 81, 30, 44, 49, 81, 64, 44, 13, 81, 78, 44, 13, 81, 70, 44, 15, 114, 30, 44, 15, 114, 64, 44, 16, 136, 41, 44, 16, 136, 73, 44, 24, 

[0, 1, 50, 104, 37, 57, 8, 47, 10, 50, 104, 43, 4, 30, 59, 13, 50, 104, 16, 50, 104, 0, 1, 50, 104, 42, 55, 29, 20, 10, 50, 104, 10, 35, 18, 56, 13, 50, 104, 21, 35, 18, 56, 16, 35, 18, 9, 16, 50, 104, 16, 35, 90, 9, 23, 57, 18, 9, 24, 57, 11, 75, 0, 1, 50, 104, 37, 57, 8, 20, 42, 60, 18, 46, 10, 50, 104, 10, 35, 11, 75, 15, 35, 18, 46, 24, 61, 29, 9, 36, 55, 30, 20, 0, 1, 55, 41, 9, 37, 35, 8, 20, 7, 35, 18, 20, 42, 35, 29, 20, 10, 35, 18, 38, 28, 55, 30, 20, 13, 55, 30, 6, 15, 55, 30, 6, 21, 55, 30, 6, 23, 35, 29, 9, 24, 60, 18, 46, 36, 55, 41, 46, 0, 37, 35, 11, 6, 7, 35, 33, 56, 10, 35, 5, 6, 43, 35, 18, 46, 13, 60, 30, 20, 31, 60, 18, 12, 16, 50, 104, 23, 35, 18, 44, 36, 35, 8, 46, 36, 60, 30, 46, 0, 37, 35, 26, 44, 42, 35, 18, 44, 10, 35, 8, 44, 43, 35, 18, 46, 49, 60, 18, 44, 28, 94, 30, 6, 31, 60, 8, 12, 16, 50, 104, 21, 60, 78, 47, 23, 60, 18, 46, 23, 55, 34, 44, 36, 35, 41, 44, 36, 60, 26, 25]
26
[0, 1, 50, 186, 1, 35, 52, 56, 49, 94, 52, 20, 28, 62, 68, 20, 13, 39, 41, 19, 2

[0, 1, 50, 127, 1, 22, 32, 20, 37, 22, 68, 20, 7, 22, 69, 20, 42, 17, 53, 75, 10, 50, 127, 13, 50, 127, 31, 22, 69, 20, 15, 17, 68, 20, 21, 39, 53, 6, 16, 50, 127, 23, 22, 72, 20, 24, 39, 53, 12, 0, 1, 50, 127, 37, 4, 69, 20, 7, 17, 68, 20, 42, 39, 69, 56, 10, 50, 127, 13, 50, 127, 31, 4, 69, 6, 21, 4, 68, 6, 16, 50, 127, 23, 40, 32, 20, 24, 22, 52, 44, 0, 1, 50, 127, 1, 22, 52, 20, 37, 57, 77, 20, 7, 22, 58, 20, 42, 57, 78, 6, 10, 50, 127, 43, 22, 53, 20, 49, 39, 69, 38, 13, 50, 127, 15, 22, 69, 20, 21, 17, 53, 20, 16, 50, 127, 16, 22, 70, 20, 23, 4, 72, 20, 24, 39, 72, 46, 0, 1, 50, 127, 37, 4, 70, 20, 7, 22, 72, 20, 42, 4, 106, 48, 10, 50, 127, 16, 17, 53, 27, 0, 37, 4, 69, 20, 7, 22, 68, 20, 42, 17, 53, 27, 13, 4, 69, 20, 31, 17, 68, 20]
35
[0, 1, 50, 104, 37, 95, 34, 38, 43, 39, 70, 47, 31, 95, 72, 74, 0, 43, 95, 34, 9, 49, 17, 70, 20, 13, 95, 70, 74, 0, 43, 4, 70, 44, 28, 95, 72, 44, 31, 39, 73, 59, 0, 42, 39, 72, 44, 43, 39, 73, 44, 28, 95, 72, 44, 31, 95, 70, 59, 0, 43, 95, 34,

[0, 1, 50, 186, 1, 60, 69, 48, 10, 50, 186, 13, 50, 186, 13, 66, 26, 20, 31, 66, 26, 9, 15, 60, 90, 20, 21, 66, 98, 9, 16, 50, 186, 16, 35, 18, 44, 24, 66, 98, 9, 36, 35, 18, 9, 0, 1, 50, 186, 1, 60, 78, 44, 7, 66, 18, 9, 42, 66, 78, 9, 10, 50, 186, 10, 35, 29, 9, 43, 35, 78, 9, 49, 35, 18, 9, 28, 61, 11, 9, 13, 50, 186, 13, 60, 18, 6, 15, 66, 11, 9, 21, 35, 90, 9, 16, 50, 186, 16, 60, 11, 6, 24, 35, 90, 9, 36, 63, 26, 20, 0, 1, 50, 186, 1, 60, 11, 9, 37, 94, 18, 9, 7, 35, 11, 20, 7, 35, 90, 9, 42, 35, 26, 20, 10, 50, 186, 10, 60, 90, 6, 43, 66, 88, 6, 49, 66, 105, 6, 28, 66, 88, 9, 13, 50, 186, 13, 60, 14, 20, 31, 60, 14, 9, 15, 60, 8, 9, 21, 61, 26, 9, 16, 50, 186, 16, 35, 90, 9, 24, 60, 92, 6, 0, 1, 50, 186, 1, 61, 11, 6, 37, 60, 8, 6, 7, 61, 11, 6, 42, 61, 8, 9, 10, 50, 186, 10, 94, 11, 9, 43, 60, 8, 9, 49, 61, 11, 6, 28, 66, 8, 9, 13, 50, 186, 13, 66, 26, 9, 31, 61, 90, 9, 15, 66, 11, 9, 21, 57, 29, 9, 16, 50, 186, 16, 60, 78, 6, 23, 35, 18, 9, 24, 60, 11, 9, 36, 35, 90, 20, 0, 1,

[0, 1, 50, 127, 1, 22, 33, 20, 37, 22, 33, 20, 7, 22, 33, 20, 42, 17, 34, 20, 10, 50, 127, 13, 50, 127, 31, 22, 33, 20, 15, 40, 68, 20, 21, 39, 34, 20, 16, 50, 127, 23, 40, 41, 20, 24, 39, 54, 20, 24, 40, 70, 9, 0, 1, 50, 127, 37, 17, 33, 6, 42, 39, 34, 9, 10, 50, 127, 13, 50, 127, 31, 4, 33, 6, 15, 39, 64, 20, 15, 40, 33, 20, 21, 39, 34, 20, 16, 39, 70, 19, 0, 1, 4, 33, 9, 7, 57, 34, 48, 31, 22, 33, 20, 21, 4, 33, 20, 16, 50, 127, 49, 57, 33, 20, 23, 62, 33, 20, 36, 187, 33, 20, 0, 1, 50, 127, 1, 50, 127, 49, 57, 32, 20, 28, 22, 33, 20, 13, 55, 68, 20, 31, 22, 33, 20, 15, 57, 41, 20, 21, 57, 29, 20, 16, 94, 30, 20, 23, 57, 30, 20, 23, 57, 29, 20, 24, 57, 29, 20, 24, 57, 29, 20, 36, 57, 78, 20, 36, 39, 29, 20, 0, 1, 22, 78, 20, 37, 22, 78, 20, 7, 57, 78, 20, 42, 40, 78, 20]
57
[0, 1, 50, 127, 1, 22, 33, 20, 37, 22, 33, 20, 7, 22, 33, 20, 42, 17, 34, 20, 10, 50, 127, 13, 50, 127, 31, 22, 33, 20, 15, 40, 68, 20, 21, 39, 34, 20, 16, 50, 127, 23, 40, 41, 20, 24, 39, 54, 20, 24, 40, 70, 9, 

[0, 1, 50, 186, 1, 40, 111, 9, 1, 40, 98, 9, 37, 57, 84, 9, 37, 57, 58, 9, 7, 17, 88, 9, 7, 17, 78, 9, 42, 4, 178, 9, 42, 4, 105, 9, 10, 50, 186, 10, 22, 111, 9, 10, 22, 98, 9, 43, 22, 84, 9, 43, 22, 87, 9, 49, 57, 88, 9, 49, 57, 78, 9, 28, 17, 14, 19, 28, 17, 29, 38, 13, 50, 186, 21, 39, 98, 9, 16, 50, 186, 16, 17, 58, 9, 23, 60, 78, 20, 23, 55, 77, 20, 24, 61, 52, 20, 24, 57, 41, 9, 36, 22, 68, 9, 36, 66, 87, 48, 36, 40, 69, 25, 0, 1, 50, 186, 10, 50, 186, 43, 22, 78, 44, 28, 60, 90, 25, 28, 17, 58, 9, 13, 50, 186, 13, 4, 77, 9, 31, 4, 58, 56, 16, 50, 186, 23, 114, 26, 44, 36, 57, 105, 25, 36, 4, 68, 25, 0, 1, 50, 186, 10, 50, 186, 43, 55, 87, 44, 43, 39, 58, 44, 28, 57, 88, 96, 28, 17, 98, 9, 13, 50, 186, 13, 40, 78, 9, 31, 39, 98, 12, 16, 50, 186, 16, 4, 78, 9, 23, 62, 77, 9, 24, 4, 52, 9, 36, 60, 87, 25, 36, 4, 41, 25, 0, 1, 50, 186, 10, 50, 186, 43, 35, 88, 44, 43, 39, 98, 44, 28, 57, 84, 82, 28, 22, 11, 9, 13, 50, 186, 13, 22, 58, 9, 31, 17, 11, 38, 16, 50, 186, 23, 22, 58, 44, 

[0, 1, 50, 186, 1, 40, 111, 9, 1, 40, 98, 9, 37, 57, 84, 9, 37, 57, 58, 9, 7, 17, 88, 9, 7, 17, 78, 9, 42, 4, 26, 19, 43, 57, 52, 20, 49, 17, 78, 9, 28, 22, 29, 38, 13, 17, 77, 9, 31, 39, 5, 20, 13, 17, 41, 12, 16, 17, 30, 12, 23, 4, 18, 12, 0, 1, 4, 11, 9, 37, 17, 32, 20, 7, 17, 18, 20, 42, 4, 11, 20, 10, 4, 18, 25, 13, 50, 186, 31, 57, 5, 20, 15, 57, 26, 20, 21, 39, 5, 38, 16, 50, 186, 23, 55, 8, 20, 23, 55, 87, 20, 24, 35, 26, 9, 36, 55, 11, 20, 36, 55, 18, 9, 0, 1, 40, 11, 20, 37, 95, 18, 44, 42, 57, 26, 48, 31, 40, 8, 25, 23, 55, 18, 48, 0, 37, 57, 11, 48, 49, 39, 18, 9, 31, 22, 26, 6, 21, 17, 33, 9, 23, 40, 8, 20, 23, 55, 18, 12, 0, 42, 40, 8, 20, 42, 57, 8, 56, 43, 55, 8, 20, 28, 40, 92, 27, 21, 57, 18, 20, 16, 50, 186, 16, 22, 11, 20, 23, 60, 18, 20, 23, 35, 11, 20, 24, 57, 18, 20, 36, 22, 18, 20, 0, 1, 39, 18, 20, 37, 40, 18, 20, 37, 57, 11, 20, 42, 62, 78, 12, 10, 57, 5, 20, 43, 22, 18, 20, 43, 40, 11, 20, 49, 62, 18, 20, 28, 40, 18, 20, 31, 40, 11, 20, 13, 22, 5, 20, 31, 22,

[0, 1, 50, 127, 1, 4, 77, 9, 1, 4, 72, 9, 37, 4, 58, 20, 37, 4, 53, 20, 7, 4, 98, 20, 7, 4, 69, 20, 42, 4, 90, 20, 42, 4, 68, 20, 10, 4, 98, 38, 10, 4, 69, 38, 13, 17, 52, 6, 13, 17, 106, 6, 15, 17, 77, 12, 15, 17, 72, 12, 23, 17, 77, 20, 23, 17, 72, 20, 24, 17, 77, 6, 24, 17, 72, 6, 0, 1, 4, 52, 20, 1, 4, 106, 20, 37, 4, 41, 20, 37, 4, 73, 20, 7, 17, 68, 46, 7, 17, 107, 46, 43, 4, 77, 20, 43, 4, 72, 20, 49, 4, 78, 38, 49, 4, 70, 38, 15, 17, 41, 12, 15, 17, 73, 12, 23, 4, 77, 9, 23, 4, 72, 9, 24, 17, 41, 20, 24, 17, 73, 20, 36, 22, 68, 20, 36, 22, 107, 20, 0, 1, 17, 41, 44, 1, 17, 73, 20, 37, 4, 78, 96, 1, 39, 70, 96, 1, 17, 33, 20, 37, 4, 32, 9, 7, 4, 53, 38, 49, 4, 70, 9, 28, 4, 78, 9, 13, 39, 58, 9, 31, 4, 68, 20, 15, 4, 107, 48, 24, 4, 77, 9, 36, 4, 30, 9, 0, 1, 4, 32, 9, 37, 4, 69, 12, 1, 4, 33, 9, 37, 4, 53, 19, 28, 4, 34, 20, 13, 4, 53, 19, 13, 4, 98, 9, 31, 4, 41, 9, 31, 4, 78, 20, 15, 4, 69, 12, 24, 17, 33, 9, 24, 4, 68, 75, 0, 1, 4, 68, 44, 37, 4, 32, 9, 37, 4, 52, 9, 7, 4, 3

[0, 1, 50, 104, 37, 39, 18, 44, 7, 22, 78, 56, 31, 39, 78, 9, 15, 39, 18, 20, 15, 39, 78, 20, 21, 40, 77, 9, 16, 40, 78, 9, 23, 57, 78, 20, 24, 57, 78, 20, 24, 60, 41, 9, 36, 57, 78, 9, 0, 1, 40, 77, 38, 13, 40, 41, 6, 15, 4, 34, 20, 21, 40, 41, 20, 16, 39, 34, 20, 24, 57, 78, 9, 36, 4, 70, 20, 0, 1, 39, 30, 9, 1, 39, 34, 27, 43, 4, 77, 9, 49, 4, 78, 9, 28, 4, 30, 9, 13, 57, 78, 75, 13, 39, 34, 20, 31, 39, 34, 44, 16, 81, 41, 20, 23, 4, 68, 75, 16, 94, 18, 20, 24, 4, 34, 46, 0, 1, 39, 34, 56, 37, 40, 34, 9, 7, 39, 41, 20, 7, 39, 34, 20, 42, 39, 33, 6, 43, 4, 34, 20, 43, 39, 34, 20, 49, 39, 41, 20, 28, 4, 34, 9, 13, 39, 78, 9, 13, 39, 70, 20, 31, 4, 34, 9, 31, 4, 34, 6, 15, 39, 78, 9, 21, 39, 34, 9, 16, 39, 77, 9, 23, 4, 64, 9, 23, 4, 73, 12, 0, 1, 39, 73, 9, 37, 4, 72, 75, 10, 39, 34, 9, 7, 4, 72, 9, 42, 39, 41, 47, 43, 4, 128, 9, 43, 39, 64, 20, 28, 39, 34, 20, 13, 39, 73, 56, 31, 4, 78, 9, 31, 39, 73, 56, 31, 39, 73, 9, 15, 39, 73, 9, 21, 4, 30, 12, 16, 39, 29, 9, 23, 4, 78, 9, 24, 2

[0, 1, 50, 186, 21, 4, 58, 27, 21, 4, 53, 27, 36, 39, 77, 38, 36, 39, 72, 38, 0, 42, 57, 68, 9, 43, 17, 69, 9, 49, 4, 53, 9, 28, 95, 72, 82, 23, 95, 70, 9, 24, 57, 53, 6, 36, 95, 32, 82, 0, 28, 95, 72, 56, 23, 95, 70, 6, 24, 57, 53, 9, 36, 95, 68, 19, 0, 28, 95, 72, 56, 23, 95, 70, 9, 24, 35, 53, 9, 36, 4, 68, 96, 0, 43, 57, 70, 9, 49, 95, 70, 9, 28, 57, 70, 6, 31, 57, 70, 44, 21, 57, 53, 6, 23, 57, 53, 9, 24, 62, 70, 9, 36, 62, 68, 25, 0, 43, 17, 53, 6, 49, 55, 69, 6, 13, 62, 41, 6, 31, 57, 68, 9, 15, 55, 41, 6, 21, 17, 5, 9, 24, 62, 90, 6, 23, 95, 11, 6, 36, 17, 5, 9, 36, 35, 98, 9, 0, 43, 57, 58, 6, 49, 57, 77, 6, 28, 4, 78, 44, 13, 40, 105, 44, 31, 57, 98, 6, 16, 57, 98, 6, 24, 57, 78, 6, 36, 17, 58, 9]
98
[0, 1, 50, 104, 37, 95, 34, 38, 43, 39, 70, 47, 31, 95, 72, 74, 0, 43, 95, 34, 9, 49, 17, 70, 20, 13, 95, 70, 74, 0, 43, 4, 70, 44, 28, 95, 72, 44, 31, 39, 73, 59, 0, 42, 39, 72, 44, 43, 39, 73, 44, 28, 95, 72, 44, 31, 95, 70, 59, 0, 43, 95, 72, 44, 28, 17, 70, 6, 31, 95, 72, 59,

[0, 1, 50, 104, 37, 95, 34, 38, 43, 39, 70, 47, 31, 95, 72, 74, 0, 43, 95, 34, 9, 49, 17, 70, 20, 13, 95, 70, 74, 0, 43, 4, 70, 44, 28, 95, 72, 44, 31, 39, 73, 59, 0, 42, 39, 72, 44, 43, 39, 73, 44, 28, 95, 72, 44, 31, 95, 70, 59, 0, 43, 95, 72, 44, 28, 17, 70, 6, 31, 95, 72, 59, 0, 43, 4, 70, 44, 28, 95, 72, 6, 31, 57, 72, 74, 0, 43, 95, 34, 44, 28, 17, 70, 6, 31, 95, 34, 44, 21, 39, 73, 85, 0, 43, 95, 68, 44, 49, 39, 34, 44]
106
[0, 1, 50, 186, 1, 40, 111, 9, 1, 40, 98, 9, 37, 57, 84, 9, 37, 57, 78, 9, 7, 17, 88, 9, 7, 17, 78, 9, 42, 4, 26, 9, 42, 4, 41, 9, 10, 50, 186, 10, 22, 90, 9, 10, 22, 98, 9, 43, 22, 90, 9, 43, 22, 90, 9, 49, 57, 26, 20, 49, 57, 11, 6, 13, 17, 18, 20, 31, 17, 18, 103, 21, 39, 78, 20, 16, 50, 186, 23, 39, 90, 9, 24, 57, 78, 9, 24, 4, 18, 25, 0, 10, 50, 186, 49, 40, 77, 9, 28, 57, 26, 25, 0, 1, 4, 98, 25, 15, 62, 90, 9, 21, 40, 26, 9, 16, 4, 90, 9, 23, 22, 18, 12, 0, 1, 114, 90, 9, 37, 55, 78, 9, 7, 57, 78, 9, 42, 39, 77, 9, 10, 60, 18, 9, 43, 55, 18, 9, 49, 55,

In [24]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

25080
25080
24420
24420
29880
29880
19680
19680
23760
23760
38040
38040
35820
35820
40440
40440
22800
22800
14220
14220
20160
20160
24000
24000
25380
25380
21240
21240
25800
25800
23400
23400
25320
25320
23580
23580
40620
40620
17700
17700
25620
25620
18780
18780


In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
vocab_folder = "vocab/"
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab  + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def load_vocab(path):
    with open(path, 'rb') as f:
        x = pickle.load(f)
    return x

In [None]:
vocab_folder = "vocab_intro/"
intro_field.vocab = load_vocab(vocab_folder + 'intro_vocab.pkl')
solo_field.vocab = load_vocab(vocab_folder + 'solo_vocab.pkl')
outro_field.vocab = load_vocab(vocab_folder + 'outro_vocab.pkl')

In [None]:
vocab_folder = "vocab/"
with open()
intro_field.vocab = 

In [None]:
beam = [2, 59, 119, 13, 212, 59, 212, 59, 59, 75, 59, 59, 13, 119, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 119, 59, 59, 59, 59, 166, 59, 59, 59, 13, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 59, 59, 212, 212, 13, 59, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 14, 59, 59, 212, 59, 212, 212, 59, 59, 59, 68, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 68, 59, 59, 212, 59, 59, 13, 59, 59, 59, 59, 59, 59, 97, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 158, 59, 59, 13, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 13, 212, 13, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 158, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 13, 13, 59, 59, 13, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 13, 59, 166, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 86, 59, 212, 212, 212, 59, 59, 59, 59, 212, 59, 86, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 212, 59, 59, 166, 59, 86, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 59, 59, 59, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 212, 212, 158, 59, 59, 13, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 158, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 212, 212, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 13, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 212, 86, 59, 59, 59, 158, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 212, 59, 13, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 68, 59, 13, 59, 59, 13, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 14, 59, 59, 59, 59, 59, 59, 13, 86, 59, 59, 59, 212, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 13, 59, 59, 59, 59, 68, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 212, 212, 13, 59, 166, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 212, 59, 59, 212, 59, 212, 59, 59, 13, 59, 59, 59, 59, 13, 59, 59, 14, 13, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 68, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 212, 59, 59, 13, 59, 59, 166, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 13, 59, 59, 212, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59]

translated_sentence1 = [solo_field.vocab.itos[idx] for idx in beam]
translated_sentence = [int(x) for x in translated_sentence1 if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']    
utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict1.mid")

In [None]:
translated_sentence

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    list_sentence = [int(x) for x in sentence.split(' ')]
    remi = [word2event[x] for x in list_sentence]
    print(remi)

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns