In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_fixed"
folder = "dynamic_fixed_models/outro"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [2]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [3]:
event2word, word2event = pickle.load(open('dictionary_fixed.pkl', 'rb'))

In [4]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [5]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.outro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=1)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=1)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=1)

In [6]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    print(solo.transpose(1,0).size())

torch.Size([522, 8])
torch.Size([353, 8])
torch.Size([281, 8])
torch.Size([331, 8])
torch.Size([509, 8])
torch.Size([353, 8])
torch.Size([473, 8])
torch.Size([319, 8])
torch.Size([325, 8])
torch.Size([491, 8])
torch.Size([619, 8])
torch.Size([272, 8])
torch.Size([294, 8])
torch.Size([365, 8])


In [7]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [8]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

In [9]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [10]:
src_vocab_size = len(outro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [11]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4)#4e-5


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 11,073,284 trainable parameters


In [12]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [13]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [14]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, trg = outro.transpose(1,0), solo.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, trg = outro.transpose(1,0), solo.transpose(1,0)
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [15]:
def translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [16]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [17]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(10):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, outro, outro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [None]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+2) % 20 ==0:
        if check_mode_collapse(model) > 1:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 0m 6s
	Train Loss: 5.343 | Train PPL: 209.055
	 Val. Loss: 5.134 |  Val. PPL: 169.686
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 0m 6s
	Train Loss: 4.812 | Train PPL: 123.012
	 Val. Loss: 4.595 |  Val. PPL:  98.948
=> Saving checkpoint
Epoch: 03 | Time: 0m 6s
	Train Loss: 4.392 | Train PPL:  80.810
	 Val. Loss: 4.294 |  Val. PPL:  73.264
=> Saving checkpoint
Epoch: 04 | Time: 0m 6s
	Train Loss: 4.068 | Train PPL:  58.442
	 Val. Loss: 3.950 |  Val. PPL:  51.937
=> Saving checkpoint
Epoch: 05 | Time: 0m 6s
	Train Loss: 3.715 | Train PPL:  41.079
	 Val. Loss: 3.626 |  Val. PPL:  37.568
=> Saving checkpoint
Epoch: 06 | Time: 0m 6s
	Train Loss: 3.435 | Train PPL:  31.017
	 Val. Loss: 3.395 |  Val. PPL:  29.819
=> Saving checkpoint
Epoch: 07 | Time: 0m 6s
	Train Loss: 3.222 | Train PPL:  25.088
	 Val. Loss: 3.214 |  Val. PPL:  24.876
=> Saving checkpoint
Epoch: 08 | Time: 0m 6s
	Train Loss: 3.079 | Train PPL:  21.745
	 Val. Loss: 3.111 |  Val. PPL:  22.443


[0, 1, 50, 186, 1, 57, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 55, 77, 9, 1, 55, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 6

[0, 1, 50, 186, 1, 50, 186, 1, 50, 50, 186, 1, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5

[0, 1, 50, 186, 1, 57, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 55, 77, 9, 1, 55, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 62, 77, 9, 1, 6

[0, 1, 50, 186, 1, 50, 186, 1, 50, 50, 186, 1, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 5

Epoch: 20 | Time: 0m 6s
	Train Loss: 2.427 | Train PPL:  11.327
	 Val. Loss: 2.500 |  Val. PPL:  12.185
=> Saving checkpoint
=> Saving checkpoint
Epoch: 21 | Time: 0m 6s
	Train Loss: 2.392 | Train PPL:  10.933
	 Val. Loss: 2.465 |  Val. PPL:  11.766
=> Saving checkpoint
=> Saving checkpoint
Epoch: 22 | Time: 0m 6s
	Train Loss: 2.357 | Train PPL:  10.556
	 Val. Loss: 2.445 |  Val. PPL:  11.527
=> Saving checkpoint
Epoch: 23 | Time: 0m 6s
	Train Loss: 2.322 | Train PPL:  10.198
	 Val. Loss: 2.415 |  Val. PPL:  11.192
=> Saving checkpoint
Epoch: 24 | Time: 0m 6s
	Train Loss: 2.292 | Train PPL:   9.892
	 Val. Loss: 2.396 |  Val. PPL:  10.984
=> Saving checkpoint
Epoch: 25 | Time: 0m 6s
	Train Loss: 2.266 | Train PPL:   9.643
	 Val. Loss: 2.394 |  Val. PPL:  10.961
=> Saving checkpoint
Epoch: 26 | Time: 0m 6s
	Train Loss: 2.242 | Train PPL:   9.411
	 Val. Loss: 2.359 |  Val. PPL:  10.576
=> Saving checkpoint
Epoch: 27 | Time: 0m 6s
	Train Loss: 2.214 | Train PPL:   9.149
	 Val. Loss: 2.354 

[0, 1, 50, 171, 1, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 49, 55, 77, 44, 49, 55, 77, 44, 10, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 4

[0, 1, 50, 171, 1, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 42, 55, 52, 9, 42, 55, 52, 9, 10, 55, 52, 9, 43, 55, 52, 9, 28, 55, 52, 9, 31, 55, 52, 9, 21, 55, 52, 9, 16, 55, 52, 9, 36, 55, 52, 9, 0, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 12, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55, 52, 9, 37, 55

[0, 1, 50, 171, 1, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 49, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 10, 55, 77, 44, 13, 55, 77, 44, 13, 55, 77, 44, 15, 55, 77, 44, 16, 55, 77, 44, 24, 55, 77, 44, 24, 55, 77, 44, 0, 1, 55, 77, 44, 7, 55, 77, 44, 7, 55, 77, 44, 7, 55,

[0, 1, 50, 127, 1, 4, 33, 9, 37, 4, 33, 9, 7, 4, 33, 9, 42, 4, 33, 9, 10, 4, 54, 9, 43, 4, 33, 9, 49, 4, 33, 9, 28, 4, 33, 9, 13, 4, 33, 9, 31, 4, 33, 9, 15, 4, 33, 9, 16, 4, 33, 9, 23, 4, 33, 9, 24, 4, 33, 9, 36, 4, 33, 9, 0, 1, 4, 33, 9, 37, 4, 33, 9, 37, 4, 33, 9, 7, 4, 33, 9, 42, 4, 33, 9, 10, 4, 33, 9, 43, 4, 54, 9, 49, 4, 33, 9, 28, 4, 54, 9, 13, 4, 33, 9, 31, 4, 54, 9, 15, 4, 33, 9, 21, 4, 54, 9, 16, 4, 33, 9, 23, 4, 54, 9, 24, 4, 33, 9, 36, 4, 33, 9, 0, 1, 4, 33, 9, 37, 4, 54, 9, 37, 4, 33, 9, 7, 4, 33, 9, 42, 4, 33, 9, 10, 4, 33, 9, 43, 4, 54, 9, 49, 4, 33, 9, 28, 4, 54, 9, 13, 4, 54, 9, 31, 4, 33, 9, 15, 4, 54, 9, 21, 4, 33, 9, 16, 4, 54, 9, 23, 4, 54, 9, 24, 4, 54, 9, 36, 4, 54, 9, 0, 1, 4, 54, 9, 37, 4, 54, 9, 37, 4, 54, 9, 7, 4, 54, 9, 42, 4, 54, 9, 10, 4, 54, 9, 10, 4, 54, 44, 43, 4, 33, 9, 43, 4, 33, 9, 49, 4, 54, 9, 49, 4, 54, 9, 28, 4, 33, 9, 28, 4, 54, 9, 13, 4, 54, 9, 31, 4, 54, 9, 31, 4, 54, 9, 15, 4, 33, 9, 21, 4, 54, 9, 23, 4, 33, 9, 23, 4, 33, 9, 24, 4, 54, 9, 36

[0, 1, 50, 127, 1, 60, 77, 9, 37, 60, 77, 9, 37, 60, 77, 9, 7, 60, 77, 9, 42, 60, 77, 9, 10, 60, 77, 9, 43, 60, 77, 9, 49, 60, 77, 9, 28, 60, 77, 9, 13, 60, 77, 9, 31, 60, 77, 9, 15, 60, 77, 9, 21, 60, 77, 9, 16, 60, 77, 9, 23, 60, 77, 9, 23, 60, 77, 9, 24, 60, 77, 9, 36, 60, 77, 9, 0, 1, 60, 77, 9, 37, 60, 77, 9, 7, 60, 77, 9, 42, 60, 77, 9, 10, 60, 77, 9, 43, 60, 77, 9, 49, 60, 77, 9, 28, 60, 77, 9, 13, 60, 77, 9, 31, 60, 77, 9, 15, 60, 77, 9, 21, 60, 77, 9, 16, 60, 77, 9, 23, 60, 77, 9, 24, 60, 77, 9, 36, 60, 77, 9, 0, 1, 60, 77, 9, 37, 60, 77, 9, 7, 60, 77, 9, 42, 60, 77, 9, 10, 60, 77, 9, 10, 60, 77, 9, 43, 60, 77, 9, 43, 60, 77, 9, 49, 60, 77, 9, 28, 60, 77, 9, 13, 60, 77, 9, 31, 60, 77, 9, 15, 60, 77, 9, 21, 60, 77, 9, 23, 60, 77, 9, 24, 60, 77, 9, 36, 60, 77, 9, 0, 1, 60, 77, 9, 37, 60, 77, 9, 7, 60, 77, 9, 42, 60, 77, 9, 42, 60, 77, 9, 10, 60, 77, 9, 43, 60, 77, 9, 43, 60, 77, 9, 43, 60, 77, 9, 49, 60, 77, 9, 28, 60, 77, 9, 13, 60, 77, 9, 31, 60, 77, 9, 31, 60, 77, 9, 15, 60, 

Epoch: 60 | Time: 0m 6s
	Train Loss: 1.717 | Train PPL:   5.566
	 Val. Loss: 2.197 |  Val. PPL:   8.998
=> Saving checkpoint
Epoch: 61 | Time: 0m 6s
	Train Loss: 1.707 | Train PPL:   5.514
	 Val. Loss: 2.192 |  Val. PPL:   8.956
=> Saving checkpoint
Epoch: 62 | Time: 0m 6s
	Train Loss: 1.697 | Train PPL:   5.460
	 Val. Loss: 2.201 |  Val. PPL:   9.036
Epoch: 63 | Time: 0m 6s
	Train Loss: 1.689 | Train PPL:   5.413
	 Val. Loss: 2.187 |  Val. PPL:   8.906
Epoch: 64 | Time: 0m 6s
	Train Loss: 1.678 | Train PPL:   5.353
	 Val. Loss: 2.199 |  Val. PPL:   9.017
Epoch: 65 | Time: 0m 6s
	Train Loss: 1.667 | Train PPL:   5.296
	 Val. Loss: 2.208 |  Val. PPL:   9.101
Epoch: 66 | Time: 0m 6s
	Train Loss: 1.660 | Train PPL:   5.261
	 Val. Loss: 2.216 |  Val. PPL:   9.171
Epoch: 67 | Time: 0m 6s
	Train Loss: 1.652 | Train PPL:   5.217
	 Val. Loss: 2.232 |  Val. PPL:   9.314
Epoch: 68 | Time: 0m 6s
	Train Loss: 1.639 | Train PPL:   5.149
	 Val. Loss: 2.243 |  Val. PPL:   9.417
Epoch: 69 | Time: 0m 6

[0, 1, 50, 127, 37, 4, 32, 38, 43, 4, 32, 38, 28, 4, 32, 38, 31, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 28, 4, 32, 38, 31, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 13, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 28, 4, 32, 38, 31, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 43, 4, 32, 38, 28, 4, 32, 38, 31, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 43, 4, 32, 38, 31, 4, 32, 38, 23, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 43, 4, 32, 38, 43, 4, 32, 38, 28, 4, 32, 38, 31, 4, 32, 38, 0, 37, 4, 32, 38, 43, 4, 32, 38, 43, 4, 32, 38, 28, 4, 32, 38, 31, 4, 32, 38, 31, 4, 32, 38]
[0, 1, 2, 137, 10, 2, 137, 13, 2, 137, 21, 55, 58, 9, 16, 2, 137, 0, 1, 2, 137, 37, 55, 58, 9, 7, 55, 58, 9, 10, 2, 137, 13, 2, 137, 16, 2, 137, 23, 55, 58, 9, 24, 55, 58, 9, 36, 55, 58, 9, 0, 1, 2, 137, 37, 55, 58, 9, 7, 55, 58, 44, 10, 2, 137, 13, 2, 137, 15, 55, 58, 9, 21, 55, 58, 9, 16, 2, 137, 24, 55, 58, 9, 36, 55, 58, 9, 0, 1, 2, 137, 7, 5

[0, 1, 50, 102, 1, 63, 58, 9, 37, 94, 58, 9, 7, 60, 58, 9, 42, 61, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 77, 9, 28, 60, 58, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58, 9, 24, 60, 58, 9, 36, 60, 58, 9, 0, 1, 60, 58, 9, 37, 60, 58, 9, 7, 60, 58, 9, 42, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 28, 60, 58, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58, 9, 24, 60, 58, 9, 36, 60, 58, 9, 0, 1, 60, 58, 9, 37, 60, 58, 9, 7, 60, 58, 9, 42, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 58, 9, 28, 60, 58, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58, 9, 24, 60, 58, 9, 36, 60, 58, 9, 0, 1, 60, 58, 9, 7, 60, 58, 9, 7, 60, 58, 9, 42, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 58, 9, 28, 60, 58, 9, 28, 60, 58, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58, 9, 23, 60, 58, 9, 24, 60, 58, 9, 36, 60, 5

[0, 1, 50, 102, 1, 60, 58, 9, 7, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 58, 9, 28, 60, 58, 9, 13, 60, 98, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58, 9, 24, 60, 58, 9, 36, 60, 58, 9, 0, 1, 60, 58, 9, 37, 60, 58, 9, 7, 60, 58, 9, 42, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 58, 9, 28, 60, 58, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58, 9, 24, 60, 58, 9, 36, 60, 58, 9, 0, 1, 60, 58, 9, 7, 60, 58, 9, 7, 60, 58, 9, 42, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 58, 9, 28, 60, 58, 9, 13, 60, 58, 9, 31, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 24, 60, 58, 9, 24, 60, 58, 9, 36, 60, 58, 9, 0, 1, 60, 58, 9, 7, 60, 58, 9, 7, 60, 58, 9, 7, 60, 58, 9, 42, 60, 58, 9, 10, 60, 58, 9, 10, 60, 58, 9, 43, 60, 58, 9, 49, 60, 58, 9, 49, 60, 58, 9, 28, 60, 58, 9, 13, 60, 58, 9, 13, 60, 58, 9, 13, 60, 58, 9, 15, 60, 58, 9, 15, 60, 58, 9, 21, 60, 58, 9, 16, 60, 58, 9, 23, 60, 58,

[0, 1, 50, 171, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 98, 44, 13, 60, 98, 44, 15, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 98, 44, 13, 60, 98, 44, 15, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 98, 44, 13, 60, 98, 44, 15, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 98, 44, 13, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 98, 44, 13, 60, 98, 44, 15, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 13, 60, 98, 44, 13, 60, 98, 44, 15, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 98, 44, 13, 60, 98, 44, 15, 60, 98, 44, 16, 60, 98, 44, 24, 60, 98, 44, 0, 1, 60, 98, 44, 7, 60, 98, 44, 10, 60, 98, 44, 49, 60, 5, 44, 13, 60, 98, 44, 13, 60, 105, 44, 16,

[0, 1, 50, 171, 10, 50, 171, 10, 50, 171, 10, 50, 171, 10, 40, 32, 44, 49, 40, 32, 6, 13, 50, 171, 13, 40, 32, 9, 31, 40, 32, 6, 15, 40, 32, 9, 21, 40, 32, 9, 16, 50, 171, 16, 40, 32, 9, 23, 40, 32, 9, 24, 40, 32, 9, 36, 40, 32, 9, 0, 1, 50, 171, 1, 50, 171, 1, 50, 171, 1, 50, 171, 1, 40, 32, 9, 37, 40, 32, 9, 7, 40, 32, 9, 42, 40, 32, 9, 10, 40, 32, 9, 10, 50, 171, 10, 50, 171, 10, 50, 171, 10, 40, 32, 9, 43, 40, 32, 9, 49, 40, 32, 9, 28, 40, 32, 9, 13, 50, 171, 16, 50, 171, 16, 40, 32, 9, 23, 40, 32, 9, 24, 40, 32, 9, 36, 40, 32, 9, 0, 1, 50, 171, 1, 40, 32, 9, 37, 40, 32, 9, 7, 40, 32, 9, 42, 40, 32, 9, 10, 50, 171, 10, 50, 171, 10, 50, 171, 10, 40, 32, 9, 43, 40, 32, 9, 49, 40, 65, 9, 49, 40, 65, 9, 13, 50, 171, 13, 50, 171, 16, 50, 171, 16, 50, 171, 16, 50, 171, 16, 50, 171, 16, 50, 171, 16, 50, 171, 16, 50, 171, 16, 50, 171, 16, 40, 32, 9, 23, 60, 32, 9, 23, 40, 32, 9, 24, 40, 32, 9, 36, 40, 32, 9, 36, 40, 32, 9, 0, 1, 50, 171, 1, 50, 171, 1, 50, 171, 1, 50, 171, 1, 50, 171, 1, 5

[0, 1, 2, 127, 43, 57, 69, 44, 31, 57, 53, 44, 21, 57, 69, 44, 23, 60, 69, 48, 0, 43, 61, 53, 44, 31, 55, 69, 44, 21, 60, 53, 44, 23, 55, 69, 44, 36, 55, 69, 44, 0, 37, 55, 69, 44, 42, 60, 53, 44, 43, 55, 69, 44, 28, 55, 53, 44, 31, 55, 69, 44, 21, 55, 69, 44, 23, 55, 69, 44, 36, 55, 68, 44, 0, 37, 62, 52, 85, 0, 37, 62, 69, 44, 42, 55, 68, 44, 43, 55, 52, 44, 28, 62, 32, 44, 31, 55, 32, 44, 21, 55, 52, 44, 23, 55, 52, 44, 36, 55, 52, 44, 0, 37, 57, 52, 44, 42, 55, 52, 44, 43, 55, 52, 44, 28, 66, 52, 44, 31, 55, 77, 44, 21, 55, 52, 44, 23, 62, 52, 44, 36, 55, 52, 44, 0, 37, 55, 52, 44, 42, 55, 52, 44, 43, 55, 52, 44, 28, 55, 52, 44, 31, 55, 32, 44, 21, 55, 77, 85, 0, 23, 55, 52, 44, 23, 55, 52, 44, 36, 55, 52, 9, 0, 37, 62, 77, 44, 37, 55, 52, 44, 42, 55, 52, 44, 43, 55, 52, 44, 23, 55, 52, 44]
[0, 1, 50, 127, 43, 4, 5, 44, 28, 4, 58, 44, 31, 4, 58, 44, 21, 4, 53, 44, 23, 4, 53, 44, 36, 4, 53, 44, 0, 37, 40, 53, 44, 37, 40, 53, 44, 42, 40, 53, 44, 43, 40, 53, 44, 28, 40, 53, 44, 31, 40

[0, 1, 2, 137, 10, 2, 137, 13, 2, 137, 21, 62, 52, 96, 16, 2, 137, 0, 1, 2, 137, 37, 62, 77, 27, 10, 2, 137, 28, 62, 52, 27, 13, 2, 137, 21, 95, 77, 38, 16, 2, 137, 0, 1, 2, 137, 37, 62, 52, 38, 10, 2, 137, 28, 95, 69, 44, 31, 95, 70, 27, 0, 37, 62, 77, 44, 42, 55, 77, 38, 28, 62, 77, 38, 13, 55, 52, 38, 0, 1, 2, 137, 37, 62, 52, 38, 10, 2, 137, 49, 66, 77, 44, 13, 2, 137, 21, 55, 52, 38, 16, 2, 137]
[0, 1, 50, 127, 43, 4, 72, 9, 49, 4, 106, 6, 28, 4, 72, 9, 31, 4, 106, 20, 15, 4, 106, 20, 15, 4, 72, 20, 21, 4, 72, 20, 23, 4, 69, 20, 24, 4, 106, 20, 36, 4, 72, 20, 0, 37, 4, 108, 20, 7, 4, 108, 20, 42, 4, 106, 20, 10, 4, 106, 20, 49, 4, 106, 20, 28, 4, 72, 20, 13, 39, 69, 20, 31, 4, 69, 20, 15, 39, 69, 20, 15, 39, 69, 6, 21, 39, 69, 6, 16, 4, 69, 20, 23, 4, 53, 20, 23, 4, 69, 20, 24, 39, 69, 20, 36, 4, 53, 20, 36, 39, 69, 20, 0, 1, 4, 69, 20, 37, 22, 69, 20, 7, 4, 53, 20, 42, 4, 69, 20, 42, 4, 53, 38, 28, 4, 53, 47, 21, 17, 69, 47, 23, 17, 69, 20, 24, 17, 69, 20, 36, 39, 53, 20, 36, 39,

[0, 1, 50, 186, 1, 60, 58, 56, 10, 50, 186, 49, 61, 98, 44, 13, 50, 186, 13, 94, 90, 44, 15, 94, 90, 47, 16, 50, 186, 16, 94, 98, 44, 24, 35, 58, 44, 0, 1, 50, 186, 1, 61, 90, 25, 10, 50, 186, 13, 50, 186, 16, 50, 186, 24, 94, 90, 56, 0, 1, 50, 186, 10, 50, 186, 13, 50, 186, 13, 50, 186, 13, 50, 186, 31, 61, 58, 44, 21, 35, 90, 44, 16, 50, 186, 16, 50, 186, 16, 50, 186, 24, 57, 58, 44, 0, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 10, 50, 186, 13, 50, 186, 13, 50, 186, 13, 50, 186, 16, 50, 186, 16, 50, 186, 16, 50, 186, 24, 57, 58, 9, 36, 61, 58, 44, 0, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 10, 50, 186, 10, 50, 186, 10, 50, 186, 49, 35, 58, 44, 28, 57, 53, 56, 13, 50, 186, 13, 50, 186, 13, 50, 186, 13, 50, 186, 13, 50, 186, 16, 50, 186, 16, 50, 186, 16, 50, 186, 16, 50, 186, 0, 1, 50, 186, 24, 57, 58, 9, 36, 114, 8, 9, 37, 4, 14, 20, 1, 50, 186, 1, 50, 186, 1, 50, 186, 1, 50, 186, 10, 50, 186, 49, 57, 30, 56, 13

Epoch: 160 | Time: 0m 6s
	Train Loss: 1.027 | Train PPL:   2.794
	 Val. Loss: 2.972 |  Val. PPL:  19.525
=> Saving checkpoint
Epoch: 161 | Time: 0m 6s
	Train Loss: 1.018 | Train PPL:   2.768
	 Val. Loss: 3.007 |  Val. PPL:  20.233
=> Saving checkpoint
Epoch: 162 | Time: 0m 6s
	Train Loss: 1.013 | Train PPL:   2.754
	 Val. Loss: 2.972 |  Val. PPL:  19.524
Epoch: 163 | Time: 0m 6s
	Train Loss: 1.011 | Train PPL:   2.748
	 Val. Loss: 3.005 |  Val. PPL:  20.191
Epoch: 164 | Time: 0m 6s
	Train Loss: 1.011 | Train PPL:   2.749
	 Val. Loss: 3.039 |  Val. PPL:  20.874
Epoch: 165 | Time: 0m 6s
	Train Loss: 1.005 | Train PPL:   2.732
	 Val. Loss: 3.008 |  Val. PPL:  20.238
Epoch: 166 | Time: 0m 6s
	Train Loss: 0.995 | Train PPL:   2.704
	 Val. Loss: 3.041 |  Val. PPL:  20.927
Epoch: 167 | Time: 0m 6s
	Train Loss: 0.988 | Train PPL:   2.687
	 Val. Loss: 3.050 |  Val. PPL:  21.109
Epoch: 168 | Time: 0m 6s
	Train Loss: 0.981 | Train PPL:   2.668
	 Val. Loss: 3.079 |  Val. PPL:  21.740
Epoch: 169 | 

[0, 1, 50, 51, 1, 55, 8, 9, 37, 94, 26, 9, 7, 66, 26, 9, 42, 94, 18, 9, 10, 94, 11, 46, 43, 66, 18, 9, 28, 94, 11, 9, 13, 94, 30, 9, 31, 61, 18, 38, 16, 55, 30, 9, 23, 63, 11, 9, 24, 61, 18, 9, 36, 61, 11, 9, 0, 1, 63, 11, 46, 37, 61, 18, 44, 42, 61, 18, 46, 43, 61, 11, 9, 49, 61, 11, 38, 31, 61, 18, 44, 21, 61, 29, 9, 23, 61, 30, 9, 23, 61, 11, 9, 24, 61, 78, 9, 36, 61, 30, 9, 0, 1, 66, 29, 9, 37, 60, 30, 9, 7, 61, 29, 9, 42, 61, 30, 9, 10, 35, 30, 9, 43, 35, 29, 9, 49, 66, 30, 9, 28, 61, 30, 9, 13, 94, 30, 9, 31, 55, 30, 9, 15, 94, 30, 9, 21, 35, 29, 9, 16, 60, 30, 9, 23, 60, 29, 9, 24, 61, 41, 9, 36, 61, 30, 9, 0, 1, 61, 30, 12, 37, 94, 30, 9, 7, 61, 41, 9, 42, 61, 30, 9, 10, 35, 33, 44, 43, 66, 29, 9, 49, 61, 11, 9, 28, 61, 29, 25, 21, 94, 18, 12, 28, 94, 30, 9, 13, 60, 29, 9, 31, 61, 29, 9, 15, 94, 30, 9, 21, 35, 29, 9, 16, 66, 29, 9, 23, 61, 30, 9, 24, 94, 29, 9, 36, 61, 18, 9, 36, 61, 30, 9]
[0, 1, 2, 137, 10, 2, 137, 10, 55, 105, 38, 10, 55, 90, 46, 49, 55, 87, 44, 13, 2, 137, 

[0, 1, 50, 186, 10, 35, 111, 9, 43, 62, 84, 9, 49, 62, 14, 75, 24, 57, 14, 9, 36, 62, 5, 6, 0, 1, 35, 87, 9, 37, 35, 105, 9, 7, 35, 14, 38, 49, 66, 84, 9, 28, 62, 14, 9, 13, 35, 111, 9, 31, 66, 84, 9, 15, 66, 84, 9, 21, 57, 14, 75, 0, 1, 94, 14, 9, 37, 94, 84, 9, 7, 62, 5, 9, 10, 57, 87, 9, 43, 66, 105, 9, 49, 35, 14, 9, 28, 62, 105, 9, 13, 57, 5, 9, 31, 35, 14, 9, 15, 57, 5, 9, 21, 57, 32, 9, 16, 57, 32, 9, 24, 57, 52, 9, 36, 57, 32, 9, 0, 1, 35, 52, 9, 37, 35, 32, 9, 7, 62, 32, 9, 42, 57, 33, 38, 43, 57, 84, 9, 49, 62, 69, 9, 28, 57, 33, 9, 13, 57, 53, 9, 15, 35, 69, 9, 21, 35, 32, 9, 16, 57, 68, 9, 23, 35, 32, 9, 24, 57, 32, 9, 36, 35, 68, 9, 0, 1, 35, 32, 9, 7, 62, 52, 9, 10, 57, 32, 9, 43, 35, 32, 9, 49, 57, 77, 9, 28, 57, 32, 9, 28, 35, 52, 9, 13, 61, 32, 9, 31, 57, 69, 9, 31, 35, 53, 9, 15, 57, 32, 9, 15, 57, 53, 9, 21, 35, 53, 9, 21, 35, 54, 9, 23, 35, 72, 9, 24, 35, 69, 9, 36, 57, 65, 9, 0, 1, 35, 106, 44, 7, 62, 72, 38, 10, 50, 186, 49, 35, 72, 9, 28, 35, 53, 9, 13, 35, 72, 9

[0, 1, 50, 172, 10, 61, 69, 9, 43, 94, 53, 9, 49, 55, 32, 75, 0, 49, 94, 53, 9, 28, 94, 69, 9, 13, 94, 53, 20, 31, 61, 32, 27, 0, 37, 61, 69, 9, 43, 61, 32, 47, 49, 61, 32, 20, 28, 94, 53, 103, 36, 61, 32, 38, 0, 42, 61, 52, 9, 43, 61, 77, 9, 49, 61, 32, 9, 28, 61, 52, 9, 13, 61, 32, 9, 15, 61, 52, 103, 0, 43, 61, 77, 9, 49, 81, 32, 9, 28, 61, 32, 9, 28, 61, 52, 9, 31, 60, 52, 9, 21, 61, 32, 9, 16, 61, 77, 9, 23, 61, 52, 9, 24, 61, 77, 9, 36, 61, 32, 9, 0, 1, 61, 52, 9, 37, 61, 52, 9, 7, 61, 32, 9, 42, 94, 52, 9, 43, 61, 32, 9, 49, 61, 32, 9, 28, 94, 52, 9, 28, 94, 32, 44, 31, 61, 77, 9, 15, 61, 77, 9, 21, 61, 52, 9, 16, 61, 32, 9, 23, 61, 29, 9, 24, 61, 32, 9, 36, 61, 52, 9, 0, 1, 61, 77, 9, 37, 61, 52, 9, 7, 61, 77, 9, 42, 61, 77, 9, 43, 94, 77, 9, 49, 61, 29, 9, 28, 94, 77, 9, 28, 94, 58, 9, 13, 94, 98, 9, 31, 61, 32, 9, 15, 61, 52, 9, 15, 61, 77, 9, 21, 61, 32, 9, 23, 61, 32, 9, 24, 61, 77, 9, 36, 61, 52, 9, 0, 37, 61, 32, 9, 7, 61, 77, 9, 42, 61, 32, 9, 43, 61, 32, 44, 43, 94, 52,

Epoch: 209 | Time: 0m 6s
	Train Loss: 0.785 | Train PPL:   2.192
	 Val. Loss: 3.560 |  Val. PPL:  35.158
Epoch: 210 | Time: 0m 6s
	Train Loss: 0.792 | Train PPL:   2.209
	 Val. Loss: 3.563 |  Val. PPL:  35.286
Epoch: 211 | Time: 0m 6s
	Train Loss: 0.775 | Train PPL:   2.170
	 Val. Loss: 3.512 |  Val. PPL:  33.515
Epoch: 212 | Time: 0m 6s
	Train Loss: 0.770 | Train PPL:   2.159
	 Val. Loss: 3.573 |  Val. PPL:  35.612
Epoch: 213 | Time: 0m 6s
	Train Loss: 0.769 | Train PPL:   2.158
	 Val. Loss: 3.554 |  Val. PPL:  34.969
Epoch: 214 | Time: 0m 6s
	Train Loss: 0.762 | Train PPL:   2.142
	 Val. Loss: 3.598 |  Val. PPL:  36.524
Epoch: 215 | Time: 0m 6s
	Train Loss: 0.768 | Train PPL:   2.155
	 Val. Loss: 3.599 |  Val. PPL:  36.577
Epoch: 216 | Time: 0m 6s
	Train Loss: 0.755 | Train PPL:   2.127
	 Val. Loss: 3.592 |  Val. PPL:  36.295
Epoch: 217 | Time: 0m 6s
	Train Loss: 0.757 | Train PPL:   2.131
	 Val. Loss: 3.598 |  Val. PPL:  36.516
Epoch: 218 | Time: 0m 6s
	Train Loss: 0.749 | Train PPL

[0, 1, 50, 67, 10, 50, 67, 49, 62, 92, 44, 13, 50, 67, 13, 62, 88, 44, 15, 62, 8, 6, 16, 50, 67, 16, 62, 8, 44, 24, 62, 8, 44, 0, 1, 50, 67, 1, 62, 8, 44, 7, 62, 26, 44, 10, 50, 67, 10, 62, 8, 44, 49, 62, 11, 44, 13, 50, 67, 13, 62, 11, 44, 15, 62, 11, 38, 16, 50, 67, 16, 62, 8, 44, 24, 62, 14, 38, 0, 1, 50, 67, 1, 62, 5, 44, 7, 62, 11, 44, 10, 50, 67, 10, 35, 8, 44, 49, 62, 5, 44, 13, 50, 67, 13, 50, 67, 13, 62, 11, 44, 15, 62, 30, 6, 16, 35, 11, 45, 10, 50, 67, 16, 50, 67, 0, 1, 62, 8, 44, 7, 62, 11, 44, 10, 62, 26, 99]
[0, 1, 50, 186, 1, 60, 58, 6, 7, 55, 53, 47, 10, 50, 186, 49, 61, 69, 96, 13, 50, 186, 16, 50, 186, 24, 61, 53, 9, 36, 60, 32, 75, 0, 37, 94, 69, 12, 10, 50, 186, 10, 50, 186, 49, 61, 58, 9, 28, 61, 53, 56, 16, 50, 186, 24, 61, 58, 6, 36, 61, 32, 20, 36, 60, 30, 9, 0, 1, 50, 186, 1, 50, 186, 37, 61, 32, 46, 10, 50, 186, 49, 61, 58, 9, 28, 94, 30, 9, 28, 61, 32, 9, 13, 94, 52, 9, 13, 61, 77, 9, 31, 61, 32, 56, 16, 50, 186, 23, 61, 52, 9, 24, 94, 32, 56, 0, 1, 50, 186, 

Epoch: 220 | Time: 0m 6s
	Train Loss: 0.741 | Train PPL:   2.098
	 Val. Loss: 3.625 |  Val. PPL:  37.519
=> Saving checkpoint
Epoch: 221 | Time: 0m 6s
	Train Loss: 0.746 | Train PPL:   2.108
	 Val. Loss: 3.631 |  Val. PPL:  37.756
=> Saving checkpoint
Epoch: 222 | Time: 0m 6s
	Train Loss: 0.740 | Train PPL:   2.097
	 Val. Loss: 3.700 |  Val. PPL:  40.438
Epoch: 223 | Time: 0m 6s
	Train Loss: 0.730 | Train PPL:   2.076
	 Val. Loss: 3.685 |  Val. PPL:  39.854
Epoch: 224 | Time: 0m 6s
	Train Loss: 0.727 | Train PPL:   2.068
	 Val. Loss: 3.638 |  Val. PPL:  38.023
Epoch: 225 | Time: 0m 6s
	Train Loss: 0.722 | Train PPL:   2.059
	 Val. Loss: 3.675 |  Val. PPL:  39.440
Epoch: 226 | Time: 0m 6s
	Train Loss: 0.715 | Train PPL:   2.043
	 Val. Loss: 3.704 |  Val. PPL:  40.604
Epoch: 227 | Time: 0m 6s
	Train Loss: 0.706 | Train PPL:   2.026
	 Val. Loss: 3.744 |  Val. PPL:  42.247
Epoch: 228 | Time: 0m 6s
	Train Loss: 0.702 | Train PPL:   2.017
	 Val. Loss: 3.813 |  Val. PPL:  45.294
Epoch: 229 | 

[0, 1, 50, 186, 1, 60, 29, 56, 10, 50, 186, 49, 66, 18, 44, 13, 50, 186, 13, 61, 11, 20, 31, 61, 11, 6, 21, 61, 5, 9, 16, 94, 5, 9, 23, 61, 11, 9, 24, 61, 5, 9, 0, 1, 63, 5, 9, 37, 61, 11, 6, 42, 94, 5, 9, 10, 94, 5, 9, 49, 61, 11, 9, 49, 61, 11, 9, 28, 61, 26, 9, 13, 50, 186, 15, 66, 8, 9, 21, 66, 11, 9, 16, 61, 5, 9, 23, 61, 18, 9, 24, 61, 11, 9, 36, 66, 11, 9, 0, 1, 94, 8, 9, 37, 61, 5, 9, 7, 66, 11, 9, 42, 94, 5, 9, 10, 94, 18, 9, 43, 61, 5, 9, 49, 61, 32, 9, 28, 94, 29, 9, 13, 114, 11, 9, 15, 61, 29, 9, 21, 66, 29, 9, 16, 94, 29, 9, 23, 60, 18, 56, 24, 94, 29, 9, 36, 66, 29, 9, 0, 1, 94, 30, 56, 49, 61, 52, 9, 28, 61, 32, 9, 28, 94, 41, 9, 13, 50, 186, 13, 66, 41, 9, 15, 63, 29, 9, 21, 61, 30, 9, 16, 61, 77, 9, 23, 61, 32, 9, 24, 61, 18, 9, 36, 61, 77, 9, 0, 1, 94, 41, 9, 7, 61, 34, 9, 42, 61, 129, 9, 42, 61, 11, 9, 10, 61, 130, 9, 49, 61, 5, 9, 28, 61, 18, 9, 13, 94, 18, 9, 13, 61, 18, 20, 31, 61, 73, 48, 21, 61, 52, 9, 23, 61, 32, 48, 16, 61, 52, 9, 16, 50, 186, 23, 66, 11, 9, 2

Epoch: 240 | Time: 0m 6s
	Train Loss: 0.665 | Train PPL:   1.944
	 Val. Loss: 3.838 |  Val. PPL:  46.423
=> Saving checkpoint
Epoch: 241 | Time: 0m 6s
	Train Loss: 0.669 | Train PPL:   1.952
	 Val. Loss: 3.855 |  Val. PPL:  47.233
=> Saving checkpoint
Epoch: 242 | Time: 0m 6s
	Train Loss: 0.658 | Train PPL:   1.932
	 Val. Loss: 3.895 |  Val. PPL:  49.161
Epoch: 243 | Time: 0m 6s
	Train Loss: 0.654 | Train PPL:   1.924
	 Val. Loss: 3.903 |  Val. PPL:  49.531
Epoch: 244 | Time: 0m 6s
	Train Loss: 0.649 | Train PPL:   1.914
	 Val. Loss: 3.900 |  Val. PPL:  49.410
Epoch: 245 | Time: 0m 6s
	Train Loss: 0.645 | Train PPL:   1.905
	 Val. Loss: 3.951 |  Val. PPL:  51.977
Epoch: 246 | Time: 0m 6s
	Train Loss: 0.641 | Train PPL:   1.898
	 Val. Loss: 3.955 |  Val. PPL:  52.204
Epoch: 247 | Time: 0m 6s
	Train Loss: 0.636 | Train PPL:   1.890
	 Val. Loss: 3.906 |  Val. PPL:  49.695
Epoch: 248 | Time: 0m 6s
	Train Loss: 0.639 | Train PPL:   1.895
	 Val. Loss: 3.888 |  Val. PPL:  48.806
Epoch: 249 | 

[0, 1, 50, 79, 1, 60, 58, 44, 7, 60, 29, 47, 10, 50, 79, 43, 60, 58, 82, 13, 50, 79, 31, 60, 18, 9, 15, 60, 58, 9, 16, 50, 79, 23, 66, 58, 44, 36, 60, 5, 9, 0, 1, 50, 79, 1, 55, 58, 9, 37, 60, 29, 9, 7, 60, 18, 44, 42, 60, 5, 44, 10, 50, 79, 43, 60, 18, 56, 10, 50, 79, 13, 60, 11, 20, 15, 60, 5, 20, 15, 60, 18, 20, 21, 60, 58, 9, 16, 50, 79, 24, 60, 18, 75, 0, 1, 50, 79, 37, 60, 5, 75, 10, 50, 79, 28, 60, 18, 9, 13, 50, 79, 13, 60, 11, 20, 31, 60, 58, 9, 15, 60, 29, 20, 15, 60, 29, 20, 21, 60, 18, 6, 16, 50, 79, 16, 60, 58, 9, 24, 60, 29, 20, 24, 60, 58, 6, 0, 1, 50, 79, 1, 35, 18, 9, 37, 60, 11, 20, 37, 60, 14, 20, 7, 60, 11, 38, 10, 60, 14, 9, 10, 55, 8, 6, 10, 35, 105, 9, 43, 60, 8, 9, 49, 60, 14, 44, 13, 50, 79, 13, 55, 105, 9, 31, 60, 5, 9, 15, 60, 14, 9, 15, 60, 29, 9, 21, 60, 105, 9, 16, 50, 79, 16, 50, 79]
[0, 1, 50, 145, 1, 94, 58, 9, 37, 60, 98, 9, 7, 60, 5, 46, 43, 55, 98, 120, 36, 60, 98, 20, 0, 1, 81, 58, 20, 37, 57, 98, 20, 7, 55, 69, 20, 42, 94, 53, 20, 10, 60, 69, 20, 4

[0, 1, 50, 171, 1, 61, 58, 44, 7, 61, 77, 44, 10, 61, 58, 44, 49, 61, 77, 6, 13, 60, 52, 44, 15, 61, 58, 44, 16, 61, 77, 44, 24, 61, 29, 38, 0, 1, 61, 52, 44, 7, 66, 58, 44, 10, 61, 98, 6, 49, 61, 98, 6, 13, 66, 58, 9, 13, 66, 5, 9, 31, 60, 32, 46, 15, 61, 29, 44, 16, 61, 58, 44, 24, 66, 32, 44, 0, 1, 61, 58, 44, 7, 61, 32, 44, 7, 61, 98, 44, 10, 61, 29, 44, 49, 61, 58, 44, 13, 61, 32, 44, 15, 66, 58, 44, 16, 61, 32, 44, 24, 61, 77, 38, 0, 1, 94, 98, 44, 7, 61, 58, 44, 10, 66, 58, 44, 49, 61, 5, 44, 13, 94, 98, 47, 16, 61, 32, 47, 16, 61, 32, 44, 24, 61, 52, 44, 0, 1, 61, 58, 44, 7, 61, 32, 44, 10, 66, 5, 44, 49, 66, 32, 44, 49, 61, 98, 44, 13, 66, 105, 44, 16, 61, 58, 44, 24, 61, 32, 44, 0, 1, 61, 58, 44, 7, 66, 58, 44, 10, 61, 77, 44, 49, 61, 58, 44, 13, 60, 58, 38, 16, 61, 58, 44, 16, 66, 58, 44, 16, 61, 98, 38, 0, 1, 61, 5, 44, 7, 61, 58, 44, 7, 61, 98, 44, 10, 61, 58, 44, 10, 61, 98, 44, 49, 61, 98, 38, 49, 61, 58, 44, 13, 66, 5, 44, 16, 66, 32, 44, 16, 61, 32, 44, 24, 61, 5, 44, 

[0, 1, 50, 79, 37, 4, 53, 6, 42, 22, 106, 6, 43, 57, 53, 12, 28, 57, 106, 6, 31, 22, 106, 20, 15, 57, 53, 48, 36, 40, 72, 20, 0, 37, 22, 106, 12, 0, 37, 22, 53, 74, 37, 62, 69, 20, 7, 57, 68, 20, 42, 62, 69, 20, 10, 57, 69, 120, 36, 57, 69, 6, 0, 1, 57, 68, 20, 37, 22, 53, 20, 37, 62, 69, 6, 42, 62, 53, 20, 10, 57, 69, 20, 43, 57, 68, 20, 28, 57, 52, 20, 13, 55, 69, 20, 31, 40, 53, 20, 15, 57, 69, 20, 21, 22, 69, 20, 21, 57, 68, 20, 21, 22, 53, 20, 16, 57, 69, 20, 23, 62, 53, 20, 36, 22, 69, 20, 0, 1, 57, 52, 9, 37, 57, 53, 9, 7, 40, 53, 9, 42, 40, 53, 9, 10, 57, 69, 96, 0, 1, 57, 53, 20, 37, 57, 53, 20, 37, 62, 69, 20, 7, 40, 53, 20, 42, 40, 69, 20, 42, 57, 58, 20, 42, 57, 53, 46, 28, 57, 69, 12, 31, 40, 53, 20, 15, 62, 69, 20, 21, 57, 77, 9, 21, 22, 72, 20, 16, 22, 69, 47, 0, 1, 57, 53, 20, 37, 60, 69, 20, 7, 62, 68, 46, 28, 57, 52, 9, 42, 57, 69, 20, 42, 40, 53, 20, 31, 40, 69, 9, 15, 57, 53, 6, 23, 22, 53, 20, 15, 22, 69, 20, 21, 22, 52, 46, 24, 55, 68, 20, 24, 4, 69, 20, 0, 1, 62,

[0, 1, 50, 186, 1, 60, 29, 75, 49, 60, 18, 44, 13, 60, 78, 44, 15, 60, 18, 44, 16, 60, 11, 56, 0, 1, 60, 26, 44, 7, 60, 11, 44, 10, 60, 78, 9, 49, 60, 18, 44, 13, 60, 11, 44, 15, 60, 18, 44, 16, 60, 29, 9, 23, 60, 18, 44, 24, 60, 78, 9, 36, 60, 11, 44, 0, 1, 60, 18, 44, 1, 60, 18, 44, 7, 60, 11, 44, 10, 66, 11, 44, 49, 35, 78, 44, 13, 60, 29, 44, 15, 60, 18, 44, 16, 60, 11, 44, 24, 60, 18, 44, 0, 1, 35, 29, 44, 7, 60, 18, 44, 10, 60, 5, 44, 49, 60, 18, 44, 13, 60, 11, 38, 24, 60, 26, 44, 0, 1, 60, 18, 44, 7, 35, 11, 38, 49, 61, 29, 38, 49, 66, 18, 44, 13, 60, 58, 44, 15, 60, 11, 38, 15, 60, 18, 44, 16, 60, 5, 44, 24, 60, 18, 44, 24, 60, 11, 44, 0, 1, 60, 26, 38, 37, 60, 11, 38, 10, 60, 11, 44, 49, 60, 11, 44, 13, 60, 14, 38, 16, 35, 26, 44, 24, 60, 8, 44, 0, 1, 60, 11, 44, 1, 60, 29, 38, 10, 60, 5, 44, 49, 60, 11, 9, 28, 61, 5, 9, 13, 60, 11, 9, 15, 61, 18, 44, 16, 60, 5, 44, 24, 60, 5, 44, 0, 1, 60, 26, 9, 37, 35, 26, 44, 7, 61, 11, 9, 42, 60, 18, 44, 49, 61, 11, 44, 13, 60, 5, 44, 15

Epoch: 318 | Time: 0m 6s
	Train Loss: 0.444 | Train PPL:   1.559
	 Val. Loss: 4.569 |  Val. PPL:  96.430
Epoch: 319 | Time: 0m 6s
	Train Loss: 0.437 | Train PPL:   1.548
	 Val. Loss: 4.597 |  Val. PPL:  99.204
[0, 1, 50, 104, 37, 57, 8, 47, 10, 50, 104, 43, 4, 30, 59, 13, 50, 104, 16, 50, 104, 0, 1, 50, 104, 42, 55, 29, 20, 10, 50, 104, 10, 35, 58, 20, 43, 62, 18, 56, 13, 50, 104, 21, 57, 18, 9, 16, 50, 104, 16, 55, 18, 9, 23, 62, 58, 75, 0, 1, 50, 104, 10, 50, 104, 10, 17, 11, 56, 13, 50, 104, 43, 55, 18, 9, 43, 55, 8, 47, 13, 50, 104, 13, 55, 58, 9, 31, 62, 11, 47, 16, 50, 104, 16, 35, 90, 9, 23, 55, 8, 20, 24, 62, 26, 56, 0, 10, 55, 18, 96, 0, 1, 50, 104, 10, 50, 104, 43, 55, 58, 6, 28, 55, 30, 96, 13, 50, 104, 0, 1, 50, 104, 49, 62, 18, 9, 13, 50, 104, 13, 50, 104, 31, 4, 58, 56, 31, 4, 18, 56, 16, 62, 29, 9, 24, 62, 18, 9, 36, 62, 8, 9, 0, 1, 17, 30, 27, 49, 95, 30, 38, 10, 50, 104, 49, 40, 30, 6, 28, 55, 41, 6, 13, 50, 104, 31, 17, 18, 9, 31, 40, 68, 9, 15, 57, 52, 44, 21, 40, 33

[0, 1, 50, 186, 1, 60, 29, 46, 7, 60, 58, 6, 10, 66, 11, 44, 49, 66, 105, 44, 13, 66, 14, 38, 15, 60, 8, 44, 16, 35, 52, 6, 24, 60, 14, 9, 36, 66, 29, 6, 0, 1, 61, 11, 44, 7, 61, 11, 6, 10, 66, 33, 44, 49, 61, 11, 6, 49, 61, 52, 6, 13, 60, 29, 46, 15, 66, 52, 6, 16, 60, 29, 12, 24, 66, 11, 6, 0, 1, 66, 11, 44, 7, 66, 33, 46, 7, 66, 11, 6, 42, 66, 11, 6, 10, 66, 11, 6, 43, 66, 33, 6, 49, 35, 11, 9, 28, 35, 33, 9, 13, 66, 11, 6, 15, 66, 29, 6, 15, 66, 58, 44, 16, 94, 105, 6]
[0, 1, 50, 150, 37, 61, 111, 44, 42, 61, 87, 46, 43, 61, 105, 46, 28, 61, 98, 46, 28, 61, 78, 46, 31, 66, 98, 44, 21, 66, 90, 44, 21, 61, 78, 38, 23, 61, 98, 44, 36, 61, 90, 44, 36, 61, 90, 38, 0, 37, 61, 98, 46, 42, 66, 90, 47, 42, 61, 98, 20, 28, 81, 26, 85, 28, 61, 78, 44, 31, 61, 77, 44, 21, 66, 90, 44, 23, 61, 98, 44, 36, 61, 26, 44, 0, 37, 61, 105, 44, 42, 94, 26, 44, 43, 61, 88, 46, 28, 94, 87, 25, 28, 61, 78, 44, 31, 61, 98, 44, 21, 66, 90, 44, 23, 94, 111, 44, 23, 66, 98, 44, 36, 61, 98, 44, 36, 61, 78, 85, 

[0, 1, 50, 145, 1, 94, 33, 9, 37, 66, 58, 9, 7, 66, 53, 20, 42, 94, 98, 120, 36, 60, 53, 6, 0, 1, 81, 58, 27, 49, 61, 32, 27, 15, 61, 65, 6, 16, 94, 65, 27, 0, 7, 81, 106, 75, 10, 60, 53, 75, 15, 61, 65, 96, 0, 1, 81, 58, 27, 49, 66, 53, 6, 49, 66, 53, 44, 13, 63, 53, 75, 15, 61, 53, 9, 21, 66, 58, 9, 16, 115, 53, 6, 24, 60, 72, 6, 0, 1, 63, 53, 6, 7, 94, 54, 9, 10, 61, 33, 6, 49, 66, 53, 9, 28, 114, 53, 9, 13, 94, 53, 9, 31, 60, 54, 27, 0, 1, 94, 52, 6, 7, 60, 69, 6, 10, 66, 77, 6, 49, 60, 72, 75, 13, 61, 106, 44, 15, 81, 54, 9, 21, 63, 106, 46, 24, 60, 72, 20, 0, 1, 60, 72, 9, 37, 63, 53, 9, 7, 60, 107, 6, 10, 66, 68, 9, 10, 60, 65, 6, 49, 60, 106, 6, 13, 60, 53, 27, 15, 55, 65, 6, 16, 114, 52, 6, 24, 94, 65, 6, 0, 1, 60, 72, 44, 7, 61, 106, 6, 42, 61, 65, 6, 10, 61, 108, 6, 43, 60, 72, 6, 49, 61, 65, 6]
[0, 1, 50, 127, 1, 61, 29, 20, 37, 66, 52, 44, 42, 60, 32, 9, 10, 61, 29, 44, 43, 61, 52, 44, 31, 60, 32, 44, 21, 66, 52, 44, 23, 66, 29, 38, 0, 37, 66, 52, 44, 42, 60, 53, 6, 43, 60

[0, 1, 2, 137, 1, 60, 29, 46, 10, 60, 30, 56, 15, 60, 78, 9, 21, 60, 18, 9, 16, 60, 11, 47, 0, 7, 60, 26, 44, 10, 60, 29, 75, 15, 60, 78, 20, 21, 60, 18, 75, 0, 7, 60, 11, 38, 49, 60, 11, 44, 13, 60, 18, 75, 15, 60, 78, 9, 16, 60, 29, 44, 24, 60, 18, 44, 0, 1, 60, 29, 6, 7, 60, 11, 46, 49, 60, 92, 9, 28, 60, 14, 9, 13, 60, 92, 9, 31, 60, 14, 6, 15, 60, 26, 9, 21, 60, 11, 6, 16, 60, 11, 9, 23, 60, 11, 6, 36, 60, 18, 6, 0, 1, 60, 11, 9, 7, 60, 18, 9, 37, 60, 11, 9, 7, 60, 11, 6, 10, 60, 18, 75, 49, 60, 11, 9, 28, 60, 11, 20, 13, 60, 11, 9, 15, 60, 18, 9, 21, 60, 11, 9, 16, 60, 11, 6, 24, 60, 11, 6, 0, 1, 60, 11, 9, 7, 60, 11, 6, 7, 60, 11, 9, 42, 60, 29, 48]
[0, 1, 50, 145, 1, 94, 58, 9, 37, 55, 98, 20, 7, 60, 5, 46, 43, 55, 98, 120, 36, 60, 98, 20, 0, 1, 81, 58, 20, 37, 57, 98, 20, 7, 55, 69, 6, 10, 94, 58, 12, 28, 55, 77, 20, 13, 55, 77, 20, 31, 55, 77, 20, 15, 61, 58, 20, 21, 63, 77, 27, 0, 1, 62, 58, 20, 37, 94, 98, 96, 28, 60, 98, 96, 0, 1, 55, 58, 12, 10, 55, 98, 20, 49, 66, 98, 19

[0, 1, 50, 192, 1, 55, 8, 44, 7, 35, 26, 44, 10, 35, 90, 38, 13, 55, 90, 44, 15, 66, 18, 44, 16, 35, 90, 44, 24, 60, 26, 44, 0, 1, 60, 8, 44, 7, 35, 92, 44, 10, 60, 8, 25, 13, 35, 88, 56, 15, 66, 129, 44, 16, 60, 92, 6, 24, 60, 92, 44, 0, 1, 60, 14, 44, 7, 66, 169, 44, 10, 66, 129, 38, 13, 60, 92, 44, 15, 66, 8, 44, 16, 60, 88, 112, 0, 7, 35, 8, 44, 10, 35, 8, 44, 49, 60, 8, 44, 13, 60, 26, 38, 16, 35, 8, 44, 24, 60, 26, 44, 0, 1, 60, 8, 44, 7, 60, 26, 44, 10, 60, 90, 38, 13, 60, 11, 44, 49, 60, 18, 44, 15, 60, 26, 44, 15, 60, 8, 44, 16, 55, 8, 44, 24, 60, 18, 44, 0, 1, 55, 26, 44, 7, 60, 8, 44, 10, 35, 92, 38, 13, 60, 26, 44, 15, 60, 8, 44, 16, 60, 26, 38, 0, 1, 35, 90, 44, 7, 60, 8, 44, 10, 60, 26, 38, 13, 60, 18, 44, 15, 35, 18, 44, 15, 66, 90, 44, 16, 60, 11, 44, 16, 60, 18, 44, 24, 60, 26, 44, 0, 1, 60, 26, 38, 1, 60, 26, 44, 7, 35, 90, 44, 10, 60, 8, 44, 10, 60, 88, 56, 13, 35, 18, 9, 31, 60, 26, 44, 15, 35, 87, 44, 16, 35, 8, 44, 24, 60, 78, 56, 0, 1, 61, 90, 44, 7, 35, 90, 6, 1

Epoch: 374 | Time: 0m 6s
	Train Loss: 0.366 | Train PPL:   1.442
	 Val. Loss: 4.897 |  Val. PPL: 133.920
Epoch: 375 | Time: 0m 6s
	Train Loss: 0.358 | Train PPL:   1.431
	 Val. Loss: 4.841 |  Val. PPL: 126.641
Epoch: 376 | Time: 0m 6s
	Train Loss: 0.353 | Train PPL:   1.424
	 Val. Loss: 4.877 |  Val. PPL: 131.236
Epoch: 377 | Time: 0m 6s
	Train Loss: 0.348 | Train PPL:   1.416
	 Val. Loss: 4.879 |  Val. PPL: 131.445
Epoch: 378 | Time: 0m 6s
	Train Loss: 0.363 | Train PPL:   1.437
	 Val. Loss: 4.918 |  Val. PPL: 136.793
Epoch: 379 | Time: 0m 6s
	Train Loss: 0.350 | Train PPL:   1.419
	 Val. Loss: 4.926 |  Val. PPL: 137.765
[0, 1, 50, 104, 37, 62, 41, 6, 42, 55, 33, 20, 10, 50, 104, 10, 55, 33, 93, 13, 50, 104, 16, 50, 104, 36, 62, 68, 9, 0, 1, 50, 104, 1, 55, 41, 20, 37, 62, 68, 75, 10, 50, 104, 28, 55, 34, 44, 13, 50, 104, 31, 62, 68, 9, 15, 62, 41, 9, 21, 57, 78, 9, 16, 50, 104, 16, 55, 41, 9, 23, 55, 30, 6, 23, 55, 33, 56, 0, 1, 55, 34, 75, 10, 55, 33, 44, 28, 62, 70, 44, 21, 35, 34,

[0, 1, 50, 192, 37, 57, 33, 12, 10, 60, 32, 9, 43, 55, 33, 82, 23, 60, 52, 6, 0, 37, 66, 30, 6, 42, 35, 11, 20, 10, 60, 32, 20, 43, 35, 32, 20, 49, 60, 58, 20, 28, 35, 5, 20, 13, 35, 29, 20, 31, 60, 5, 9, 15, 60, 5, 9, 21, 66, 105, 20, 16, 60, 5, 20, 23, 35, 11, 20, 24, 60, 58, 9, 36, 62, 58, 20, 0, 1, 60, 58, 9, 37, 35, 98, 9, 7, 35, 29, 9, 42, 60, 58, 9, 10, 60, 58, 9, 10, 60, 52, 9, 43, 62, 32, 20, 49, 60, 32, 20, 28, 35, 53, 20, 13, 57, 54, 20, 31, 35, 64, 47, 21, 60, 72, 47, 23, 60, 54, 47, 0, 1, 60, 65, 46, 43, 35, 65, 38, 13, 35, 106, 38, 23, 60, 64, 47, 0, 42, 60, 54, 12, 49, 60, 54, 6, 28, 35, 53, 47, 28, 35, 32, 6, 13, 35, 65, 20, 31, 60, 65, 9, 15, 35, 107, 9, 21, 60, 65, 9, 16, 57, 65, 9, 23, 35, 72, 9, 23, 35, 53, 6, 24, 60, 30, 38, 0, 1, 60, 58, 20, 37, 60, 64, 20, 37, 60, 33, 20, 7, 60, 72, 20, 42, 60, 72, 20, 10, 60, 53, 47, 13, 62, 53, 9, 13, 35, 58, 9, 31, 60, 64, 38, 31, 35, 53, 38, 23, 57, 54, 9, 15, 35, 29, 9, 21, 55, 53, 9, 16, 35, 18, 12, 16, 35, 54, 9, 23, 62, 5

[0, 1, 50, 192, 37, 57, 33, 12, 10, 60, 32, 9, 43, 55, 33, 82, 23, 60, 52, 20, 24, 60, 32, 20, 36, 57, 33, 20, 0, 1, 60, 53, 9, 37, 62, 54, 46, 10, 66, 53, 20, 43, 35, 54, 27, 23, 62, 33, 6, 36, 60, 53, 20, 0, 1, 60, 54, 20, 37, 60, 106, 46, 10, 55, 54, 20, 43, 35, 106, 47, 21, 55, 54, 6, 23, 55, 106, 9, 36, 55, 65, 9, 0, 37, 55, 64, 6, 42, 60, 106, 20, 10, 60, 64, 20, 43, 62, 54, 44, 28, 35, 53, 75, 0, 37, 66, 58, 93, 23, 60, 52, 38, 0, 37, 60, 58, 83, 0, 37, 35, 11, 38, 43, 35, 52, 47, 31, 60, 69, 47, 23, 35, 53, 75, 0, 43, 55, 33, 9, 49, 94, 69, 9, 28, 35, 53, 9, 13, 35, 69, 20, 31, 35, 33, 9, 15, 35, 69, 9, 21, 60, 68, 9, 21, 60, 41, 9]
[0, 1, 50, 104, 37, 57, 8, 47, 43, 55, 88, 9, 49, 81, 11, 75, 24, 81, 11, 9, 36, 55, 18, 9, 0, 1, 55, 78, 9, 37, 66, 11, 75, 15, 63, 18, 38, 23, 81, 11, 9, 24, 81, 18, 20, 36, 94, 78, 9, 0, 37, 61, 11, 9, 7, 81, 18, 12, 43, 35, 90, 96, 36, 94, 18, 6, 0, 37, 61, 26, 75, 49, 94, 11, 44, 13, 55, 18, 9, 31, 60, 11, 9, 15, 94, 26, 9, 21, 66, 11, 9, 23, 6

Epoch: 402 | Time: 0m 6s
	Train Loss: 0.325 | Train PPL:   1.384
	 Val. Loss: 5.093 |  Val. PPL: 162.886
Epoch: 403 | Time: 0m 6s
	Train Loss: 0.329 | Train PPL:   1.389
	 Val. Loss: 5.128 |  Val. PPL: 168.660
Epoch: 404 | Time: 0m 6s
	Train Loss: 0.320 | Train PPL:   1.377
	 Val. Loss: 4.955 |  Val. PPL: 141.948
Epoch: 405 | Time: 0m 6s
	Train Loss: 0.317 | Train PPL:   1.372
	 Val. Loss: 5.071 |  Val. PPL: 159.282
Epoch: 406 | Time: 0m 6s
	Train Loss: 0.321 | Train PPL:   1.379
	 Val. Loss: 5.060 |  Val. PPL: 157.547
Epoch: 407 | Time: 0m 6s
	Train Loss: 0.318 | Train PPL:   1.374
	 Val. Loss: 5.040 |  Val. PPL: 154.394
Epoch: 408 | Time: 0m 6s
	Train Loss: 0.319 | Train PPL:   1.376
	 Val. Loss: 5.100 |  Val. PPL: 164.031
Epoch: 409 | Time: 0m 6s
	Train Loss: 0.315 | Train PPL:   1.371
	 Val. Loss: 5.167 |  Val. PPL: 175.445
Epoch: 410 | Time: 0m 6s
	Train Loss: 0.317 | Train PPL:   1.373
	 Val. Loss: 5.066 |  Val. PPL: 158.468
Epoch: 411 | Time: 0m 6s
	Train Loss: 0.315 | Train PPL

[0, 1, 50, 153, 37, 55, 98, 47, 10, 50, 153, 10, 62, 105, 12, 13, 50, 153, 13, 50, 153, 31, 66, 98, 20, 15, 55, 58, 12, 16, 50, 153, 24, 60, 52, 9, 0, 1, 94, 105, 12, 42, 60, 87, 6, 10, 50, 153, 10, 35, 105, 25, 13, 50, 153, 15, 57, 105, 56, 16, 50, 153, 16, 50, 153, 16, 50, 153, 24, 60, 5, 6, 36, 60, 105, 6, 0, 1, 57, 87, 6, 7, 57, 98, 6, 7, 57, 98, 6, 10, 50, 153, 10, 35, 105, 19, 13, 35, 105, 12, 13, 50, 153, 15, 57, 105, 44, 16, 35, 90, 12, 24, 35, 98, 19, 0, 1, 62, 105, 44, 7, 62, 105, 25, 10, 50, 153, 10, 50, 153, 10, 35, 98, 44, 49, 62, 87, 19, 13, 50, 153, 21, 153, 16, 55, 105, 19, 16, 35, 111, 44, 24, 62, 105, 44, 36, 60, 78, 9, 0, 1, 50, 153, 1, 62, 98, 20, 37, 60, 88, 9, 7, 62, 87, 6, 42, 60, 87, 6, 10, 50, 153, 10, 50, 153, 43, 66, 105, 46, 49, 60, 78, 6, 13, 50, 153, 13, 50, 153, 15, 60, 98, 19, 16, 50, 153, 16, 60, 105, 9, 23, 60, 105, 9, 24, 60, 98, 9, 36, 55, 52, 44, 0, 1, 50, 153, 1, 62, 105, 44, 7, 62, 77, 44, 10, 62, 77, 6, 43, 55, 98, 44, 49, 60, 98, 46, 31, 66, 90,

[0, 1, 50, 186, 1, 60, 54, 75, 49, 60, 70, 44, 13, 60, 53, 56, 15, 60, 69, 44, 16, 94, 34, 44, 24, 60, 69, 9, 36, 66, 68, 9, 0, 1, 66, 53, 9, 37, 66, 34, 82, 49, 66, 53, 9, 28, 61, 33, 9, 13, 66, 34, 9, 31, 61, 33, 44, 21, 66, 34, 9, 16, 66, 34, 9, 24, 35, 70, 9, 36, 66, 68, 9, 0, 1, 61, 33, 9, 37, 61, 34, 9, 7, 61, 33, 9, 42, 94, 68, 9, 10, 66, 64, 9, 43, 61, 70, 9, 49, 61, 33, 9, 28, 66, 34, 9, 13, 66, 34, 9, 13, 66, 64, 38, 16, 66, 33, 9, 23, 66, 34, 9, 24, 66, 70, 9, 36, 66, 64, 9, 0, 1, 60, 34, 9, 37, 61, 70, 9, 7, 35, 53, 9, 42, 61, 33, 6, 42, 63, 70, 9, 10, 60, 53, 9, 43, 94, 34, 9, 49, 94, 34, 9, 49, 60, 34, 9, 28, 60, 53, 9, 13, 66, 70, 9, 31, 61, 70, 9, 15, 61, 34, 20, 21, 63, 33, 9, 16, 94, 64, 20, 23, 66, 53, 9, 23, 61, 70, 82, 24, 61, 34, 9, 36, 66, 68, 9, 0, 1, 60, 69, 20, 37, 66, 34, 9, 7, 66, 72, 9, 42, 61, 70, 9, 10, 61, 53, 9, 43, 61, 64, 9, 49, 66, 64, 38, 49, 61, 30, 9, 28, 94, 64, 38, 13, 94, 73, 74, 15, 66, 73, 9, 16, 60, 73, 9, 16, 61, 33, 9, 16, 61, 30, 9, 23, 6

Epoch: 440 | Time: 0m 6s
	Train Loss: 0.287 | Train PPL:   1.332
	 Val. Loss: 5.191 |  Val. PPL: 179.696
=> Saving checkpoint
Epoch: 441 | Time: 0m 6s
	Train Loss: 0.282 | Train PPL:   1.326
	 Val. Loss: 5.276 |  Val. PPL: 195.556
=> Saving checkpoint
Epoch: 442 | Time: 0m 6s
	Train Loss: 0.282 | Train PPL:   1.325
	 Val. Loss: 5.180 |  Val. PPL: 177.721
Epoch: 443 | Time: 0m 6s
	Train Loss: 0.287 | Train PPL:   1.333
	 Val. Loss: 5.297 |  Val. PPL: 199.743
Epoch: 444 | Time: 0m 6s
	Train Loss: 0.282 | Train PPL:   1.325
	 Val. Loss: 5.206 |  Val. PPL: 182.418
Epoch: 445 | Time: 0m 6s
	Train Loss: 0.280 | Train PPL:   1.323
	 Val. Loss: 5.280 |  Val. PPL: 196.284
Epoch: 446 | Time: 0m 6s
	Train Loss: 0.276 | Train PPL:   1.317
	 Val. Loss: 5.293 |  Val. PPL: 199.006
Epoch: 447 | Time: 0m 6s
	Train Loss: 0.269 | Train PPL:   1.309
	 Val. Loss: 5.277 |  Val. PPL: 195.757
Epoch: 448 | Time: 0m 6s
	Train Loss: 0.280 | Train PPL:   1.323
	 Val. Loss: 5.365 |  Val. PPL: 213.707
Epoch: 449 | 

[0, 1, 2, 174, 1, 95, 54, 56, 1, 55, 132, 20, 49, 39, 34, 44, 13, 95, 54, 12, 16, 95, 54, 44, 24, 39, 64, 44, 0, 1, 95, 65, 12, 10, 95, 65, 6, 49, 95, 64, 44, 13, 95, 54, 82, 0, 1, 95, 54, 25, 49, 95, 64, 44, 13, 95, 54, 12, 15, 95, 64, 9, 16, 95, 34, 44, 24, 40, 33, 44, 0, 1, 95, 54, 47, 42, 4, 64, 9, 10, 17, 54, 6, 49, 95, 34, 44, 13, 17, 33, 19, 0, 7, 95, 54, 9, 49, 95, 54, 20, 13, 95, 64, 38, 16, 95, 54, 46, 0, 1, 95, 65, 9, 7, 4, 65, 9, 42, 62, 64, 9, 10, 95, 54, 6, 49, 95, 34, 44, 13, 95, 33, 56, 0, 7, 95, 54, 46, 49, 95, 54, 6, 13, 95, 34, 47, 16, 95, 33, 44, 23, 39, 30, 44]
[0, 1, 50, 134, 1, 55, 98, 56, 7, 55, 98, 19, 0, 7, 35, 105, 9, 42, 94, 98, 19, 0, 1, 62, 90, 47, 10, 61, 105, 46, 15, 63, 90, 27, 15, 66, 98, 19, 16, 35, 98, 6, 24, 94, 58, 47, 0, 1, 55, 90, 47, 7, 62, 98, 44, 10, 57, 77, 101, 0, 7, 55, 77, 96, 0, 7, 55, 52, 96, 0, 7, 55, 77, 47, 10, 55, 52, 12, 15, 57, 52, 12, 15, 40, 77, 44, 16, 55, 98, 46, 0, 7, 55, 98, 44, 10, 55, 77, 44, 49, 55, 98, 44, 13, 55, 78, 44,

[0, 1, 50, 186, 1, 66, 58, 47, 7, 60, 58, 6, 10, 66, 52, 56, 15, 66, 33, 44, 16, 66, 53, 56, 24, 61, 58, 6, 0, 1, 61, 11, 44, 7, 66, 105, 6, 7, 66, 52, 6, 10, 66, 8, 103, 10, 61, 30, 27, 15, 66, 33, 44, 16, 60, 53, 56, 24, 61, 30, 6, 0, 1, 94, 18, 6, 7, 94, 11, 44, 7, 66, 53, 6, 10, 61, 14, 48, 10, 66, 54, 75, 15, 66, 64, 6, 16, 94, 98, 82, 16, 66, 106, 47, 0, 1, 66, 106, 44, 7, 66, 54, 44, 10, 61, 11, 48, 10, 61, 53, 71, 15, 94, 5, 6, 16, 66, 105, 44, 24, 61, 5, 46, 0, 1, 66, 105, 38, 1, 66, 54, 44, 7, 61, 32, 44, 10, 61, 29, 45, 10, 60, 52, 75, 15, 60, 58, 6, 16, 66, 30, 46, 24, 61, 8, 44, 0, 1, 61, 105, 44, 1, 66, 33, 6, 7, 66, 8, 44, 7, 60, 32, 6, 10, 94, 98, 100, 10, 66, 30, 6, 49, 66, 29, 44, 13, 61, 14, 38, 16, 60, 30, 6, 24, 61, 29, 38, 0, 1, 60, 32, 44, 7, 61, 30, 6, 10, 61, 32, 6, 49, 61, 18, 19]
[0, 1, 50, 134, 10, 50, 134, 43, 39, 30, 48, 13, 50, 134, 21, 57, 29, 6, 16, 50, 134, 23, 57, 30, 82, 0, 1, 50, 134, 10, 50, 134, 43, 40, 30, 27, 13, 50, 134, 15, 17, 30, 9, 21, 62, 

In [None]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [25]:
check_mode_collapse(model)

[0, 1, 2, 89, 67, 51, 61, 66, 91, 53, 73, 15, 13, 51, 57, 22, 70, 51, 61, 66, 74, 51, 57, 22, 0, 1, 53, 73, 22, 67, 64, 73, 59, 74, 35, 9, 15, 0, 1, 54, 57, 15, 67, 35, 73, 22, 4, 53, 96, 77, 74, 53, 57, 15, 0, 1, 63, 9, 15, 67, 14, 62, 103, 70, 54, 50, 137, 74, 35, 61, 15, 0, 1, 11, 61, 68]
[0, 1, 2, 142, 67, 24, 40, 15, 4, 11, 76, 47, 8, 2, 142, 91, 38, 9, 15, 13, 2, 142, 13, 32, 21, 15, 70, 38, 76, 15, 16, 14, 82, 34, 17, 2, 142, 74, 38, 73, 15, 0, 1, 2, 142, 1, 38, 61, 15, 67, 24, 82, 15, 4, 24, 40, 34, 8, 2, 142, 13, 2, 142, 13, 11, 76, 22, 70, 14, 40, 22, 16, 38, 97, 34, 17, 2, 142, 27, 32, 40, 22, 27, 11, 97, 52, 27, 38, 118, 52, 0, 1, 2, 142, 1, 32, 82, 15, 67, 54, 82, 7, 67, 38, 40, 85, 4, 32, 118, 77, 8, 2, 142, 13, 2, 142, 16, 38, 55, 108, 17, 2, 142, 17, 38, 33, 31, 17, 38, 21, 31, 27, 38, 26, 52, 27, 38, 20, 52, 0, 1, 2, 142, 1, 24, 50, 31, 1, 24, 96, 31, 4, 38, 26, 34, 4, 38, 20, 34, 8, 2, 142, 10, 38, 6, 7, 13, 38, 82, 7, 16, 11, 21, 31, 17, 38, 97, 7, 17, 2, 142, 27, 38

[0, 1, 2, 95, 67, 11, 33, 41, 8, 2, 95, 72, 5, 21, 60, 13, 2, 95, 17, 2, 95, 0, 1, 2, 95, 23, 32, 21, 22, 8, 2, 95, 8, 64, 73, 22, 72, 38, 57, 47, 13, 2, 95, 78, 51, 12, 15, 17, 2, 95, 17, 64, 9, 15, 90, 38, 50, 36, 0, 1, 2, 95, 8, 2, 95, 8, 32, 33, 22, 72, 64, 6, 42, 13, 2, 95, 78, 32, 9, 15, 17, 2, 95, 17, 32, 57, 15, 90, 11, 9, 41, 0, 1, 2, 95, 67, 5, 21, 41, 8, 2, 95, 72, 38, 50, 103, 13, 2, 95, 0, 67, 64, 33, 34, 67, 32, 26, 34, 72, 32, 9, 68, 72, 32, 21, 112, 0, 23, 54, 57, 15, 23, 64, 12, 15, 8, 51, 6, 15, 8, 32, 9, 15, 72, 51, 6, 28, 72, 32, 57, 28, 78, 64, 6, 15, 78, 64, 9, 15, 17, 64, 33, 15, 17, 54, 50, 22, 90, 51, 122, 28, 90, 38, 33, 28, 0, 8, 38, 92, 15, 8, 38, 33, 15, 72, 54, 111, 7, 72, 54, 6, 7, 10, 51, 29, 7, 10, 51, 9, 7, 91, 54, 55, 31, 91, 54, 57, 31, 70, 51, 29, 7, 70, 51, 9, 7, 16, 32, 111, 7, 16, 32, 6, 7, 78, 64, 122, 7, 78, 64, 30, 7, 90, 64, 55, 7, 90, 64, 57, 7, 27, 64, 30, 15, 27, 64, 12, 15, 74, 51, 33, 15, 74, 51, 21, 15, 0, 67, 64, 48, 15, 67, 64, 86, 15

0

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [18]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, model, optimizer)

model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint


In [41]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

298.5861798366139


In [None]:
generated_outputs = folder +  "/generated_samples_500epochs"
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [42]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [43]:
for i in range(len(test_outro)):
    if len(test_outro) > 1200:
        continue
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, outro, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
        
        


[0, 1, 50, 188, 4, 58, 88, 22, 8, 56, 88, 47, 23, 56, 25, 39, 28, 56, 25, 39, 0, 8, 56, 25, 39, 20, 37, 88, 43, 21, 56, 89, 43, 23, 56, 93, 39, 28, 37, 141, 22, 0, 4, 56, 93, 22, 8, 56, 16, 24, 21, 56, 88, 22, 23, 56, 89, 39, 28, 58, 93, 43, 0, 4, 56, 89, 43, 8, 37, 88, 47, 23, 56, 25, 26, 28, 56, 88, 43, 0, 4, 63, 25, 43, 8, 56, 91, 24, 21, 56, 98, 7, 33, 56, 10, 7, 23, 56, 98, 43, 36, 56, 91, 43, 28, 56, 25, 43, 0, 4, 56, 89, 43, 8, 56, 89, 39, 21, 56, 89, 39, 21, 56, 93, 43, 23, 56, 89, 39, 28, 37, 133, 43, 0, 4, 37, 16, 39, 28, 56, 111, 43, 0, 4, 37, 93, 22]
0
[0, 1, 50, 178, 4, 114, 42, 7, 14, 67, 59, 7, 8, 82, 29, 7, 18, 63, 32, 7, 19, 62, 42, 7, 49, 82, 10, 7, 20, 64, 31, 22, 30, 63, 32, 7, 21, 62, 42, 7, 33, 67, 10, 7, 23, 67, 31, 7, 53, 41, 32, 7, 36, 114, 42, 7, 38, 37, 59, 7, 28, 95, 29, 22, 0, 1, 5, 32, 7, 4, 114, 42, 22, 14, 61, 6, 7, 8, 62, 31, 22, 18, 63, 32, 7, 19, 62, 42, 22, 49, 82, 6, 22, 20, 114, 31, 22, 30, 41, 32, 7, 21, 64, 42, 7, 33, 56, 6, 7, 23, 95, 29, 7, 53,

[0, 1, 50, 158, 1, 37, 79, 116, 33, 61, 69, 7, 23, 63, 69, 7, 53, 67, 42, 22, 38, 64, 54, 43, 0, 1, 63, 72, 26, 49, 95, 42, 7, 20, 114, 54, 7, 30, 95, 79, 39, 36, 114, 69, 7, 38, 82, 54, 7, 28, 67, 42, 13, 0, 1, 56, 72, 22, 14, 64, 79, 48, 28, 82, 72, 13, 28, 5, 42, 13, 28, 61, 54, 13, 28, 95, 69, 13, 28, 61, 10, 13, 0, 1, 62, 98, 13, 1, 37, 25, 43, 14, 37, 6, 43, 18, 67, 69, 43, 19, 95, 79, 94]
8
[0, 1, 2, 205, 1, 136, 25, 39, 18, 2, 205, 18, 2, 205, 20, 115, 10, 43, 30, 2, 205, 21, 189, 10, 24, 53, 82, 25, 43, 36, 56, 6, 43, 28, 37, 6, 43, 0, 1, 2, 205, 4, 37, 31, 43, 18, 2, 205, 20, 37, 29, 43, 30, 2, 205, 21, 82, 25, 77, 53, 2, 205, 36, 37, 12, 43, 28, 56, 16, 22, 0, 1, 2, 205, 4, 37, 16, 22, 8, 56, 42, 43, 18, 2, 205, 19, 37, 25, 84, 53, 2, 205, 43, 0, 1, 2, 205, 18, 2, 205, 20, 56, 10, 27, 53, 2, 205, 20, 67, 79, 43, 30, 2, 205, 21, 56, 10, 43, 23, 63, 10, 39, 53, 2, 205, 36, 56, 10, 43, 28, 37, 31, 43, 0, 1, 2, 205, 4, 58, 29, 24, 18, 2, 205, 20, 56, 10, 43, 30, 2, 205, 21, 37, 

[0, 1, 50, 194, 4, 58, 34, 27, 18, 61, 32, 7, 19, 56, 34, 83, 36, 61, 54, 13, 38, 61, 32, 13, 28, 58, 34, 13, 0, 1, 61, 52, 7, 4, 37, 55, 45, 18, 67, 52, 13, 19, 37, 55, 26, 36, 63, 34, 22, 28, 61, 52, 13, 0, 1, 61, 55, 13, 4, 61, 106, 45, 18, 56, 55, 13, 19, 37, 106, 46, 23, 56, 55, 22, 36, 56, 106, 7, 28, 56, 66, 7, 0, 4, 56, 65, 22, 8, 61, 106, 13, 18, 61, 65, 13, 19, 63, 55, 43, 20, 37, 52, 77, 0, 4, 67, 59, 94, 36, 61, 54, 39, 0, 4, 61, 29, 84, 0, 4, 37, 6, 39, 19, 37, 29, 46, 21, 61, 17, 46, 36, 37, 59, 39, 0, 19, 56, 34, 7, 49, 95, 70, 7, 20, 37, 34, 7, 30, 37, 32, 13, 21, 37, 34, 7, 33, 37, 32, 7, 23, 61, 54, 7, 23, 61, 31, 7]
20
[0, 1, 2, 216, 18, 67, 89, 87, 18, 62, 25, 78, 0, 14, 95, 89, 45, 18, 67, 25, 102, 18, 62, 181, 102, 0, 14, 37, 88, 43, 18, 61, 98, 24, 33, 67, 6, 24, 0, 14, 37, 91, 43, 18, 37, 25, 43, 49, 37, 89, 76, 0, 14, 37, 89, 43, 18, 37, 42, 57, 33, 37, 16, 43, 53, 61, 79, 43, 38, 61, 79, 116, 0, 14, 37, 89, 112, 0, 14, 37, 88, 43, 18, 61, 89, 112]
21
[0, 1, 50

[0, 1, 50, 188, 4, 58, 88, 22, 8, 56, 88, 47, 23, 56, 25, 39, 28, 56, 91, 39, 0, 8, 56, 25, 39, 20, 37, 88, 43, 21, 56, 89, 43, 23, 56, 93, 39, 28, 37, 141, 22, 0, 4, 56, 93, 22, 8, 56, 16, 24, 21, 56, 88, 22, 23, 56, 89, 39, 28, 58, 93, 43, 0, 4, 56, 89, 43, 8, 37, 88, 47, 23, 56, 25, 26, 28, 56, 88, 43, 0, 4, 63, 25, 43, 8, 56, 91, 24, 21, 56, 98, 7, 33, 56, 10, 7, 23, 56, 98, 43, 36, 56, 91, 43, 28, 56, 25, 43, 0, 4, 56, 89, 43, 8, 56, 88, 60, 28, 37, 89, 43, 0, 4, 37, 88, 43, 8, 37, 16, 48, 21, 56, 88, 43, 23, 61, 89, 39, 28, 37, 111, 43, 0, 4, 37, 93, 22]
30
[0, 1, 2, 127, 19, 62, 70, 43, 19, 62, 108, 43, 20, 82, 32, 43, 20, 82, 66, 43, 21, 62, 32, 39, 21, 62, 66, 39, 36, 37, 72, 47, 36, 37, 107, 47, 0, 19, 82, 98, 43, 19, 82, 70, 43, 20, 67, 59, 43, 20, 67, 52, 43, 21, 62, 98, 76, 21, 62, 70, 76, 0, 19, 56, 59, 43, 19, 56, 52, 43, 20, 62, 79, 43, 20, 62, 71, 43, 21, 56, 59, 86, 21, 56, 52, 86, 0, 8, 114, 85, 43, 8, 114, 59, 43, 19, 61, 98, 43, 19, 61, 70, 43, 20, 63, 91, 43, 20, 

[0, 1, 2, 137, 14, 115, 31, 39, 14, 115, 65, 39, 49, 82, 31, 43, 49, 82, 65, 43, 30, 82, 79, 43, 30, 82, 71, 43, 33, 114, 31, 43, 33, 114, 65, 43, 53, 136, 42, 43, 53, 136, 75, 43, 38, 64, 34, 39, 38, 64, 128, 39, 0, 14, 82, 31, 39, 14, 82, 65, 39, 49, 95, 31, 43, 49, 95, 65, 43, 30, 114, 79, 43, 30, 114, 71, 43, 33, 136, 31, 43, 33, 136, 65, 43, 53, 82, 34, 43, 53, 82, 128, 43, 38, 82, 42, 39, 38, 82, 75, 39, 0, 14, 114, 31, 39, 14, 114, 65, 39, 49, 82, 31, 43, 49, 82, 65, 43, 30, 114, 79, 43, 30, 114, 71, 43, 33, 114, 31, 43, 33, 114, 65, 43, 53, 136, 42, 43, 53, 136, 75, 43, 38, 95, 34, 39, 38, 95, 128, 39, 0, 14, 136, 31, 39, 14, 136, 65, 39, 49, 82, 31, 43, 49, 82, 65, 43, 30, 114, 79, 43, 30, 114, 71, 43, 33, 115, 31, 43, 33, 115, 65, 43, 53, 114, 34, 43, 53, 114, 128, 43, 36, 64, 42, 39, 36, 64, 75, 39]
39
[0, 1, 2, 223, 18, 2, 223, 30, 2, 223, 30, 58, 17, 39, 30, 58, 32, 39, 53, 2, 223, 53, 63, 98, 39, 53, 63, 70, 39, 0, 1, 2, 223, 1, 63, 59, 39, 1, 63, 52, 39, 18, 2, 223, 18, 

[0, 1, 50, 194, 4, 58, 34, 27, 18, 61, 32, 7, 19, 56, 34, 83, 36, 61, 54, 13, 38, 61, 32, 13, 28, 58, 34, 13, 0, 1, 61, 52, 7, 4, 37, 55, 45, 18, 67, 52, 13, 19, 37, 55, 26, 36, 63, 34, 22, 28, 61, 52, 13, 0, 1, 61, 55, 13, 4, 61, 106, 45, 18, 56, 55, 13, 19, 37, 106, 46, 23, 56, 55, 22, 36, 56, 106, 7, 28, 56, 66, 7, 0, 4, 56, 65, 22, 8, 61, 106, 13, 18, 61, 65, 13, 19, 63, 55, 43, 20, 37, 52, 77, 0, 4, 67, 59, 94, 36, 61, 54, 39, 0, 4, 61, 29, 84, 0, 4, 37, 6, 39, 19, 37, 29, 46, 21, 61, 17, 46, 36, 37, 59, 39, 0, 19, 56, 34, 7, 49, 95, 70, 7, 20, 37, 34, 7, 30, 37, 32, 13, 21, 37, 34, 7, 33, 37, 32, 7, 23, 61, 54, 7, 23, 61, 31, 7]
47
[0, 1, 50, 121, 1, 67, 34, 7, 4, 114, 31, 7, 14, 95, 34, 57, 49, 61, 52, 7, 20, 95, 34, 22, 30, 67, 31, 22, 33, 67, 35, 39, 38, 64, 34, 43, 0, 1, 95, 31, 104, 53, 154, 29, 45, 28, 64, 31, 45, 0, 1, 67, 34, 22, 14, 67, 42, 43, 18, 95, 31, 45, 49, 67, 35, 22, 30, 5, 65, 13, 21, 67, 34, 7, 33, 34, 7, 33, 67, 35, 43, 53, 67, 52, 7, 38, 67, 31, 22, 0, 1, 67

[0, 1, 50, 194, 4, 58, 34, 27, 18, 61, 32, 7, 19, 56, 34, 83, 36, 61, 54, 13, 38, 61, 32, 13, 28, 58, 34, 13, 0, 1, 61, 52, 7, 4, 37, 55, 45, 18, 67, 52, 13, 19, 37, 55, 26, 36, 63, 34, 22, 28, 61, 52, 13, 0, 1, 61, 55, 13, 4, 61, 106, 45, 18, 56, 55, 13, 19, 37, 106, 46, 23, 56, 55, 22, 36, 56, 106, 7, 28, 56, 66, 7, 0, 4, 56, 65, 22, 8, 61, 106, 13, 18, 61, 65, 13, 19, 63, 55, 43, 20, 37, 52, 77, 0, 4, 67, 59, 94, 36, 61, 54, 39, 0, 4, 61, 29, 84, 0, 4, 37, 6, 39, 19, 37, 29, 46, 21, 61, 17, 46, 36, 37, 59, 39, 0, 19, 56, 34, 7, 49, 95, 70, 7, 20, 37, 34, 7, 30, 37, 32, 13, 21, 37, 34, 7, 33, 37, 32, 7, 23, 61, 54, 7, 23, 61, 31, 7]
54
[0, 1, 2, 223, 18, 2, 223, 30, 2, 223, 30, 58, 17, 39, 30, 58, 32, 39, 53, 2, 223, 53, 63, 98, 39, 53, 63, 70, 39, 0, 1, 2, 223, 1, 63, 59, 39, 1, 63, 52, 39, 18, 2, 223, 18, 56, 29, 43, 18, 56, 55, 43, 49, 61, 98, 43, 49, 61, 70, 43, 30, 2, 223, 30, 63, 59, 84, 30, 63, 52, 84, 53, 2, 223, 0, 1, 2, 223, 18, 2, 223, 30, 2, 223, 30, 63, 17, 39, 30, 63, 3

[0, 1, 50, 105, 4, 58, 16, 46, 18, 50, 105, 19, 5, 31, 60, 30, 50, 105, 53, 50, 105, 0, 1, 50, 105, 8, 56, 29, 13, 18, 50, 105, 18, 37, 59, 13, 19, 63, 10, 57, 30, 50, 105, 23, 61, 10, 7, 53, 50, 105, 53, 37, 91, 7, 36, 63, 16, 48, 0, 1, 50, 105, 18, 50, 105, 18, 56, 16, 13, 19, 37, 6, 77, 30, 50, 105, 23, 56, 10, 7, 53, 50, 105, 53, 56, 59, 7, 36, 58, 10, 46, 0, 1, 50, 105, 4, 5, 31, 46, 18, 50, 105, 19, 63, 10, 73, 30, 50, 105, 0, 4, 37, 16, 39, 4, 56, 17, 39, 19, 56, 59, 87, 19, 56, 31, 122, 0, 8, 67, 10, 7, 8, 37, 29, 7, 18, 61, 6, 7, 18, 56, 59, 7, 19, 61, 91, 24, 19, 56, 10, 24, 23, 37, 91, 7, 23, 37, 10, 7, 53, 37, 16, 7, 53, 67, 91, 13, 36, 61, 93, 24, 36, 63, 16, 24, 0, 18, 63, 132, 7, 18, 63, 16, 7, 19, 67, 129, 22, 19, 67, 6, 22, 49, 61, 93, 22, 49, 61, 10, 22, 20, 67, 85, 43, 20, 67, 59, 43, 21, 61, 93, 22, 21, 61, 10, 22, 33, 56, 129, 22, 33, 56, 6, 22, 23, 37, 130, 22, 23, 37, 17, 22, 36, 37, 85, 22, 36, 37, 59, 22, 38, 37, 12, 7, 38, 37, 29, 7, 28, 61, 16, 7, 28, 61, 31,

[0, 1, 2, 3, 1, 56, 42, 13, 4, 37, 42, 7, 14, 41, 31, 13, 8, 56, 42, 13, 18, 2, 3, 18, 37, 79, 13, 19, 56, 79, 13, 49, 63, 10, 13, 20, 56, 79, 13, 30, 2, 3, 30, 41, 10, 13, 21, 37, 10, 13, 33, 63, 6, 13, 23, 37, 10, 13, 53, 2, 3, 53, 61, 6, 13, 36, 37, 6, 13, 38, 56, 25, 13, 28, 63, 16, 13, 0, 1, 2, 3, 1, 56, 89, 39, 18, 2, 3, 18, 58, 65, 39, 30, 2, 3, 30, 58, 71, 57, 33, 58, 79, 22, 53, 2, 3, 53, 58, 10, 22, 38, 58, 6, 22, 0, 1, 2, 3, 1, 58, 10, 24, 18, 2, 3, 18, 63, 71, 46, 30, 2, 3, 30, 56, 35, 47, 33, 58, 6, 13, 23, 56, 10, 7, 23, 63, 6, 13, 53, 2, 3, 53, 37, 25, 22, 38, 41, 16, 22, 0, 1, 2, 3, 1, 37, 25, 73, 14, 58, 34, 43, 18, 2, 3, 18, 63, 42, 43, 49, 56, 31, 43, 30, 2, 3, 30, 37, 42, 47, 53, 2, 3, 0, 1, 2, 3, 1, 56, 42, 7, 14, 56, 42, 7, 8, 56, 31, 13, 18, 2, 3, 18, 63, 42, 22, 49, 56, 75, 13, 20, 58, 65, 13, 30, 2, 3, 30, 56, 75, 13, 21, 63, 65, 13, 33, 41, 71, 7, 53, 2, 3, 53, 41, 71, 27, 0, 1, 2, 3, 1, 63, 10, 7, 14, 63, 10, 13, 8, 58, 6, 13, 18, 2, 3, 18, 58, 10, 13, 49, 41

[0, 1, 50, 174, 1, 58, 69, 27, 18, 15, 69, 22, 49, 5, 70, 22, 30, 15, 72, 43, 33, 41, 72, 13, 23, 58, 70, 13, 53, 9, 72, 22, 38, 15, 42, 22, 0, 1, 41, 69, 45, 8, 63, 42, 13, 18, 15, 69, 7, 49, 41, 79, 7, 30, 41, 98, 45, 53, 15, 69, 22, 38, 58, 79, 43, 0, 1, 58, 98, 46, 18, 58, 69, 7, 49, 15, 79, 22, 30, 58, 91, 22, 53, 58, 25, 22, 38, 58, 79, 26, 0, 18, 58, 98, 48, 0, 1, 9, 69, 27, 18, 15, 69, 22, 49, 58, 70, 22, 30, 58, 72, 7, 33, 58, 72, 13, 23, 63, 70, 13, 53, 15, 72, 22, 38, 15, 42, 22, 0, 1, 41, 69, 43, 8, 63, 42, 13, 18, 58, 69, 22, 49, 58, 79, 7, 30, 15, 98, 27, 53, 58, 69, 22, 38, 58, 79, 22, 0, 1, 15, 98, 46, 18, 58, 69, 22, 49, 15, 98, 22, 30, 63, 91, 43, 38, 58, 91, 22, 0, 1, 58, 25, 22, 14, 58, 88, 104]
84
[0, 1, 2, 196, 1, 40, 79, 43, 14, 40, 69, 22, 18, 40, 42, 7, 49, 96, 72, 46, 33, 58, 79, 7, 53, 58, 79, 7, 38, 58, 79, 13, 28, 41, 69, 7, 0, 1, 58, 42, 7, 14, 58, 69, 43, 18, 40, 72, 102, 0, 49, 5, 79, 7, 30, 58, 79, 7, 33, 58, 79, 13, 33, 56, 69, 7, 53, 41, 42, 43, 38, 5

[0, 1, 50, 155, 8, 95, 52, 46, 20, 95, 106, 45, 23, 95, 107, 27, 28, 95, 66, 87, 0, 20, 95, 74, 84, 0, 20, 95, 52, 39, 23, 95, 74, 45, 28, 95, 106, 47, 0, 20, 95, 52, 76, 28, 95, 66, 11, 0, 20, 95, 107, 77, 28, 95, 66, 7, 0, 4, 95, 66, 22]
96
[0, 1, 50, 105, 4, 96, 35, 39, 19, 40, 71, 46, 21, 96, 74, 76, 0, 19, 96, 35, 7, 49, 9, 71, 13, 30, 96, 71, 76, 0, 19, 5, 71, 43, 20, 96, 74, 43, 21, 40, 75, 60, 0, 8, 40, 74, 22, 19, 40, 75, 43, 20, 96, 74, 43, 21, 96, 71, 60, 0, 19, 96, 35, 43, 20, 9, 71, 22, 21, 96, 74, 60, 0, 19, 5, 72, 43, 20, 96, 35, 43, 21, 58, 42, 73, 0, 19, 96, 35, 43, 20, 9, 71, 22, 21, 96, 74, 22, 23, 40, 75, 86, 0, 19, 96, 72, 43, 49, 40, 35, 43]
97
[0, 1, 2, 171, 53, 9, 42, 7, 38, 9, 42, 7, 0, 14, 9, 42, 7, 18, 9, 54, 39, 30, 9, 69, 22, 33, 5, 79, 22, 53, 9, 69, 39, 0, 18, 9, 98, 39, 53, 5, 79, 7, 38, 5, 79, 45, 0, 14, 5, 98, 22, 18, 5, 79, 22, 49, 15, 69, 7, 30, 9, 54, 22, 33, 41, 72, 7, 23, 58, 54, 13, 53, 5, 69, 27, 0, 53, 9, 79, 77, 0, 14, 15, 98, 43, 18, 41, 91, 

[0, 1, 50, 121, 4, 61, 42, 22, 8, 61, 42, 7, 18, 50, 121, 18, 61, 42, 7, 19, 37, 42, 22, 20, 61, 42, 22, 30, 50, 121, 30, 61, 42, 22, 21, 56, 42, 46, 53, 50, 121, 36, 61, 79, 22, 28, 63, 42, 22, 0, 1, 50, 121, 4, 61, 42, 22, 8, 61, 42, 22, 18, 50, 121, 19, 37, 72, 7, 49, 61, 42, 7, 20, 63, 31, 7, 30, 50, 121, 30, 61, 79, 13, 21, 37, 79, 39, 53, 50, 121, 36, 67, 25, 43, 28, 67, 6, 22, 0, 1, 50, 121, 4, 67, 25, 22, 8, 61, 91, 22, 18, 50, 121, 19, 37, 91, 22, 20, 61, 25, 22, 30, 50, 121, 21, 61, 6, 46, 53, 50, 121, 36, 61, 79, 22, 28, 37, 25, 13, 0, 1, 50, 121, 1, 61, 89, 7, 4, 61, 79, 22, 8, 61, 79, 22, 18, 50, 121, 19, 56, 31, 13, 49, 61, 79, 13, 20, 61, 79, 13, 30, 50, 121, 30, 67, 91, 13, 21, 61, 91, 46, 53, 50, 121, 36, 56, 71, 43, 28, 37, 74, 43, 0, 1, 50, 121, 4, 61, 75, 39, 18, 50, 121, 30, 50, 121]
107
[0, 1, 50, 121, 1, 67, 34, 7, 4, 114, 31, 7, 14, 95, 34, 57, 49, 61, 52, 7, 20, 95, 34, 22, 30, 67, 31, 22, 33, 67, 35, 46, 38, 67, 34, 43, 0, 1, 67, 31, 104, 53, 154, 34, 45, 0, 1

In [44]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

30480
30480
29040
29040
40200
40200
34620
34620
27720
27720
22200
22200
53340
53340
43440
43440
25800
25800
31680
31680
20400
20400
25560
25560
30900
30900
24060
24060
31680
31680
34500
34500
30720
30720
23040
23040
61500
61500
58140
58140
31380
31380
24000
24000


In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
vocab_folder = "vocab_outro/"
save_vocab(intro_field.vocab, vocab + 'intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab + 'solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + 'outro_vocab.pkl')

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    sentence = test_intro[i]
    #print(sentence)
    list_sentence = [int(x) for x in sentence.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence_beam(model, sentence, intro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    #print(translated_sentence)
    utils.write_midi(list_sentence, word2event, generated_outputs + "/src" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict_beam" + str(i)  + ".mid")
    print(i)
    if i == 10:
        break

In [None]:
beam = [2, 59, 119, 13, 212, 59, 212, 59, 59, 75, 59, 59, 13, 119, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 119, 59, 59, 59, 59, 166, 59, 59, 59, 13, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 59, 59, 212, 212, 13, 59, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 14, 59, 59, 212, 59, 212, 212, 59, 59, 59, 68, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 68, 59, 59, 212, 59, 59, 13, 59, 59, 59, 59, 59, 59, 97, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 158, 59, 59, 13, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 13, 212, 13, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 158, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 13, 13, 59, 59, 13, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 13, 59, 166, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 86, 59, 212, 212, 212, 59, 59, 59, 59, 212, 59, 86, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 212, 59, 59, 166, 59, 86, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 59, 59, 59, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 212, 212, 158, 59, 59, 13, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 158, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 212, 212, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 13, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 212, 86, 59, 59, 59, 158, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 212, 59, 13, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 68, 59, 13, 59, 59, 13, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 14, 59, 59, 59, 59, 59, 59, 13, 86, 59, 59, 59, 212, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 13, 59, 59, 59, 59, 68, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 212, 212, 13, 59, 166, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 212, 59, 59, 212, 59, 212, 59, 59, 13, 59, 59, 59, 59, 13, 59, 59, 14, 13, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 68, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 212, 59, 59, 13, 59, 59, 166, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 13, 59, 59, 212, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59]

translated_sentence1 = [solo_field.vocab.itos[idx] for idx in beam]
translated_sentence = [int(x) for x in translated_sentence1 if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']    
utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict1.mid")

In [None]:
translated_sentence

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    list_sentence = [int(x) for x in sentence.split(' ')]
    remi = [word2event[x] for x in list_sentence]
    print(remi)

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns