In [48]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_dynamic_alphabetical_split"
folder = "dynamic_alphabetical_models/intro_3rd"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
vocab = folder + "/vocab"

In [49]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

state = pickle.load(open('./state.pkl', 'rb'))
random.setstate(state)

In [50]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [51]:
event2word, word2event = pickle.load(open('dictionary_fixed.pkl', 'rb'))

In [52]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [53]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=1)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=1)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=1)

In [54]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    print(solo.transpose(1,0).size())

torch.Size([272, 8])
torch.Size([353, 8])
torch.Size([509, 8])
torch.Size([335, 8])
torch.Size([326, 8])
torch.Size([279, 8])
torch.Size([253, 8])
torch.Size([522, 8])
torch.Size([281, 8])
torch.Size([444, 8])
torch.Size([619, 8])
torch.Size([325, 8])
torch.Size([319, 8])
torch.Size([414, 8])


In [55]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [56]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [57]:
src_vocab_size = len(intro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [58]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4) #4e-5


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 11,079,428 trainable parameters


In [59]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [60]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [61]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, trg = intro.transpose(1,0), solo.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, trg = intro.transpose(1,0), solo.transpose(1,0)
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [62]:
def translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [63]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [64]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(5):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, intro, intro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [65]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+2) % 20 ==0:
        if check_mode_collapse(model) > 2:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 0m 6s
	Train Loss: 5.340 | Train PPL: 208.453
	 Val. Loss: 5.128 |  Val. PPL: 168.674
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 0m 6s
	Train Loss: 4.805 | Train PPL: 122.129
	 Val. Loss: 4.589 |  Val. PPL:  98.400
=> Saving checkpoint
Epoch: 03 | Time: 0m 6s
	Train Loss: 4.375 | Train PPL:  79.414
	 Val. Loss: 4.270 |  Val. PPL:  71.547
=> Saving checkpoint
Epoch: 04 | Time: 0m 6s
	Train Loss: 4.042 | Train PPL:  56.958
	 Val. Loss: 3.932 |  Val. PPL:  51.008
=> Saving checkpoint
Epoch: 05 | Time: 0m 6s
	Train Loss: 3.708 | Train PPL:  40.753
	 Val. Loss: 3.612 |  Val. PPL:  37.050
=> Saving checkpoint
Epoch: 06 | Time: 0m 6s
	Train Loss: 3.423 | Train PPL:  30.649
	 Val. Loss: 3.395 |  Val. PPL:  29.810
=> Saving checkpoint
Epoch: 07 | Time: 0m 6s
	Train Loss: 3.230 | Train PPL:  25.280
	 Val. Loss: 3.230 |  Val. PPL:  25.283
=> Saving checkpoint
Epoch: 08 | Time: 0m 6s
	Train Loss: 3.093 | Train PPL:  22.047
	 Val. Loss: 3.123 |  Val. PPL:  22.723


[0, 1, 50, 188, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 14, 63, 59, 7, 14, 63, 59, 7, 1, 63, 59, 7, 14, 63, 59, 7, 14, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 14, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 14, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 1, 63, 59, 7, 14, 63, 59, 7, 14, 63, 59, 7, 18, 63, 59, 7, 14, 63, 59, 7, 18, 63, 59, 7, 14, 63, 59, 7, 18, 63, 59, 7, 18, 63, 59, 7, 18, 63, 59, 7, 14, 63, 59, 7, 18, 63, 59, 7, 1, 63, 59, 7, 18, 63, 59, 7, 

Epoch: 20 | Time: 0m 6s
	Train Loss: 2.434 | Train PPL:  11.404
	 Val. Loss: 2.458 |  Val. PPL:  11.684
=> Saving checkpoint
=> Saving checkpoint
Epoch: 21 | Time: 0m 6s
	Train Loss: 2.393 | Train PPL:  10.950
	 Val. Loss: 2.426 |  Val. PPL:  11.312
=> Saving checkpoint
=> Saving checkpoint
Epoch: 22 | Time: 0m 6s
	Train Loss: 2.361 | Train PPL:  10.607
	 Val. Loss: 2.401 |  Val. PPL:  11.036
=> Saving checkpoint
Epoch: 23 | Time: 0m 6s
	Train Loss: 2.325 | Train PPL:  10.223
	 Val. Loss: 2.379 |  Val. PPL:  10.798
=> Saving checkpoint
Epoch: 24 | Time: 0m 6s
	Train Loss: 2.293 | Train PPL:   9.901
	 Val. Loss: 2.345 |  Val. PPL:  10.437
=> Saving checkpoint
Epoch: 25 | Time: 0m 6s
	Train Loss: 2.263 | Train PPL:   9.609
	 Val. Loss: 2.330 |  Val. PPL:  10.281
=> Saving checkpoint
Epoch: 26 | Time: 0m 6s
	Train Loss: 2.229 | Train PPL:   9.289
	 Val. Loss: 2.295 |  Val. PPL:   9.920
=> Saving checkpoint
Epoch: 27 | Time: 0m 6s
	Train Loss: 2.200 | Train PPL:   9.023
	 Val. Loss: 2.273 

[0, 1, 50, 174, 1, 56, 54, 43, 14, 56, 54, 43, 18, 56, 54, 43, 49, 56, 54, 43, 30, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 18, 56, 54, 43, 49, 56, 54, 43, 30, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 18, 56, 54, 43, 49, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 14, 56, 54, 43, 18, 56, 54, 43, 49, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 18, 56, 54, 43, 49, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 14, 56, 54, 43, 49, 56, 54, 43, 49, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 18, 56, 54, 43, 49, 56, 54, 43, 33, 56, 54, 43, 53, 56, 54, 43, 38, 56, 54, 43, 0, 1, 56, 54, 43, 14, 56, 54, 43, 14, 56, 54, 43]
Epoch: 40 | Time: 0m 6s
	Train Loss: 1.917 | Train PPL:   6

[0, 1, 50, 188, 4, 5, 10, 7, 8, 5, 10, 7, 18, 5, 10, 7, 19, 5, 10, 7, 49, 5, 10, 7, 20, 5, 10, 7, 30, 5, 10, 7, 21, 5, 10, 7, 23, 5, 10, 7, 53, 5, 10, 7, 36, 5, 10, 7, 38, 5, 10, 7, 28, 5, 10, 7, 0, 1, 5, 10, 7, 4, 5, 10, 7, 14, 5, 10, 7, 8, 5, 10, 7, 18, 5, 10, 7, 19, 5, 10, 7, 49, 5, 10, 7, 20, 5, 10, 7, 30, 5, 10, 7, 21, 5, 10, 7, 33, 5, 10, 7, 23, 5, 10, 7, 53, 5, 10, 7, 36, 5, 10, 7, 38, 5, 10, 7, 28, 5, 10, 7, 0, 1, 5, 10, 7, 4, 5, 10, 7, 4, 5, 10, 7, 14, 5, 10, 7, 8, 5, 10, 7, 18, 5, 10, 7, 18, 5, 10, 7, 19, 5, 10, 7, 49, 5, 10, 7, 20, 5, 10, 7, 30, 5, 10, 7, 21, 5, 10, 7, 33, 5, 10, 7, 23, 5, 10, 7, 53, 5, 10, 7, 36, 5, 10, 7, 36, 5, 10, 7, 38, 5, 10, 7, 28, 5, 10, 7, 28, 5, 10, 7, 0, 1, 5, 10, 7, 4, 5, 10, 7, 4, 5, 10, 7, 14, 5, 10, 7, 8, 5, 10, 7, 18, 5, 10, 7, 19, 5, 10, 7, 49, 5, 10, 7, 20, 5, 10, 7, 21, 5, 10, 7, 21, 5, 10, 7, 33, 5, 10, 7, 23, 5, 10, 7, 53, 5, 10, 7, 36, 5, 10, 7, 38, 5, 10, 7, 28, 5, 10, 7, 0, 1, 5, 10, 7, 4, 5, 10, 7, 4, 5, 10, 7, 14, 5, 10, 7, 8, 5, 10

[0, 1, 2, 137, 14, 61, 54, 43, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 30, 2, 137, 30, 2, 137, 33, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 38, 61, 54, 43, 28, 61, 54, 43, 0, 1, 2, 137, 1, 2, 137, 1, 2, 137, 1, 2, 137, 14, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 19, 61, 54, 43, 49, 2, 137, 30, 2, 137, 30, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 38, 2, 137, 0, 1, 2, 137, 1, 2, 137, 1, 2, 137, 14, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 30, 2, 137, 30, 2, 137, 30, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 38, 2, 137, 0, 1, 2, 137, 1, 2, 137, 1, 2, 137, 1, 2, 137, 1, 2, 137, 14, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 18, 2, 137, 30, 2, 137, 30, 2, 137, 30, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 53, 2, 137, 

[0, 1, 50, 121, 4, 67, 42, 22, 8, 67, 42, 22, 19, 67, 42, 22, 20, 67, 42, 22, 21, 67, 42, 22, 23, 67, 42, 22, 36, 67, 75, 22, 28, 67, 75, 22, 0, 4, 67, 75, 22, 8, 67, 65, 22, 19, 67, 75, 22, 20, 67, 65, 22, 21, 67, 71, 22, 23, 67, 71, 22, 36, 67, 71, 22, 28, 67, 75, 22, 0, 4, 67, 71, 22, 8, 67, 71, 22, 19, 67, 35, 22, 20, 67, 71, 22, 21, 67, 35, 22, 23, 67, 35, 22, 36, 67, 71, 22, 28, 67, 71, 22, 0, 4, 67, 35, 22, 8, 67, 35, 22, 19, 67, 71, 22, 20, 67, 35, 22, 21, 67, 71, 22, 23, 67, 71, 22, 36, 67, 71, 22, 28, 67, 35, 22, 28, 67, 42, 22, 0, 4, 67, 71, 22, 8, 67, 35, 22, 8, 67, 71, 22, 19, 67, 71, 22, 20, 67, 71, 22, 21, 67, 35, 22, 23, 67, 71, 22, 36, 67, 71, 22, 28, 67, 71, 22, 0, 4, 67, 71, 22, 8, 67, 71, 22, 8, 67, 35, 22, 19, 67, 71, 22, 19, 67, 71, 22, 20, 67, 71, 22, 21, 67, 35, 22, 23, 67, 35, 22, 36, 67, 35, 22, 28, 67, 35, 22, 28, 67, 71, 22, 0, 4, 67, 35, 22, 8, 67, 71, 22, 8, 67, 71, 22, 19, 67, 71, 22, 20, 67, 71, 22, 21, 67, 35, 22, 21, 67, 71, 22, 23, 67, 71, 22, 23, 67,

Epoch: 109 | Time: 0m 6s
	Train Loss: 1.304 | Train PPL:   3.684
	 Val. Loss: 2.431 |  Val. PPL:  11.368
Epoch: 110 | Time: 0m 6s
	Train Loss: 1.298 | Train PPL:   3.663
	 Val. Loss: 2.414 |  Val. PPL:  11.184
Epoch: 111 | Time: 0m 6s
	Train Loss: 1.290 | Train PPL:   3.634
	 Val. Loss: 2.447 |  Val. PPL:  11.550
Epoch: 112 | Time: 0m 6s
	Train Loss: 1.289 | Train PPL:   3.628
	 Val. Loss: 2.437 |  Val. PPL:  11.436
Epoch: 113 | Time: 0m 6s
	Train Loss: 1.275 | Train PPL:   3.578
	 Val. Loss: 2.438 |  Val. PPL:  11.449
Epoch: 114 | Time: 0m 6s
	Train Loss: 1.269 | Train PPL:   3.556
	 Val. Loss: 2.477 |  Val. PPL:  11.902
Epoch: 115 | Time: 0m 6s
	Train Loss: 1.265 | Train PPL:   3.544
	 Val. Loss: 2.474 |  Val. PPL:  11.867
Epoch: 116 | Time: 0m 6s
	Train Loss: 1.259 | Train PPL:   3.520
	 Val. Loss: 2.470 |  Val. PPL:  11.828
Epoch: 117 | Time: 0m 6s
	Train Loss: 1.253 | Train PPL:   3.499
	 Val. Loss: 2.536 |  Val. PPL:  12.624
Epoch: 118 | Time: 0m 6s
	Train Loss: 1.246 | Train PPL

Epoch: 120 | Time: 0m 6s
	Train Loss: 1.237 | Train PPL:   3.444
	 Val. Loss: 2.499 |  Val. PPL:  12.176
=> Saving checkpoint
Epoch: 121 | Time: 0m 6s
	Train Loss: 1.223 | Train PPL:   3.398
	 Val. Loss: 2.518 |  Val. PPL:  12.398
=> Saving checkpoint
Epoch: 122 | Time: 0m 6s
	Train Loss: 1.220 | Train PPL:   3.389
	 Val. Loss: 2.567 |  Val. PPL:  13.025
Epoch: 123 | Time: 0m 6s
	Train Loss: 1.209 | Train PPL:   3.349
	 Val. Loss: 2.512 |  Val. PPL:  12.325
Epoch: 124 | Time: 0m 6s
	Train Loss: 1.209 | Train PPL:   3.350
	 Val. Loss: 2.549 |  Val. PPL:  12.789
Epoch: 125 | Time: 0m 6s
	Train Loss: 1.205 | Train PPL:   3.338
	 Val. Loss: 2.574 |  Val. PPL:  13.117
Epoch: 126 | Time: 0m 6s
	Train Loss: 1.193 | Train PPL:   3.298
	 Val. Loss: 2.585 |  Val. PPL:  13.265
Epoch: 127 | Time: 0m 6s
	Train Loss: 1.187 | Train PPL:   3.276
	 Val. Loss: 2.578 |  Val. PPL:  13.167
Epoch: 128 | Time: 0m 6s
	Train Loss: 1.180 | Train PPL:   3.254
	 Val. Loss: 2.581 |  Val. PPL:  13.206
Epoch: 129 | 

Epoch: 149 | Time: 0m 6s
	Train Loss: 1.050 | Train PPL:   2.859
	 Val. Loss: 2.813 |  Val. PPL:  16.656
Epoch: 150 | Time: 0m 6s
	Train Loss: 1.048 | Train PPL:   2.853
	 Val. Loss: 2.809 |  Val. PPL:  16.590
Epoch: 151 | Time: 0m 6s
	Train Loss: 1.037 | Train PPL:   2.821
	 Val. Loss: 2.821 |  Val. PPL:  16.798
Epoch: 152 | Time: 0m 6s
	Train Loss: 1.033 | Train PPL:   2.810
	 Val. Loss: 2.880 |  Val. PPL:  17.814
Epoch: 153 | Time: 0m 6s
	Train Loss: 1.027 | Train PPL:   2.792
	 Val. Loss: 2.874 |  Val. PPL:  17.708
Epoch: 154 | Time: 0m 6s
	Train Loss: 1.026 | Train PPL:   2.791
	 Val. Loss: 2.871 |  Val. PPL:  17.652
Epoch: 155 | Time: 0m 6s
	Train Loss: 1.021 | Train PPL:   2.776
	 Val. Loss: 2.908 |  Val. PPL:  18.317
Epoch: 156 | Time: 0m 6s
	Train Loss: 1.015 | Train PPL:   2.758
	 Val. Loss: 2.914 |  Val. PPL:  18.436
Epoch: 157 | Time: 0m 6s
	Train Loss: 1.012 | Train PPL:   2.750
	 Val. Loss: 2.914 |  Val. PPL:  18.428
Epoch: 158 | Time: 0m 6s
	Train Loss: 1.008 | Train PPL

Epoch: 172 | Time: 0m 6s
	Train Loss: 0.936 | Train PPL:   2.551
	 Val. Loss: 3.114 |  Val. PPL:  22.518
Epoch: 173 | Time: 0m 6s
	Train Loss: 0.928 | Train PPL:   2.529
	 Val. Loss: 3.098 |  Val. PPL:  22.144
Epoch: 174 | Time: 0m 6s
	Train Loss: 0.923 | Train PPL:   2.517
	 Val. Loss: 3.129 |  Val. PPL:  22.845
Epoch: 175 | Time: 0m 6s
	Train Loss: 0.926 | Train PPL:   2.526
	 Val. Loss: 3.107 |  Val. PPL:  22.344
Epoch: 176 | Time: 0m 6s
	Train Loss: 0.914 | Train PPL:   2.493
	 Val. Loss: 3.078 |  Val. PPL:  21.710
Epoch: 177 | Time: 0m 6s
	Train Loss: 0.900 | Train PPL:   2.460
	 Val. Loss: 3.122 |  Val. PPL:  22.700
Epoch: 178 | Time: 0m 6s
	Train Loss: 0.898 | Train PPL:   2.455
	 Val. Loss: 3.210 |  Val. PPL:  24.784
Epoch: 179 | Time: 0m 6s
	Train Loss: 0.892 | Train PPL:   2.439
	 Val. Loss: 3.184 |  Val. PPL:  24.144
[0, 1, 50, 188, 18, 61, 25, 43, 49, 37, 91, 43, 30, 61, 25, 43, 33, 37, 10, 43, 53, 37, 79, 43, 38, 61, 25, 43, 0, 1, 61, 91, 43, 14, 63, 25, 43, 18, 61, 25, 43

[0, 1, 50, 105, 4, 58, 91, 43, 8, 40, 10, 43, 19, 9, 10, 57, 23, 40, 10, 7, 53, 5, 10, 13, 36, 41, 79, 46, 0, 4, 9, 10, 22, 8, 41, 79, 39, 23, 41, 79, 26, 0, 4, 9, 6, 46, 20, 41, 79, 13, 53, 63, 79, 13, 36, 41, 79, 27, 0, 4, 41, 79, 13, 14, 58, 10, 13, 8, 9, 10, 13, 18, 63, 79, 13, 19, 63, 79, 13, 20, 63, 10, 13, 20, 63, 10, 77, 23, 5, 6, 13, 53, 58, 16, 13, 36, 63, 25, 13, 38, 63, 6, 13, 28, 63, 16, 13, 28, 63, 25, 22, 0, 1, 40, 16, 13, 1, 63, 25, 84]
[0, 1, 50, 135, 4, 61, 79, 39, 20, 95, 98, 26, 21, 61, 59, 7, 33, 61, 98, 7, 23, 61, 98, 7, 36, 61, 59, 7, 38, 61, 98, 7, 28, 61, 98, 7, 0, 4, 61, 98, 7, 8, 62, 91, 26, 20, 62, 59, 7, 21, 61, 98, 7, 33, 61, 91, 45, 23, 62, 59, 7, 53, 61, 88, 43, 38, 61, 98, 45, 28, 61, 99, 22, 0, 1, 61, 88, 7, 4, 62, 99, 7, 14, 61, 89, 26, 20, 95, 88, 43, 21, 61, 88, 43, 23, 61, 79, 26, 28, 61, 98, 7, 0, 1, 61, 98, 7, 4, 61, 59, 7, 14, 61, 59, 116, 20, 95, 69, 77, 38, 61, 74, 45, 28, 62, 74, 7, 0, 1, 61, 69, 7, 4, 64, 98, 57, 4, 61, 74, 77, 20, 62, 59, 7

Epoch: 211 | Time: 0m 6s
	Train Loss: 0.771 | Train PPL:   2.161
	 Val. Loss: 3.502 |  Val. PPL:  33.168
Epoch: 212 | Time: 0m 6s
	Train Loss: 0.757 | Train PPL:   2.133
	 Val. Loss: 3.575 |  Val. PPL:  35.708
Epoch: 213 | Time: 0m 6s
	Train Loss: 0.753 | Train PPL:   2.123
	 Val. Loss: 3.516 |  Val. PPL:  33.646
Epoch: 214 | Time: 0m 6s
	Train Loss: 0.753 | Train PPL:   2.122
	 Val. Loss: 3.558 |  Val. PPL:  35.108
Epoch: 215 | Time: 0m 6s
	Train Loss: 0.739 | Train PPL:   2.094
	 Val. Loss: 3.566 |  Val. PPL:  35.358
Epoch: 216 | Time: 0m 6s
	Train Loss: 0.750 | Train PPL:   2.117
	 Val. Loss: 3.563 |  Val. PPL:  35.271
Epoch: 217 | Time: 0m 6s
	Train Loss: 0.745 | Train PPL:   2.107
	 Val. Loss: 3.585 |  Val. PPL:  36.047
Epoch: 218 | Time: 0m 6s
	Train Loss: 0.746 | Train PPL:   2.108
	 Val. Loss: 3.579 |  Val. PPL:  35.844
Epoch: 219 | Time: 0m 6s
	Train Loss: 0.735 | Train PPL:   2.085
	 Val. Loss: 3.521 |  Val. PPL:  33.828
[0, 1, 50, 105, 4, 58, 91, 43, 8, 40, 10, 43, 19, 96, 3

Epoch: 227 | Time: 0m 6s
	Train Loss: 0.709 | Train PPL:   2.031
	 Val. Loss: 3.686 |  Val. PPL:  39.871
Epoch: 228 | Time: 0m 6s
	Train Loss: 0.699 | Train PPL:   2.012
	 Val. Loss: 3.679 |  Val. PPL:  39.621
Epoch: 229 | Time: 0m 6s
	Train Loss: 0.688 | Train PPL:   1.989
	 Val. Loss: 3.770 |  Val. PPL:  43.370
Epoch: 230 | Time: 0m 6s
	Train Loss: 0.693 | Train PPL:   2.000
	 Val. Loss: 3.737 |  Val. PPL:  41.965
Epoch: 231 | Time: 0m 6s
	Train Loss: 0.686 | Train PPL:   1.986
	 Val. Loss: 3.728 |  Val. PPL:  41.607
Epoch: 232 | Time: 0m 6s
	Train Loss: 0.681 | Train PPL:   1.977
	 Val. Loss: 3.733 |  Val. PPL:  41.799
Epoch: 233 | Time: 0m 6s
	Train Loss: 0.690 | Train PPL:   1.994
	 Val. Loss: 3.734 |  Val. PPL:  41.856
Epoch: 234 | Time: 0m 6s
	Train Loss: 0.687 | Train PPL:   1.988
	 Val. Loss: 3.742 |  Val. PPL:  42.176
Epoch: 235 | Time: 0m 6s
	Train Loss: 0.680 | Train PPL:   1.973
	 Val. Loss: 3.775 |  Val. PPL:  43.599
Epoch: 236 | Time: 0m 6s
	Train Loss: 0.687 | Train PPL

[0, 1, 50, 135, 4, 61, 72, 26, 20, 95, 54, 7, 30, 61, 42, 13, 21, 61, 72, 26, 28, 61, 42, 13, 0, 1, 95, 54, 13, 4, 37, 54, 104, 28, 62, 54, 7, 0, 1, 67, 42, 7, 4, 62, 72, 13, 14, 95, 71, 11, 28, 95, 42, 7, 0, 1, 114, 54, 7, 4, 15, 69, 73, 21, 67, 42, 78, 28, 62, 72, 7, 0, 1, 58, 72, 87]
[0, 1, 50, 174, 18, 50, 174, 18, 63, 69, 43, 20, 15, 79, 43, 30, 50, 174, 33, 41, 59, 116, 53, 50, 174, 53, 95, 74, 7, 36, 82, 72, 7, 38, 95, 70, 7, 28, 114, 35, 7, 0, 1, 50, 174, 1, 95, 52, 7, 4, 114, 72, 7, 14, 61, 91, 22, 14, 136, 70, 7, 8, 114, 72, 7, 8, 95, 35, 7, 18, 50, 174, 18, 61, 70, 7, 18, 61, 74, 7, 20, 62, 72, 7, 30, 50, 174, 33, 63, 71, 13, 33, 56, 72, 7, 23, 114, 32, 7, 53, 50, 174, 53, 114, 70, 13, 36, 114, 52, 13, 38, 95, 54, 7, 28, 82, 72, 7, 0, 1, 50, 174, 1, 50, 174, 1, 50, 174, 1, 95, 72, 7, 4, 82, 54, 7, 8, 82, 32, 7, 18, 50, 174, 18, 50, 174, 18, 58, 32, 7, 19, 64, 72, 7, 49, 61, 52, 7, 20, 114, 72, 13, 30, 50, 174, 30, 50, 174, 53, 50, 174, 53, 63, 74, 13, 53, 63, 74, 13, 36, 114

[0, 1, 2, 137, 14, 115, 31, 39, 14, 115, 65, 39, 49, 82, 31, 43, 30, 82, 65, 43, 30, 82, 79, 43, 33, 82, 71, 43, 53, 114, 31, 43, 38, 114, 65, 43, 0, 14, 82, 31, 43, 14, 82, 65, 43, 18, 114, 79, 43, 49, 136, 71, 43, 30, 114, 31, 43, 33, 136, 71, 43, 53, 136, 79, 43, 38, 82, 34, 43, 0, 14, 114, 71, 43, 14, 82, 34, 43, 18, 114, 128, 39, 49, 64, 107, 43, 30, 114, 79, 43, 33, 114, 71, 39, 38, 114, 31, 39, 38, 64, 75, 43, 0, 14, 82, 31, 43, 18, 136, 31, 43, 49, 136, 65, 43, 30, 114, 79, 43, 33, 136, 31, 43, 53, 136, 31, 43, 38, 136, 34, 43, 38, 114, 79, 43, 0, 1, 136, 34, 39, 1, 136, 75, 43, 14, 114, 71, 43, 14, 136, 71, 43, 18, 136, 34, 39, 49, 114, 71, 43, 49, 136, 31, 43, 30, 114, 71, 43, 33, 136, 31, 43, 33, 114, 79, 43, 53, 136, 31, 43, 38, 82, 34, 43, 38, 136, 34, 43, 0, 1, 114, 34, 43, 14, 82, 42, 43, 14, 114, 31, 43, 14, 82, 29, 43, 14, 136, 31, 43, 18, 136, 34, 43, 49, 82, 32, 43, 49, 114, 6, 43, 30, 82, 34, 39, 33, 115, 31, 43, 53, 136, 34, 43, 53, 82, 34, 43, 38, 136, 34, 43, 38,

Epoch: 301 | Time: 0m 6s
	Train Loss: 0.519 | Train PPL:   1.680
	 Val. Loss: 4.295 |  Val. PPL:  73.332
=> Saving checkpoint
Epoch: 302 | Time: 0m 6s
	Train Loss: 0.512 | Train PPL:   1.668
	 Val. Loss: 4.364 |  Val. PPL:  78.580
Epoch: 303 | Time: 0m 6s
	Train Loss: 0.510 | Train PPL:   1.666
	 Val. Loss: 4.256 |  Val. PPL:  70.509
Epoch: 304 | Time: 0m 6s
	Train Loss: 0.510 | Train PPL:   1.666
	 Val. Loss: 4.314 |  Val. PPL:  74.728
Epoch: 305 | Time: 0m 6s
	Train Loss: 0.502 | Train PPL:   1.652
	 Val. Loss: 4.339 |  Val. PPL:  76.602
Epoch: 306 | Time: 0m 6s
	Train Loss: 0.503 | Train PPL:   1.653
	 Val. Loss: 4.395 |  Val. PPL:  81.060
Epoch: 307 | Time: 0m 6s
	Train Loss: 0.501 | Train PPL:   1.651
	 Val. Loss: 4.342 |  Val. PPL:  76.824
Epoch: 308 | Time: 0m 6s
	Train Loss: 0.497 | Train PPL:   1.644
	 Val. Loss: 4.394 |  Val. PPL:  80.924
Epoch: 309 | Time: 0m 6s
	Train Loss: 0.495 | Train PPL:   1.640
	 Val. Loss: 4.328 |  Val. PPL:  75.785
Epoch: 310 | Time: 0m 6s
	Train Lo

[0, 1, 50, 135, 4, 61, 72, 26, 20, 95, 54, 7, 30, 61, 42, 13, 21, 61, 72, 26, 28, 61, 42, 13, 0, 1, 95, 54, 13, 4, 37, 54, 13, 14, 62, 54, 13, 8, 95, 54, 7, 18, 67, 69, 13, 19, 37, 79, 13, 49, 67, 79, 13, 20, 62, 54, 13, 30, 56, 69, 7, 21, 61, 42, 13, 33, 56, 79, 13, 23, 61, 69, 13, 53, 62, 54, 7, 36, 67, 69, 7, 38, 67, 79, 7, 28, 61, 79, 13, 0, 1, 61, 98, 116, 49, 67, 79, 13, 20, 61, 69, 7, 30, 58, 54, 7, 21, 67, 70, 13, 38, 67, 72, 13, 33, 67, 70, 7, 28, 67, 42, 13, 0, 1, 61, 72, 26, 4, 61, 79, 22, 8, 62, 54, 22, 19, 37, 71, 22, 20, 67, 54, 7, 30, 67, 54, 7, 30, 67, 42, 57, 36, 67, 79, 7, 38, 67, 54, 13, 28, 56, 69, 7, 0, 1, 56, 69, 26, 21, 67, 69, 7, 36, 56, 69, 7, 38, 58, 42, 22, 0, 1, 67, 79, 26, 49, 67, 79, 27, 20, 58, 69, 7, 21, 58, 54, 13, 33, 58, 42, 13, 23, 67, 69, 7, 36, 58, 54, 7, 38, 56, 54, 7, 28, 58, 69, 7, 0, 1, 62, 54, 13, 4, 15, 54, 7, 4, 67, 69, 13, 4, 62, 79, 11, 4, 58, 79, 39, 49, 58, 72, 13, 20, 61, 79, 22, 20, 58, 29, 13, 20, 62, 79, 11, 36, 58, 71, 13, 30, 56, 6

Epoch: 360 | Time: 0m 6s
	Train Loss: 0.407 | Train PPL:   1.503
	 Val. Loss: 4.734 |  Val. PPL: 113.762
=> Saving checkpoint
Epoch: 361 | Time: 0m 6s
	Train Loss: 0.418 | Train PPL:   1.519
	 Val. Loss: 4.695 |  Val. PPL: 109.388
=> Saving checkpoint
Epoch: 362 | Time: 0m 6s
	Train Loss: 0.418 | Train PPL:   1.519
	 Val. Loss: 4.670 |  Val. PPL: 106.665
Epoch: 363 | Time: 0m 6s
	Train Loss: 0.420 | Train PPL:   1.521
	 Val. Loss: 4.698 |  Val. PPL: 109.674
Epoch: 364 | Time: 0m 6s
	Train Loss: 0.417 | Train PPL:   1.518
	 Val. Loss: 4.689 |  Val. PPL: 108.724
Epoch: 365 | Time: 0m 6s
	Train Loss: 0.407 | Train PPL:   1.503
	 Val. Loss: 4.718 |  Val. PPL: 111.967
Epoch: 366 | Time: 0m 6s
	Train Loss: 0.413 | Train PPL:   1.512
	 Val. Loss: 4.675 |  Val. PPL: 107.279
Epoch: 367 | Time: 0m 6s
	Train Loss: 0.412 | Train PPL:   1.510
	 Val. Loss: 4.726 |  Val. PPL: 112.807
Epoch: 368 | Time: 0m 6s
	Train Loss: 0.400 | Train PPL:   1.492
	 Val. Loss: 4.780 |  Val. PPL: 119.108
Epoch: 369 | 

Epoch: 397 | Time: 0m 6s
	Train Loss: 0.380 | Train PPL:   1.462
	 Val. Loss: 4.937 |  Val. PPL: 139.348
Epoch: 398 | Time: 0m 6s
	Train Loss: 0.359 | Train PPL:   1.432
	 Val. Loss: 4.887 |  Val. PPL: 132.520
Epoch: 399 | Time: 0m 6s
	Train Loss: 0.361 | Train PPL:   1.435
	 Val. Loss: 4.960 |  Val. PPL: 142.622
[0, 1, 50, 105, 4, 58, 91, 45, 8, 40, 10, 43, 19, 96, 31, 57, 23, 63, 69, 7, 53, 5, 31, 22, 38, 56, 69, 7, 28, 41, 10, 26, 0, 19, 40, 59, 77, 23, 41, 91, 13, 53, 41, 91, 24, 0, 49, 41, 91, 22, 30, 5, 79, 13, 21, 96, 79, 39, 38, 63, 91, 22, 0, 1, 63, 79, 13, 4, 41, 79, 7, 8, 58, 79, 13, 18, 58, 79, 7, 19, 63, 25, 22, 20, 56, 91, 7, 30, 63, 25, 22, 33, 56, 91, 7, 23, 61, 25, 7, 53, 63, 91, 7, 36, 56, 25, 22, 28, 56, 91, 7, 0, 1, 96, 10, 57, 19, 96, 42, 7, 20, 58, 42, 13, 30, 58, 42, 13, 21, 40, 42, 27, 36, 96, 72, 26, 0, 4, 5, 72, 13, 14, 5, 10, 45, 19, 96, 35, 22, 20, 96, 35, 7, 30, 58, 70, 7, 21, 96, 34, 39, 36, 96, 35, 13, 28, 96, 35, 7, 0, 1, 40, 70, 13, 4, 96, 34, 27, 19, 9

[0, 1, 50, 135, 4, 5, 35, 44, 23, 9, 55, 43, 36, 9, 34, 43, 28, 15, 35, 43, 0, 4, 9, 34, 43, 8, 9, 32, 116, 36, 9, 35, 22, 28, 9, 55, 22, 0, 1, 9, 65, 60, 53, 15, 55, 22, 38, 9, 35, 43, 0, 4, 15, 34, 60, 36, 41, 17, 22, 36, 41, 32, 22, 28, 15, 6, 43, 28, 15, 34, 43, 0, 4, 41, 10, 43, 4, 41, 35, 43, 8, 41, 6, 22, 8, 41, 34, 22, 19, 58, 6, 22, 19, 58, 34, 22, 20, 41, 17, 22, 20, 41, 32, 22, 21, 15, 17, 46, 21, 15, 54, 46, 36, 5, 69, 46, 36, 5, 65, 46, 0, 4, 15, 29, 24, 4, 5, 55, 24, 20, 5, 31, 43, 20, 5, 65, 43, 21, 41, 29, 47, 21, 41, 55, 47, 0, 4, 5, 6, 60, 4, 5, 34, 60, 36, 41, 29, 22, 36, 41, 55, 22, 28, 58, 10, 43, 28, 58, 35, 43, 0, 1, 15, 6, 112, 1, 15, 34, 112]
[0, 1, 50, 135, 4, 5, 35, 44, 23, 9, 55, 43, 36, 9, 34, 43, 28, 9, 35, 43, 0, 4, 9, 34, 43, 8, 9, 32, 116, 36, 9, 35, 22, 28, 9, 55, 22, 0, 1, 9, 65, 60, 53, 15, 55, 22, 38, 9, 35, 43, 0, 4, 15, 34, 60, 36, 41, 17, 22, 36, 41, 32, 22, 28, 15, 6, 43, 28, 15, 34, 43, 0, 4, 41, 10, 43, 4, 41, 35, 43, 8, 41, 6, 22, 8, 41, 34, 

[0, 1, 2, 137, 14, 115, 31, 39, 14, 115, 65, 39, 49, 82, 31, 43, 49, 82, 65, 43, 30, 82, 79, 43, 30, 82, 71, 43, 33, 114, 31, 43, 33, 114, 65, 43, 53, 136, 42, 43, 53, 136, 75, 43, 38, 64, 34, 39, 38, 64, 128, 39, 0, 14, 82, 31, 39, 14, 82, 65, 39, 49, 95, 31, 43, 49, 95, 65, 43, 30, 114, 79, 43, 30, 114, 71, 43, 33, 136, 31, 43, 33, 136, 65, 43, 53, 82, 34, 43, 53, 82, 128, 43, 38, 82, 42, 39, 38, 82, 75, 39, 0, 14, 114, 31, 39, 14, 114, 65, 39, 49, 82, 31, 43, 49, 82, 65, 43, 30, 114, 79, 43, 30, 114, 71, 43, 33, 114, 31, 43, 33, 114, 65, 43, 53, 136, 42, 43, 53, 136, 75, 43, 38, 95, 34, 39, 38, 95, 128, 39, 0, 14, 136, 31, 39, 14, 136, 65, 39, 49, 82, 31, 43, 49, 82, 65, 43, 30, 114, 79, 43, 30, 114, 71, 43, 33, 115, 31, 43, 33, 115, 65, 43, 53, 114, 34, 43, 53, 114, 128, 43, 36, 64, 42, 39, 36, 64, 75, 39]
Epoch: 440 | Time: 0m 6s
	Train Loss: 0.335 | Train PPL:   1.397
	 Val. Loss: 5.059 |  Val. PPL: 157.490
=> Saving checkpoint
Epoch: 441 | Time: 0m 6s
	Train Loss: 0.318 | Train 

Epoch: 475 | Time: 0m 6s
	Train Loss: 0.296 | Train PPL:   1.344
	 Val. Loss: 5.215 |  Val. PPL: 184.024
Epoch: 476 | Time: 0m 6s
	Train Loss: 0.293 | Train PPL:   1.340
	 Val. Loss: 5.219 |  Val. PPL: 184.660
Epoch: 477 | Time: 0m 6s
	Train Loss: 0.296 | Train PPL:   1.344
	 Val. Loss: 5.209 |  Val. PPL: 182.985
Epoch: 478 | Time: 0m 6s
	Train Loss: 0.295 | Train PPL:   1.343
	 Val. Loss: 5.183 |  Val. PPL: 178.294
Epoch: 479 | Time: 0m 6s
	Train Loss: 0.285 | Train PPL:   1.330
	 Val. Loss: 5.291 |  Val. PPL: 198.456
[0, 1, 50, 105, 4, 58, 91, 45, 8, 40, 10, 43, 19, 96, 31, 57, 23, 63, 69, 7, 53, 5, 31, 22, 38, 56, 69, 7, 28, 41, 10, 26, 0, 19, 40, 59, 77, 23, 41, 91, 13, 53, 41, 91, 24, 0, 49, 41, 91, 22, 30, 5, 79, 13, 21, 96, 79, 39, 38, 63, 91, 22, 0, 1, 63, 79, 13, 4, 41, 79, 7, 8, 58, 79, 13, 18, 58, 79, 7, 19, 63, 25, 22, 20, 56, 91, 7, 30, 63, 25, 22, 33, 56, 91, 7, 23, 61, 25, 7, 53, 63, 91, 7, 36, 56, 25, 22, 28, 56, 91, 7, 0, 1, 96, 10, 57, 19, 96, 42, 7, 20, 58, 42, 13, 3

Epoch: 500 | Time: 0m 6s
	Train Loss: 0.277 | Train PPL:   1.319
	 Val. Loss: 5.381 |  Val. PPL: 217.236
=> Saving checkpoint
=> Saving checkpoint
| Test Loss: 5.552 | Test PPL: 257.804 |


In [None]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [18]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, model, optimizer)

model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint


In [66]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

257.7541290805818


In [39]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [40]:
for i in range(0,len(test_intro)):
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, intro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
    
        


[0, 1, 50, 51, 1, 58, 32, 45, 8, 58, 32, 13, 18, 58, 54, 7, 18, 41, 32, 7, 49, 63, 70, 45, 21, 58, 29, 7, 33, 63, 54, 7, 23, 41, 32, 7, 53, 63, 42, 27, 28, 56, 42, 13, 0, 1, 63, 54, 13, 1, 63, 42, 45, 8, 56, 42, 13, 18, 58, 54, 13, 18, 63, 42, 13, 19, 63, 70, 22, 21, 41, 54, 7, 33, 63, 70, 7, 23, 63, 52, 13, 53, 63, 55, 22, 38, 41, 55, 22, 0, 1, 15, 55, 13, 4, 67, 55, 7, 14, 56, 52, 7, 14, 63, 54, 7, 8, 37, 70, 22, 19, 63, 72, 45, 30, 63, 32, 27, 53, 63, 70, 7, 36, 9, 70, 13, 38, 63, 34, 11, 0, 14, 63, 54, 7, 8, 5, 32, 7, 18, 5, 70, 7, 19, 63, 72, 7, 49, 58, 52, 7, 20, 63, 54, 7, 30, 63, 32, 7, 21, 5, 32, 7, 33, 58, 54, 7, 23, 63, 34, 39]
0
[0, 1, 50, 188, 4, 37, 29, 24, 18, 37, 79, 43, 20, 61, 79, 43, 21, 37, 10, 43, 23, 37, 6, 43, 36, 37, 79, 43, 28, 37, 91, 43, 0, 4, 37, 25, 43, 8, 61, 6, 24, 23, 61, 6, 43, 36, 37, 25, 43, 28, 61, 88, 43, 28, 37, 79, 43, 0, 4, 61, 25, 45, 20, 37, 12, 43, 21, 37, 25, 86, 23, 37, 12, 7, 36, 61, 6, 43, 28, 37, 10, 7, 0, 4, 37, 25, 7, 4, 61, 6, 43, 8, 3

[0, 1, 50, 204, 1, 56, 79, 27, 8, 41, 69, 7, 18, 63, 42, 43, 49, 56, 72, 43, 30, 63, 69, 57, 38, 56, 79, 7, 28, 63, 98, 7, 0, 1, 56, 91, 24, 49, 63, 98, 43, 30, 63, 88, 83, 28, 37, 79, 13, 28, 37, 98, 13, 0, 1, 56, 91, 46, 8, 56, 98, 7, 18, 58, 79, 43, 49, 63, 98, 22, 30, 63, 42, 39, 53, 63, 54, 39, 0, 1, 63, 69, 77, 49, 58, 79, 45, 30, 15, 69, 43, 33, 63, 98, 57, 28, 82, 79, 13, 28, 37, 98, 13, 0, 1, 56, 91, 27, 8, 56, 98, 7, 18, 63, 79, 43, 49, 63, 98, 43, 30, 41, 69, 46, 53, 15, 79, 39, 0, 1, 58, 98, 27, 49, 63, 91, 22, 30, 63, 98, 77, 28, 61, 79, 13, 28, 37, 98, 13, 28, 37, 91, 13, 0, 1, 56, 25, 27, 8, 37, 91, 7, 18, 56, 98, 45, 49, 63, 79, 43, 30, 63, 25, 45, 33, 56, 91, 43, 53, 58, 88, 45, 28, 61, 25, 13, 28, 37, 88, 13]
7
[0, 1, 50, 204, 1, 56, 79, 27, 8, 41, 69, 7, 18, 63, 42, 43, 49, 56, 72, 43, 30, 63, 69, 57, 38, 56, 79, 7, 28, 63, 98, 7, 0, 1, 56, 91, 24, 49, 63, 98, 43, 30, 63, 88, 83, 28, 37, 79, 13, 28, 37, 98, 13, 0, 1, 56, 91, 46, 8, 56, 98, 7, 18, 58, 79, 43, 49, 63, 

[0, 1, 2, 137, 14, 41, 91, 22, 18, 2, 137, 18, 96, 31, 26, 30, 2, 137, 30, 37, 91, 43, 33, 61, 31, 39, 53, 2, 137, 38, 37, 91, 43, 0, 1, 2, 137, 1, 37, 31, 26, 18, 2, 137, 18, 56, 91, 43, 49, 63, 31, 46, 30, 2, 137, 33, 40, 91, 22, 53, 2, 137, 53, 9, 31, 57, 0, 1, 2, 137, 1, 41, 91, 76, 18, 2, 137, 30, 2, 137, 53, 2, 137]
17
[0, 1, 2, 123, 1, 15, 34, 43, 18, 9, 35, 43, 30, 40, 65, 7, 33, 40, 65, 7, 53, 40, 65, 7, 38, 9, 65, 7, 0, 1, 40, 65, 22, 8, 9, 71, 43, 49, 9, 34, 7, 30, 40, 55, 22, 38, 9, 55, 22, 0, 1, 9, 55, 27, 18, 9, 71, 43, 49, 9, 71, 13, 30, 40, 35, 22, 33, 5, 35, 7, 53, 5, 35, 7, 38, 9, 35, 22, 0, 1, 9, 35, 7, 14, 5, 71, 22, 18, 5, 34, 22, 49, 9, 71, 48, 0, 18, 15, 34, 22, 49, 5, 34, 7, 30, 9, 71, 22, 33, 15, 34, 7, 53, 9, 34, 22, 38, 5, 72, 7, 0, 1, 9, 34, 7, 18, 15, 34, 22, 49, 15, 34, 22, 30, 40, 65, 7, 33, 9, 55, 7, 38, 9, 55, 22, 0, 1, 15, 34, 43, 18, 9, 35, 43, 30, 40, 65, 7, 33, 40, 65, 7, 53, 40, 65, 7]
18
[0, 1, 50, 204, 8, 62, 6, 86, 8, 67, 59, 86, 36, 67, 6, 22, 

[0, 1, 50, 80, 1, 63, 54, 22, 14, 58, 29, 13, 8, 58, 59, 43, 49, 64, 54, 13, 49, 63, 34, 7, 30, 58, 34, 83, 0, 1, 41, 34, 7, 4, 37, 32, 7, 14, 63, 54, 7, 8, 41, 29, 13, 18, 63, 34, 43, 49, 63, 32, 83, 33, 63, 106, 22, 53, 58, 65, 46, 0, 1, 56, 59, 22, 1, 63, 52, 22, 14, 63, 6, 22, 14, 56, 34, 43, 18, 63, 99, 22, 49, 63, 17, 45, 23, 56, 17, 13, 53, 63, 6, 7, 36, 56, 17, 7, 38, 37, 31, 43, 0, 1, 41, 29, 22, 14, 63, 59, 22, 18, 58, 31, 7, 19, 58, 29, 13, 49, 63, 59, 45, 21, 37, 99, 7, 33, 37, 17, 7, 23, 56, 6, 13, 53, 58, 79, 7, 36, 56, 29, 7, 38, 63, 54, 7, 28, 41, 29, 13, 28, 63, 32, 43]
30
[0, 1, 50, 80, 1, 63, 54, 22, 14, 58, 29, 13, 8, 58, 59, 43, 49, 64, 54, 13, 49, 63, 34, 7, 30, 58, 34, 83, 0, 1, 41, 34, 7, 4, 37, 32, 7, 14, 63, 54, 7, 8, 41, 29, 13, 18, 63, 34, 43, 49, 63, 32, 83, 33, 63, 106, 22, 53, 58, 65, 46, 0, 1, 56, 59, 22, 1, 63, 52, 22, 14, 63, 6, 22, 14, 56, 34, 43, 18, 63, 99, 22, 49, 63, 17, 45, 23, 56, 17, 13, 53, 63, 6, 7, 36, 56, 17, 7, 38, 37, 31, 43, 0, 1, 41, 29

[0, 1, 50, 204, 8, 62, 6, 86, 8, 67, 59, 86, 36, 67, 6, 22, 36, 61, 59, 22, 28, 67, 6, 22, 28, 61, 59, 22, 0, 4, 62, 99, 43, 4, 61, 6, 43, 8, 67, 6, 22, 8, 61, 59, 22, 19, 95, 99, 47, 19, 67, 6, 47, 36, 62, 6, 22, 36, 67, 59, 22, 28, 67, 6, 22, 28, 61, 59, 22, 0, 4, 62, 99, 22, 4, 62, 6, 22, 8, 67, 6, 47, 8, 67, 69, 47, 23, 67, 6, 7, 23, 67, 59, 7, 36, 61, 6, 77, 36, 61, 59, 77, 0, 8, 95, 99, 84, 8, 62, 6, 84]
41
[0, 1, 50, 147, 1, 95, 59, 7, 4, 56, 98, 13, 14, 61, 17, 45, 19, 56, 98, 116, 28, 61, 98, 13, 0, 1, 82, 59, 13, 4, 58, 98, 13, 14, 37, 17, 45, 19, 56, 98, 116, 28, 56, 69, 13, 0, 1, 61, 54, 13, 4, 56, 69, 13, 14, 56, 59, 45, 19, 61, 69, 26, 38, 58, 52, 45, 0, 4, 56, 70, 13, 14, 67, 70, 122]
42
[0, 1, 2, 223, 18, 2, 223, 30, 2, 223, 30, 58, 17, 39, 30, 58, 32, 39, 53, 2, 223, 53, 63, 98, 39, 53, 63, 70, 39, 0, 1, 2, 223, 1, 63, 59, 39, 1, 63, 52, 39, 18, 2, 223, 18, 56, 29, 43, 18, 56, 55, 43, 49, 61, 98, 43, 49, 61, 70, 43, 30, 2, 223, 30, 63, 59, 84, 30, 63, 52, 84, 53, 2, 22

[0, 1, 2, 223, 1, 56, 17, 39, 18, 2, 223, 18, 61, 98, 46, 30, 2, 223, 30, 61, 59, 39, 53, 2, 223, 53, 56, 98, 43, 38, 37, 6, 43, 0, 1, 2, 223, 1, 61, 17, 27, 18, 2, 223, 18, 56, 98, 27, 30, 2, 223, 30, 37, 59, 60, 53, 2, 223, 0, 1, 2, 223, 18, 2, 223, 18, 61, 98, 22, 49, 67, 6, 43, 30, 2, 223, 30, 61, 17, 27, 53, 2, 223, 53, 95, 98, 13, 0, 1, 2, 223, 1, 67, 98, 27, 18, 2, 223, 18, 62, 6, 43, 30, 2, 223, 30, 61, 17, 46, 53, 2, 223, 53, 56, 17, 22, 38, 37, 98, 13, 0, 1, 2, 223, 1, 64, 17, 7, 1, 37, 98, 7, 4, 95, 6, 13, 14, 61, 6, 13, 14, 67, 98, 13, 8, 61, 59, 7, 18, 2, 223, 18, 61, 17, 43, 18, 61, 29, 13, 19, 61, 31, 7, 49, 61, 54, 13, 20, 61, 42, 7, 30, 2, 223, 30, 61, 17, 27, 30, 61, 32, 46, 53, 2, 223, 53, 63, 98, 27, 0, 1, 2, 223, 1, 37, 59, 45, 18, 2, 223, 18, 67, 98, 22, 49, 63, 17, 7, 30, 2, 223, 30, 2, 223, 53, 2, 223, 53, 2, 223, 53, 67, 98, 27, 0, 1, 37, 17, 46, 18, 2, 223, 18, 2, 223, 18, 2, 223, 18, 2, 223, 30, 2, 223, 30, 2, 223, 53, 37, 98, 26, 53, 2, 223, 53, 2, 223, 53, 

[0, 1, 50, 188, 4, 58, 88, 22, 8, 56, 88, 47, 23, 56, 25, 39, 28, 56, 91, 39, 0, 8, 56, 25, 39, 20, 37, 88, 43, 21, 56, 89, 43, 23, 56, 93, 39, 28, 37, 141, 22, 0, 4, 56, 93, 22, 8, 56, 16, 24, 21, 56, 88, 22, 23, 56, 89, 39, 28, 58, 93, 43, 0, 4, 56, 89, 43, 8, 37, 88, 47, 23, 56, 25, 26, 28, 56, 88, 43, 0, 4, 63, 25, 43, 8, 56, 91, 24, 21, 56, 98, 7, 33, 56, 10, 7, 23, 56, 98, 43, 36, 56, 91, 43, 28, 56, 25, 43, 0, 4, 56, 89, 43, 8, 56, 88, 60, 28, 37, 89, 43, 0, 4, 37, 88, 43, 8, 37, 16, 48, 21, 56, 88, 43, 23, 61, 89, 39, 28, 37, 111, 43, 0, 4, 37, 93, 22]
67
[0, 1, 50, 51, 1, 58, 32, 45, 8, 58, 32, 13, 18, 58, 54, 7, 18, 41, 32, 7, 49, 63, 70, 45, 21, 58, 29, 7, 33, 63, 54, 7, 23, 41, 32, 7, 53, 63, 42, 27, 28, 56, 42, 13, 0, 1, 63, 42, 13, 1, 63, 42, 45, 8, 56, 42, 13, 18, 58, 54, 13, 18, 63, 42, 13, 19, 63, 70, 22, 21, 41, 54, 7, 33, 63, 70, 7, 23, 63, 52, 13, 53, 63, 55, 22, 38, 41, 55, 22, 0, 1, 15, 55, 13, 4, 67, 55, 7, 14, 56, 52, 7, 14, 63, 54, 7, 8, 37, 70, 22, 19, 63, 72,

[0, 1, 50, 51, 1, 5, 79, 7, 4, 5, 69, 27, 18, 5, 69, 7, 19, 5, 79, 7, 49, 9, 10, 7, 20, 15, 10, 7, 30, 15, 79, 7, 21, 5, 10, 24, 28, 15, 79, 7, 0, 1, 41, 42, 7, 4, 41, 72, 27, 18, 58, 25, 7, 19, 5, 42, 7, 49, 58, 72, 7, 20, 5, 10, 7, 30, 58, 10, 7, 21, 9, 69, 24, 28, 15, 10, 7, 0, 1, 15, 79, 7, 4, 5, 69, 27, 18, 9, 69, 7, 19, 58, 42, 7, 49, 9, 69, 7, 20, 5, 79, 7, 30, 15, 10, 7, 21, 15, 10, 27, 53, 5, 25, 7, 36, 9, 42, 7, 38, 41, 69, 7, 28, 5, 79, 43, 0, 4, 5, 10, 27, 18, 15, 10, 7, 19, 5, 10, 7, 49, 56, 10, 7, 20, 5, 79, 43, 21, 5, 10, 11, 28, 15, 31, 7, 0, 1, 5, 42, 7, 4, 5, 72, 27, 18, 5, 42, 7, 19, 5, 42, 7, 49, 9, 72, 7, 20, 5, 42, 7, 30, 15, 31, 7, 21, 5, 31, 24, 28, 15, 72, 7, 0, 1, 41, 35, 7, 4, 41, 71, 27, 18, 15, 71, 7, 4, 15, 35, 7, 14, 63, 71, 7, 8, 5, 10, 7, 18, 15, 35, 7, 19, 5, 10, 7, 49, 15, 31, 7, 20, 5, 42, 7, 30, 5, 42, 7, 21, 5, 31, 7, 33, 5, 10, 7, 23, 41, 42, 7, 53, 5, 10, 7, 28, 5, 10, 7, 0, 1, 37, 16, 7, 4, 5, 10, 7, 4, 5, 6, 7, 8, 5, 79, 7, 18, 15, 6, 7, 19, 5,

[0, 1, 50, 158, 1, 37, 59, 7, 4, 62, 98, 43, 4, 61, 29, 22, 8, 37, 59, 43, 19, 62, 98, 7, 20, 62, 88, 48, 20, 62, 98, 48, 0, 4, 67, 99, 45, 4, 62, 6, 45, 18, 62, 98, 7, 19, 37, 59, 7, 20, 62, 98, 13, 20, 67, 54, 13, 30, 62, 98, 57, 30, 37, 29, 57, 36, 61, 12, 7, 36, 61, 29, 13, 38, 61, 85, 13, 38, 61, 59, 13, 28, 67, 85, 7, 28, 67, 59, 7, 0, 1, 61, 111, 13, 1, 61, 98, 13, 4, 67, 111, 47, 4, 67, 98, 46, 8, 62, 32, 7, 18, 61, 70, 7, 19, 61, 29, 22, 20, 67, 59, 7, 30, 61, 98, 13, 21, 67, 17, 22, 21, 82, 29, 22, 23, 67, 17, 27, 23, 62, 10, 27, 28, 61, 29, 7, 0, 1, 37, 69, 22, 4, 67, 59, 22, 8, 67, 99, 22, 8, 67, 59, 7, 18, 67, 98, 13, 19, 114, 12, 43, 20, 62, 17, 22, 21, 62, 99, 43, 23, 67, 98, 43, 36, 62, 25, 47, 21, 62, 98, 47, 0, 4, 62, 98, 39, 19, 37, 89, 13, 19, 37, 79, 13, 49, 67, 88, 13, 49, 67, 69, 13, 49, 67, 99, 13, 49, 67, 54, 13, 20, 67, 25, 13, 20, 67, 42, 13]
84
[0, 1, 50, 105, 4, 58, 91, 45, 8, 63, 79, 43, 19, 96, 69, 57, 23, 63, 69, 7, 53, 5, 79, 7, 36, 41, 69, 7, 38, 63, 9

[0, 1, 50, 174, 1, 96, 31, 22, 14, 63, 42, 22, 18, 96, 72, 46, 30, 9, 42, 43, 33, 41, 31, 43, 53, 63, 42, 39, 0, 18, 96, 72, 27, 30, 63, 42, 43, 33, 41, 31, 43, 53, 56, 79, 24, 0, 18, 5, 72, 46, 30, 96, 31, 39, 53, 58, 79, 46, 0, 1, 40, 10, 43, 14, 9, 31, 43, 18, 9, 79, 112]
92
[0, 1, 50, 80, 1, 63, 54, 22, 14, 58, 29, 13, 8, 58, 59, 43, 49, 64, 54, 13, 49, 63, 34, 7, 30, 58, 34, 83, 0, 1, 41, 34, 7, 4, 37, 32, 7, 14, 63, 54, 7, 8, 41, 29, 13, 18, 63, 34, 43, 49, 63, 32, 83, 33, 63, 106, 22, 53, 58, 65, 46, 0, 1, 56, 59, 22, 1, 63, 52, 22, 14, 63, 6, 22, 14, 56, 34, 43, 18, 63, 99, 22, 49, 63, 17, 45, 23, 56, 17, 13, 53, 63, 6, 7, 36, 56, 17, 7, 38, 37, 31, 43, 0, 1, 41, 29, 22, 14, 63, 59, 22, 18, 58, 31, 7, 19, 58, 29, 13, 49, 63, 59, 45, 21, 37, 99, 7, 33, 37, 17, 7, 23, 56, 6, 13, 53, 58, 79, 7, 36, 56, 29, 7, 38, 63, 54, 7, 28, 41, 29, 13, 28, 63, 32, 43]
93
[0, 1, 50, 155, 8, 95, 52, 46, 20, 95, 106, 45, 23, 95, 107, 27, 28, 95, 66, 87, 0, 20, 95, 74, 84, 0, 20, 95, 52, 39, 23, 9

[0, 1, 50, 51, 1, 5, 79, 7, 4, 5, 69, 27, 18, 5, 69, 7, 19, 5, 79, 7, 49, 9, 69, 7, 20, 15, 79, 7, 30, 15, 10, 7, 21, 5, 10, 24, 28, 15, 69, 7, 0, 1, 41, 42, 7, 4, 41, 72, 27, 18, 15, 72, 7, 19, 5, 42, 7, 49, 5, 42, 7, 20, 5, 42, 7, 30, 58, 10, 7, 21, 9, 69, 24, 28, 15, 10, 7, 0, 1, 15, 79, 7, 4, 5, 69, 27, 18, 9, 69, 7, 19, 5, 42, 7, 49, 9, 69, 7, 20, 5, 79, 7, 30, 9, 10, 7, 21, 15, 10, 27, 53, 5, 25, 7, 36, 9, 42, 7, 38, 41, 69, 7, 28, 5, 79, 43, 0, 4, 5, 10, 27, 18, 15, 10, 7, 19, 5, 98, 7, 49, 41, 10, 7, 20, 5, 79, 43, 21, 5, 10, 11, 28, 15, 31, 7, 0, 1, 5, 42, 7, 4, 5, 72, 27, 18, 5, 42, 7, 19, 5, 42, 7, 49, 9, 72, 7, 20, 5, 42, 7, 30, 15, 31, 7, 21, 5, 31, 24, 28, 15, 72, 7, 0, 1, 41, 35, 7, 4, 41, 71, 27, 18, 15, 71, 7, 4, 15, 35, 7, 14, 5, 71, 7, 8, 15, 65, 7, 18, 15, 71, 7, 19, 5, 10, 7, 49, 15, 31, 7, 20, 5, 10, 7, 30, 5, 42, 7, 21, 5, 35, 7, 33, 5, 10, 7, 23, 5, 35, 7, 53, 5, 10, 7, 28, 5, 10, 7, 0, 1, 15, 35, 7, 4, 5, 10, 7, 4, 5, 10, 7, 8, 5, 79, 7, 18, 15, 6, 7, 19, 5, 10

In [41]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

2041
2041
1861
1861
4801
4801
33301
33301
12481
12481
12481
12481
13381
13381
13381
13381
2641
2641
7201
7201
9361
9361
4561
4561
11581
11581
1
1
26760
26760
15540
15540
34560
34560
25140
25140
13381
13381
13381
13381
22140
22140
10020
10020


In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
vocab_folder = "vocab/"
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab  + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def load_vocab(path):
    with open(path, 'rb') as f:
        x = pickle.load(f)
    return x

In [None]:
vocab_folder = "vocab_intro/"
intro_field.vocab = load_vocab(vocab_folder + 'intro_vocab.pkl')
solo_field.vocab = load_vocab(vocab_folder + 'solo_vocab.pkl')
outro_field.vocab = load_vocab(vocab_folder + 'outro_vocab.pkl')

In [None]:
vocab_folder = "vocab/"
with open()
intro_field.vocab = 

In [None]:
beam = [2, 59, 119, 13, 212, 59, 212, 59, 59, 75, 59, 59, 13, 119, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 119, 59, 59, 59, 59, 166, 59, 59, 59, 13, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 59, 59, 212, 212, 13, 59, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 14, 59, 59, 212, 59, 212, 212, 59, 59, 59, 68, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 68, 59, 59, 212, 59, 59, 13, 59, 59, 59, 59, 59, 59, 97, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 158, 59, 59, 13, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 13, 212, 13, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 158, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 13, 13, 59, 59, 13, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 13, 59, 166, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 86, 59, 212, 212, 212, 59, 59, 59, 59, 212, 59, 86, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 212, 59, 59, 166, 59, 86, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 59, 59, 59, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 212, 212, 158, 59, 59, 13, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 158, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 212, 212, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 13, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 212, 86, 59, 59, 59, 158, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 212, 59, 13, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 68, 59, 13, 59, 59, 13, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 14, 59, 59, 59, 59, 59, 59, 13, 86, 59, 59, 59, 212, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 13, 59, 59, 59, 59, 68, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 212, 212, 13, 59, 166, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 212, 59, 59, 212, 59, 212, 59, 59, 13, 59, 59, 59, 59, 13, 59, 59, 14, 13, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 68, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 212, 59, 59, 13, 59, 59, 166, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 13, 59, 59, 212, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59]

translated_sentence1 = [solo_field.vocab.itos[idx] for idx in beam]
translated_sentence = [int(x) for x in translated_sentence1 if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']    
utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict1.mid")

In [None]:
translated_sentence

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    list_sentence = [int(x) for x in sentence.split(' ')]
    remi = [word2event[x] for x in list_sentence]
    print(remi)

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns