In [6]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_fixed_augmented"
folder = "dynamic_fixed_augmented_models/outro_2nd"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [7]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

state = pickle.load(open('./state.pkl', 'rb'))
random.setstate(state)

In [8]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [9]:
event2word, word2event = pickle.load(open('dictionary_fixed_augmented.pkl', 'rb'))

In [10]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [11]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.outro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=1)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=1)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=1)

In [12]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    print(solo.transpose(1,0).size())

torch.Size([269, 8])
torch.Size([331, 8])
torch.Size([353, 8])
torch.Size([522, 8])
torch.Size([281, 8])
torch.Size([444, 8])
torch.Size([281, 8])
torch.Size([473, 8])
torch.Size([619, 8])
torch.Size([365, 8])
torch.Size([509, 8])
torch.Size([353, 8])
torch.Size([433, 8])
torch.Size([491, 8])


In [13]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [14]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

In [15]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [16]:
src_vocab_size = len(outro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [17]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4)#4e-5


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 11,088,655 trainable parameters


In [18]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [19]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [20]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, trg = outro.transpose(1,0), solo.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, trg = outro.transpose(1,0), solo.transpose(1,0)
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [21]:
def translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [22]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [23]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(3):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, outro, outro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [24]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+2) % 20 ==0:
        if check_mode_collapse(model) > 1:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 0m 38s
	Train Loss: 4.216 | Train PPL:  67.736
	 Val. Loss: 3.284 |  Val. PPL:  26.685
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 0m 38s
	Train Loss: 2.934 | Train PPL:  18.797
	 Val. Loss: 2.767 |  Val. PPL:  15.908
=> Saving checkpoint
Epoch: 03 | Time: 0m 38s
	Train Loss: 2.569 | Train PPL:  13.049
	 Val. Loss: 2.463 |  Val. PPL:  11.746
=> Saving checkpoint
Epoch: 04 | Time: 0m 39s
	Train Loss: 2.336 | Train PPL:  10.345
	 Val. Loss: 2.329 |  Val. PPL:  10.269
=> Saving checkpoint
Epoch: 05 | Time: 0m 39s
	Train Loss: 2.191 | Train PPL:   8.946
	 Val. Loss: 2.234 |  Val. PPL:   9.341
=> Saving checkpoint
Epoch: 06 | Time: 0m 38s
	Train Loss: 2.072 | Train PPL:   7.941
	 Val. Loss: 2.147 |  Val. PPL:   8.562
=> Saving checkpoint
Epoch: 07 | Time: 0m 39s
	Train Loss: 1.969 | Train PPL:   7.163
	 Val. Loss: 2.081 |  Val. PPL:   8.015
=> Saving checkpoint
Epoch: 08 | Time: 0m 39s
	Train Loss: 1.886 | Train PPL:   6.594
	 Val. Loss: 2.048 |  Val. PPL: 

Epoch: 25 | Time: 0m 39s
	Train Loss: 1.280 | Train PPL:   3.598
	 Val. Loss: 2.496 |  Val. PPL:  12.134
Epoch: 26 | Time: 0m 39s
	Train Loss: 1.253 | Train PPL:   3.500
	 Val. Loss: 2.526 |  Val. PPL:  12.500
Epoch: 27 | Time: 0m 39s
	Train Loss: 1.230 | Train PPL:   3.421
	 Val. Loss: 2.555 |  Val. PPL:  12.870
Epoch: 28 | Time: 0m 39s
	Train Loss: 1.204 | Train PPL:   3.334
	 Val. Loss: 2.635 |  Val. PPL:  13.943
Epoch: 29 | Time: 0m 39s
	Train Loss: 1.181 | Train PPL:   3.258
	 Val. Loss: 2.686 |  Val. PPL:  14.679
Epoch: 30 | Time: 0m 39s
	Train Loss: 1.156 | Train PPL:   3.177
	 Val. Loss: 2.714 |  Val. PPL:  15.086
Epoch: 31 | Time: 0m 39s
	Train Loss: 1.131 | Train PPL:   3.099
	 Val. Loss: 2.740 |  Val. PPL:  15.486
Epoch: 32 | Time: 0m 39s
	Train Loss: 1.108 | Train PPL:   3.029
	 Val. Loss: 2.774 |  Val. PPL:  16.028
Epoch: 33 | Time: 0m 39s
	Train Loss: 1.090 | Train PPL:   2.975
	 Val. Loss: 2.848 |  Val. PPL:  17.261
Epoch: 34 | Time: 0m 39s
	Train Loss: 1.066 | Train PPL

Epoch: 60 | Time: 0m 39s
	Train Loss: 0.689 | Train PPL:   1.992
	 Val. Loss: 3.906 |  Val. PPL:  49.698
=> Saving checkpoint
Epoch: 61 | Time: 0m 39s
	Train Loss: 0.680 | Train PPL:   1.974
	 Val. Loss: 3.992 |  Val. PPL:  54.156
=> Saving checkpoint
Epoch: 62 | Time: 0m 39s
	Train Loss: 0.676 | Train PPL:   1.966
	 Val. Loss: 3.935 |  Val. PPL:  51.183
Epoch: 63 | Time: 0m 39s
	Train Loss: 0.665 | Train PPL:   1.945
	 Val. Loss: 4.046 |  Val. PPL:  57.170
Epoch: 64 | Time: 0m 39s
	Train Loss: 0.657 | Train PPL:   1.929
	 Val. Loss: 4.056 |  Val. PPL:  57.765
Epoch: 65 | Time: 0m 39s
	Train Loss: 0.650 | Train PPL:   1.915
	 Val. Loss: 4.071 |  Val. PPL:  58.637
Epoch: 66 | Time: 0m 39s
	Train Loss: 0.642 | Train PPL:   1.900
	 Val. Loss: 4.077 |  Val. PPL:  58.979
Epoch: 67 | Time: 0m 39s
	Train Loss: 0.635 | Train PPL:   1.887
	 Val. Loss: 4.142 |  Val. PPL:  62.921
Epoch: 68 | Time: 0m 39s
	Train Loss: 0.626 | Train PPL:   1.870
	 Val. Loss: 4.144 |  Val. PPL:  63.055
Epoch: 69 | T

[0, 1, 2, 98, 54, 35, 77, 66, 11, 40, 77, 27, 36, 35, 44, 111, 0, 11, 35, 77, 10, 15, 29, 77, 14, 17, 35, 77, 111, 0, 11, 19, 77, 7, 52, 35, 77, 7, 36, 40, 84, 42, 0, 48, 40, 44, 16, 11, 40, 84, 7, 52, 35, 44, 7, 36, 35, 44, 42, 0, 11, 35, 77, 7, 52, 29, 77, 16, 36, 35, 44, 42, 0, 11, 19, 67, 7, 52, 35, 77, 7, 36, 9, 65, 121, 0, 11, 35, 77, 7, 52, 29, 44, 16, 36, 35, 84, 16, 50, 40, 93, 115, 0, 11, 35, 77, 7, 15, 40, 77, 7]
[0, 1, 58, 198, 1, 12, 22, 66, 8, 58, 198, 8, 73, 22, 27, 17, 58, 198, 17, 73, 22, 66, 23, 58, 198, 23, 12, 22, 7, 33, 81, 22, 7, 0, 1, 58, 198, 1, 73, 22, 24, 8, 58, 198, 8, 12, 22, 24, 17, 58, 198, 17, 81, 22, 42, 23, 58, 198, 0, 1, 58, 198, 8, 58, 198, 8, 73, 22, 16, 15, 70, 22, 7, 17, 58, 198, 17, 73, 22, 24, 23, 58, 198, 23, 75, 22, 14, 0, 1, 58, 198, 1, 70, 22, 24, 8, 58, 198, 8, 72, 22, 7, 17, 58, 198, 17, 73, 22, 27, 23, 58, 198, 23, 12, 22, 16, 33, 81, 22, 14, 0, 1, 58, 198, 1, 69, 22, 10, 1, 81, 22, 10, 54, 75, 22, 14, 4, 73, 22, 14, 4, 70, 22, 14, 48, 73,

Epoch: 130 | Time: 0m 40s
	Train Loss: 0.365 | Train PPL:   1.441
	 Val. Loss: 5.256 |  Val. PPL: 191.640
Epoch: 131 | Time: 0m 39s
	Train Loss: 0.361 | Train PPL:   1.435
	 Val. Loss: 5.295 |  Val. PPL: 199.282
Epoch: 132 | Time: 0m 39s
	Train Loss: 0.360 | Train PPL:   1.433
	 Val. Loss: 5.288 |  Val. PPL: 197.943
Epoch: 133 | Time: 0m 39s
	Train Loss: 0.357 | Train PPL:   1.429
	 Val. Loss: 5.315 |  Val. PPL: 203.429
Epoch: 134 | Time: 0m 39s
	Train Loss: 0.353 | Train PPL:   1.423
	 Val. Loss: 5.339 |  Val. PPL: 208.212
Epoch: 135 | Time: 0m 39s
	Train Loss: 0.351 | Train PPL:   1.420
	 Val. Loss: 5.358 |  Val. PPL: 212.331
Epoch: 136 | Time: 0m 39s
	Train Loss: 0.346 | Train PPL:   1.413
	 Val. Loss: 5.353 |  Val. PPL: 211.340
Epoch: 137 | Time: 0m 40s
	Train Loss: 0.346 | Train PPL:   1.414
	 Val. Loss: 5.346 |  Val. PPL: 209.780
Epoch: 138 | Time: 0m 40s
	Train Loss: 0.341 | Train PPL:   1.406
	 Val. Loss: 5.461 |  Val. PPL: 235.265
Epoch: 139 | Time: 0m 39s
	Train Loss: 0.337 |

Epoch: 166 | Time: 0m 39s
	Train Loss: 0.263 | Train PPL:   1.301
	 Val. Loss: 5.727 |  Val. PPL: 306.946
Epoch: 167 | Time: 0m 40s
	Train Loss: 0.261 | Train PPL:   1.299
	 Val. Loss: 5.679 |  Val. PPL: 292.652
Epoch: 168 | Time: 0m 39s
	Train Loss: 0.255 | Train PPL:   1.291
	 Val. Loss: 5.821 |  Val. PPL: 337.385
Epoch: 169 | Time: 0m 39s
	Train Loss: 0.253 | Train PPL:   1.288
	 Val. Loss: 5.894 |  Val. PPL: 362.816
Epoch: 170 | Time: 0m 39s
	Train Loss: 0.251 | Train PPL:   1.285
	 Val. Loss: 5.858 |  Val. PPL: 350.171
Epoch: 171 | Time: 0m 40s
	Train Loss: 0.249 | Train PPL:   1.283
	 Val. Loss: 5.800 |  Val. PPL: 330.140
Epoch: 172 | Time: 0m 39s
	Train Loss: 0.248 | Train PPL:   1.281
	 Val. Loss: 5.784 |  Val. PPL: 325.037
Epoch: 173 | Time: 0m 39s
	Train Loss: 0.244 | Train PPL:   1.276
	 Val. Loss: 5.780 |  Val. PPL: 323.727
Epoch: 174 | Time: 0m 39s
	Train Loss: 0.243 | Train PPL:   1.275
	 Val. Loss: 5.859 |  Val. PPL: 350.232
Epoch: 175 | Time: 0m 39s
	Train Loss: 0.237 |

Epoch: 209 | Time: 0m 40s
	Train Loss: 0.169 | Train PPL:   1.185
	 Val. Loss: 6.262 |  Val. PPL: 524.269
Epoch: 210 | Time: 0m 40s
	Train Loss: 0.168 | Train PPL:   1.183
	 Val. Loss: 6.261 |  Val. PPL: 523.900
Epoch: 211 | Time: 0m 39s
	Train Loss: 0.165 | Train PPL:   1.179
	 Val. Loss: 6.309 |  Val. PPL: 549.253
Epoch: 212 | Time: 0m 39s
	Train Loss: 0.164 | Train PPL:   1.179
	 Val. Loss: 6.343 |  Val. PPL: 568.716
Epoch: 213 | Time: 0m 39s
	Train Loss: 0.164 | Train PPL:   1.178
	 Val. Loss: 6.368 |  Val. PPL: 583.082
Epoch: 214 | Time: 0m 39s
	Train Loss: 0.161 | Train PPL:   1.174
	 Val. Loss: 6.336 |  Val. PPL: 564.451
Epoch: 215 | Time: 0m 39s
	Train Loss: 0.162 | Train PPL:   1.176
	 Val. Loss: 6.243 |  Val. PPL: 514.230
Epoch: 216 | Time: 0m 40s
	Train Loss: 0.158 | Train PPL:   1.171
	 Val. Loss: 6.413 |  Val. PPL: 610.021
Epoch: 217 | Time: 0m 40s
	Train Loss: 0.156 | Train PPL:   1.169
	 Val. Loss: 6.405 |  Val. PPL: 604.866
Epoch: 218 | Time: 0m 39s
	Train Loss: 0.157 |

Epoch: 250 | Time: 0m 40s
	Train Loss: 0.119 | Train PPL:   1.126
	 Val. Loss: 6.598 |  Val. PPL: 733.594
Epoch: 251 | Time: 0m 39s
	Train Loss: 0.118 | Train PPL:   1.125
	 Val. Loss: 6.626 |  Val. PPL: 754.468
Epoch: 252 | Time: 0m 39s
	Train Loss: 0.116 | Train PPL:   1.123
	 Val. Loss: 6.710 |  Val. PPL: 820.562
Epoch: 253 | Time: 0m 39s
	Train Loss: 0.115 | Train PPL:   1.122
	 Val. Loss: 6.712 |  Val. PPL: 822.333
Epoch: 254 | Time: 0m 39s
	Train Loss: 0.115 | Train PPL:   1.122
	 Val. Loss: 6.695 |  Val. PPL: 808.579
Epoch: 255 | Time: 0m 39s
	Train Loss: 0.114 | Train PPL:   1.120
	 Val. Loss: 6.701 |  Val. PPL: 813.329
Epoch: 256 | Time: 0m 40s
	Train Loss: 0.114 | Train PPL:   1.120
	 Val. Loss: 6.581 |  Val. PPL: 721.450
Epoch: 257 | Time: 0m 39s
	Train Loss: 0.112 | Train PPL:   1.119
	 Val. Loss: 6.755 |  Val. PPL: 858.009
Epoch: 258 | Time: 0m 39s
	Train Loss: 0.115 | Train PPL:   1.122
	 Val. Loss: 6.706 |  Val. PPL: 817.252
Epoch: 259 | Time: 0m 39s
	Train Loss: 0.112 |

[0, 1, 2, 122, 1, 12, 30, 21, 48, 12, 110, 21, 8, 2, 122, 17, 2, 122, 50, 9, 13, 121, 50, 9, 110, 121, 23, 2, 122, 0, 1, 2, 122, 8, 2, 122, 11, 9, 6, 7, 11, 9, 67, 7, 52, 9, 20, 7, 52, 9, 67, 7, 17, 2, 122, 36, 26, 6, 66, 36, 26, 88, 66, 23, 2, 122, 32, 81, 67, 24, 32, 81, 93, 24, 0, 1, 2, 122, 1, 81, 39, 24, 1, 81, 86, 24, 48, 81, 67, 114, 48, 81, 93, 114, 8, 2, 122, 17, 2, 122, 23, 2, 122, 23, 73, 13, 10, 32, 12, 13, 10, 33, 12, 110, 10, 33, 81, 6, 10, 53, 81, 110, 10, 0, 1, 2, 122, 1, 12, 110, 28, 54, 12, 6, 10, 4, 81, 88, 21, 54, 5, 110, 10, 48, 26, 6, 16, 48, 73, 88, 94, 18, 73, 30, 10, 50, 73, 34, 10, 23, 2, 122, 23, 2, 122, 23, 2, 122, 23, 2, 122, 33, 5, 43, 16, 33, 5, 110, 16, 0, 54, 9, 203, 10, 48, 12, 110, 10, 48, 5, 86, 10]
Epoch: 280 | Time: 0m 39s
	Train Loss: 0.095 | Train PPL:   1.100
	 Val. Loss: 6.844 |  Val. PPL: 938.494
=> Saving checkpoint
Epoch: 281 | Time: 0m 39s
	Train Loss: 0.096 | Train PPL:   1.100
	 Val. Loss: 6.848 |  Val. PPL: 941.660
=> Saving checkpoint
E

[0, 1, 2, 122, 1, 12, 39, 85, 15, 73, 41, 14, 52, 5, 77, 14, 17, 40, 77, 104, 33, 69, 88, 14, 53, 76, 77, 14, 0, 1, 5, 34, 104, 15, 72, 77, 14, 52, 79, 34, 14, 17, 70, 30, 28, 0, 1, 72, 30, 85, 15, 75, 34, 14, 52, 70, 77, 14, 17, 60, 77, 45, 50, 79, 34, 14, 23, 75, 30, 7, 33, 75, 67, 16, 0, 1, 73, 34, 85, 15, 69, 30, 14, 52, 75, 34, 10, 17, 79, 30, 14, 17, 79, 39, 85, 0, 1, 79, 13, 16, 4, 69, 34, 27, 15, 75, 43, 14, 52, 81, 34, 14, 17, 26, 30, 28, 0, 1, 75, 20, 7, 4, 73, 34, 27, 15, 70, 6, 10, 52, 69, 22, 10, 17, 72, 25, 66, 0, 1, 60, 88, 31, 48, 72, 77, 14, 8, 60, 88, 16, 15, 72, 110, 10, 17, 19, 88, 85, 33, 81, 110, 7, 0, 1, 81, 67, 31, 48, 79, 30, 14, 8, 26, 77, 16, 15, 70, 34, 16, 17, 70, 30, 66]
Epoch: 320 | Time: 0m 39s
	Train Loss: 0.076 | Train PPL:   1.079
	 Val. Loss: 7.046 |  Val. PPL: 1147.775
=> Saving checkpoint
Epoch: 321 | Time: 0m 39s
	Train Loss: 0.075 | Train PPL:   1.078
	 Val. Loss: 7.073 |  Val. PPL: 1179.327
=> Saving checkpoint
Epoch: 322 | Time: 0m 39s
	Train L

Epoch: 359 | Time: 0m 40s
	Train Loss: 0.061 | Train PPL:   1.062
	 Val. Loss: 7.295 |  Val. PPL: 1473.200
[0, 1, 2, 98, 54, 9, 64, 31, 48, 40, 46, 7, 11, 35, 39, 85, 50, 5, 67, 10, 23, 19, 39, 16, 33, 12, 67, 10, 53, 26, 46, 82, 0, 11, 40, 6, 51, 50, 26, 61, 14, 23, 26, 25, 45, 0, 15, 26, 64, 16, 17, 19, 65, 14, 36, 35, 65, 66, 33, 5, 6, 16, 0, 1, 5, 65, 14, 54, 26, 65, 10, 48, 9, 65, 14, 8, 9, 65, 10, 11, 5, 64, 16, 52, 12, 61, 10, 17, 5, 64, 16, 18, 12, 61, 10, 50, 73, 64, 10, 23, 5, 61, 10, 32, 12, 64, 16, 53, 12, 61, 10, 0, 1, 35, 46, 85, 11, 35, 71, 10, 52, 9, 71, 14, 17, 9, 71, 14, 36, 40, 71, 24, 32, 35, 77, 82, 0, 54, 19, 77, 14, 4, 19, 46, 31, 11, 35, 44, 16, 52, 35, 44, 10, 17, 9, 88, 10, 36, 35, 43, 66, 32, 35, 44, 14, 53, 35, 44, 10, 0, 1, 40, 88, 14, 54, 35, 43, 24, 11, 35, 44, 51, 50, 29, 71, 115]
[0, 1, 58, 166, 8, 58, 166, 17, 58, 166, 17, 9, 13, 24, 23, 58, 166, 23, 9, 13, 16, 33, 12, 34, 16, 0, 1, 58, 166, 1, 12, 77, 10, 4, 73, 88, 10, 8, 58, 166, 8, 73, 77, 14, 11, 

Epoch: 380 | Time: 0m 40s
	Train Loss: 0.057 | Train PPL:   1.059
	 Val. Loss: 7.304 |  Val. PPL: 1486.520
=> Saving checkpoint
Epoch: 381 | Time: 0m 40s
	Train Loss: 0.055 | Train PPL:   1.057
	 Val. Loss: 7.371 |  Val. PPL: 1589.894
=> Saving checkpoint
Epoch: 382 | Time: 0m 39s
	Train Loss: 0.056 | Train PPL:   1.057
	 Val. Loss: 7.394 |  Val. PPL: 1625.443
Epoch: 383 | Time: 0m 40s
	Train Loss: 0.056 | Train PPL:   1.057
	 Val. Loss: 7.393 |  Val. PPL: 1624.793
Epoch: 384 | Time: 0m 39s
	Train Loss: 0.056 | Train PPL:   1.057
	 Val. Loss: 7.452 |  Val. PPL: 1723.869
Epoch: 385 | Time: 0m 40s
	Train Loss: 0.055 | Train PPL:   1.057
	 Val. Loss: 7.439 |  Val. PPL: 1701.305
Epoch: 386 | Time: 0m 40s
	Train Loss: 0.054 | Train PPL:   1.056
	 Val. Loss: 7.347 |  Val. PPL: 1551.147
Epoch: 387 | Time: 0m 39s
	Train Loss: 0.055 | Train PPL:   1.056
	 Val. Loss: 7.350 |  Val. PPL: 1556.565
Epoch: 388 | Time: 0m 39s
	Train Loss: 0.055 | Train PPL:   1.056
	 Val. Loss: 7.390 |  Val. PPL: 1619

[0, 1, 58, 166, 8, 58, 166, 17, 58, 166, 17, 9, 13, 24, 23, 58, 166, 23, 9, 13, 16, 33, 12, 30, 16, 0, 1, 58, 166, 1, 12, 77, 10, 4, 73, 88, 10, 8, 58, 166, 8, 73, 77, 14, 11, 73, 88, 14, 15, 70, 77, 14, 15, 12, 30, 16, 17, 58, 166, 17, 9, 67, 7, 18, 12, 30, 10, 23, 58, 166, 23, 12, 77, 10, 33, 9, 88, 10, 0, 1, 58, 166, 1, 12, 110, 27, 8, 58, 166, 17, 58, 166, 17, 12, 84, 10, 18, 12, 110, 10, 23, 58, 166, 23, 81, 88, 10, 33, 73, 110, 10, 0, 1, 58, 166, 1, 9, 84, 27, 8, 58, 166, 8, 73, 110, 14, 11, 73, 84, 14, 15, 70, 110, 14, 17, 58, 166, 17, 9, 77, 16, 18, 12, 34, 10, 23, 58, 166, 23, 9, 30, 16, 33, 12, 77, 10, 0, 1, 58, 166, 1, 9, 34, 7, 4, 12, 13, 10, 8, 58, 166, 8, 12, 13, 24, 17, 58, 166, 17, 73, 30, 24, 17, 9, 77, 24, 23, 58, 166, 23, 81, 30, 10, 23, 12, 77, 10, 33, 70, 34, 10, 33, 73, 88, 10, 0, 1, 58, 166, 1, 70, 77, 7, 1, 70, 110, 7, 4, 9, 88, 10, 4, 9, 84, 10, 8, 58, 166, 8, 73, 77, 14, 8, 73, 110, 14, 11, 73, 88, 14, 11, 73, 84, 14, 15, 70, 77, 14, 15, 70, 110, 14, 15, 9, 30

Epoch: 442 | Time: 0m 39s
	Train Loss: 0.044 | Train PPL:   1.045
	 Val. Loss: 7.505 |  Val. PPL: 1817.010
Epoch: 443 | Time: 0m 40s
	Train Loss: 0.043 | Train PPL:   1.044
	 Val. Loss: 7.524 |  Val. PPL: 1852.460
Epoch: 444 | Time: 0m 39s
	Train Loss: 0.043 | Train PPL:   1.044
	 Val. Loss: 7.589 |  Val. PPL: 1975.509
Epoch: 445 | Time: 0m 39s
	Train Loss: 0.044 | Train PPL:   1.045
	 Val. Loss: 7.648 |  Val. PPL: 2096.537
Epoch: 446 | Time: 0m 39s
	Train Loss: 0.044 | Train PPL:   1.044
	 Val. Loss: 7.537 |  Val. PPL: 1876.922
Epoch: 447 | Time: 0m 39s
	Train Loss: 0.043 | Train PPL:   1.044
	 Val. Loss: 7.531 |  Val. PPL: 1864.478
Epoch: 448 | Time: 0m 39s
	Train Loss: 0.043 | Train PPL:   1.044
	 Val. Loss: 7.582 |  Val. PPL: 1961.905
Epoch: 449 | Time: 0m 39s
	Train Loss: 0.043 | Train PPL:   1.044
	 Val. Loss: 7.601 |  Val. PPL: 2000.026
Epoch: 450 | Time: 0m 40s
	Train Loss: 0.043 | Train PPL:   1.043
	 Val. Loss: 7.643 |  Val. PPL: 2086.664
Epoch: 451 | Time: 0m 40s
	Train Loss

[0, 1, 58, 166, 8, 58, 166, 17, 58, 166, 17, 9, 30, 24, 23, 58, 166, 23, 9, 30, 16, 33, 12, 34, 16, 0, 1, 58, 166, 1, 12, 77, 10, 4, 73, 88, 10, 8, 58, 166, 8, 73, 77, 14, 11, 73, 88, 14, 15, 70, 77, 14, 15, 12, 34, 16, 17, 58, 166, 17, 9, 30, 7, 18, 12, 34, 10, 23, 58, 166, 23, 12, 77, 10, 33, 9, 88, 10, 0, 1, 58, 166, 1, 12, 110, 27, 8, 58, 166, 17, 58, 166, 17, 12, 84, 10, 18, 12, 110, 10, 23, 58, 166, 23, 81, 88, 10, 33, 73, 110, 10, 0, 1, 58, 166, 1, 9, 84, 27, 8, 58, 166, 8, 73, 110, 14, 11, 73, 84, 14, 15, 70, 110, 14, 17, 58, 166, 17, 9, 77, 16, 18, 12, 34, 10, 23, 58, 166, 23, 9, 30, 16, 33, 12, 77, 10, 0, 1, 58, 166, 1, 9, 34, 7, 4, 12, 13, 10, 8, 58, 166, 8, 12, 13, 24, 17, 58, 166, 17, 73, 30, 24, 17, 9, 77, 24, 23, 58, 166, 23, 81, 30, 10, 23, 12, 77, 10, 33, 70, 34, 10, 33, 73, 88, 10, 0, 1, 58, 166, 1, 70, 77, 7, 1, 70, 110, 7, 4, 9, 88, 10, 4, 9, 84, 10, 8, 58, 166, 8, 73, 77, 14, 8, 73, 110, 14, 11, 73, 88, 14, 11, 73, 84, 14, 15, 70, 77, 14, 15, 70, 110, 14, 15, 9, 34

In [None]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [25]:
check_mode_collapse(model)

[0, 1, 2, 89, 67, 51, 61, 66, 91, 53, 73, 15, 13, 51, 57, 22, 70, 51, 61, 66, 74, 51, 57, 22, 0, 1, 53, 73, 22, 67, 64, 73, 59, 74, 35, 9, 15, 0, 1, 54, 57, 15, 67, 35, 73, 22, 4, 53, 96, 77, 74, 53, 57, 15, 0, 1, 63, 9, 15, 67, 14, 62, 103, 70, 54, 50, 137, 74, 35, 61, 15, 0, 1, 11, 61, 68]
[0, 1, 2, 142, 67, 24, 40, 15, 4, 11, 76, 47, 8, 2, 142, 91, 38, 9, 15, 13, 2, 142, 13, 32, 21, 15, 70, 38, 76, 15, 16, 14, 82, 34, 17, 2, 142, 74, 38, 73, 15, 0, 1, 2, 142, 1, 38, 61, 15, 67, 24, 82, 15, 4, 24, 40, 34, 8, 2, 142, 13, 2, 142, 13, 11, 76, 22, 70, 14, 40, 22, 16, 38, 97, 34, 17, 2, 142, 27, 32, 40, 22, 27, 11, 97, 52, 27, 38, 118, 52, 0, 1, 2, 142, 1, 32, 82, 15, 67, 54, 82, 7, 67, 38, 40, 85, 4, 32, 118, 77, 8, 2, 142, 13, 2, 142, 16, 38, 55, 108, 17, 2, 142, 17, 38, 33, 31, 17, 38, 21, 31, 27, 38, 26, 52, 27, 38, 20, 52, 0, 1, 2, 142, 1, 24, 50, 31, 1, 24, 96, 31, 4, 38, 26, 34, 4, 38, 20, 34, 8, 2, 142, 10, 38, 6, 7, 13, 38, 82, 7, 16, 11, 21, 31, 17, 38, 97, 7, 17, 2, 142, 27, 38

[0, 1, 2, 95, 67, 11, 33, 41, 8, 2, 95, 72, 5, 21, 60, 13, 2, 95, 17, 2, 95, 0, 1, 2, 95, 23, 32, 21, 22, 8, 2, 95, 8, 64, 73, 22, 72, 38, 57, 47, 13, 2, 95, 78, 51, 12, 15, 17, 2, 95, 17, 64, 9, 15, 90, 38, 50, 36, 0, 1, 2, 95, 8, 2, 95, 8, 32, 33, 22, 72, 64, 6, 42, 13, 2, 95, 78, 32, 9, 15, 17, 2, 95, 17, 32, 57, 15, 90, 11, 9, 41, 0, 1, 2, 95, 67, 5, 21, 41, 8, 2, 95, 72, 38, 50, 103, 13, 2, 95, 0, 67, 64, 33, 34, 67, 32, 26, 34, 72, 32, 9, 68, 72, 32, 21, 112, 0, 23, 54, 57, 15, 23, 64, 12, 15, 8, 51, 6, 15, 8, 32, 9, 15, 72, 51, 6, 28, 72, 32, 57, 28, 78, 64, 6, 15, 78, 64, 9, 15, 17, 64, 33, 15, 17, 54, 50, 22, 90, 51, 122, 28, 90, 38, 33, 28, 0, 8, 38, 92, 15, 8, 38, 33, 15, 72, 54, 111, 7, 72, 54, 6, 7, 10, 51, 29, 7, 10, 51, 9, 7, 91, 54, 55, 31, 91, 54, 57, 31, 70, 51, 29, 7, 70, 51, 9, 7, 16, 32, 111, 7, 16, 32, 6, 7, 78, 64, 122, 7, 78, 64, 30, 7, 90, 64, 55, 7, 90, 64, 57, 7, 27, 64, 30, 15, 27, 64, 12, 15, 74, 51, 33, 15, 74, 51, 21, 15, 0, 67, 64, 48, 15, 67, 64, 86, 15

0

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [18]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, model, optimizer)

model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint


In [29]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

2762.6974862252796


In [None]:
generated_outputs = folder +  "/generated_samples_500epochs"
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [26]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [27]:
for i in range(len(test_outro)):
    if len(test_outro) > 1200:
        continue
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, outro, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
        
        


[0, 1, 2, 122, 1, 69, 13, 94, 18, 75, 39, 10, 50, 79, 41, 10, 23, 70, 30, 16, 33, 69, 41, 7, 0, 1, 5, 34, 82, 15, 75, 30, 10, 52, 76, 41, 10, 17, 75, 13, 66, 32, 76, 39, 10, 33, 79, 41, 10, 53, 70, 30, 14, 0, 1, 12, 34, 16, 4, 69, 13, 57, 53, 79, 34, 14, 53, 19, 30, 14, 53, 73, 41, 14, 53, 75, 41, 14, 53, 73, 46, 14, 0, 1, 72, 47, 14, 1, 81, 22, 7, 4, 81, 61, 7, 8, 70, 39, 7, 11, 75, 13, 116]
0
[0, 1, 2, 103, 8, 2, 103, 17, 2, 103, 17, 12, 41, 10, 36, 81, 65, 16, 18, 81, 41, 10, 50, 5, 34, 31, 23, 2, 103, 32, 70, 65, 10, 33, 81, 43, 7, 0, 1, 2, 103, 1, 81, 77, 66, 8, 2, 103, 15, 81, 41, 10, 52, 81, 39, 10, 17, 2, 103, 17, 12, 41, 7, 18, 81, 20, 16, 50, 81, 39, 24, 23, 2, 103, 32, 81, 47, 10, 33, 73, 34, 7, 0, 1, 2, 103, 1, 81, 39, 45, 8, 2, 103, 15, 12, 77, 10, 52, 12, 71, 10, 17, 2, 103, 17, 12, 34, 7, 18, 73, 65, 16, 50, 70, 30, 24, 23, 2, 103, 32, 70, 65, 10, 33, 12, 43, 31, 0, 1, 2, 103, 1, 81, 77, 24, 4, 81, 46, 16, 48, 81, 39, 66, 8, 2, 103, 52, 81, 34, 16, 17, 2, 103, 17, 12, 71

[0, 1, 58, 68, 50, 73, 86, 66, 0, 1, 73, 86, 7, 4, 73, 159, 114, 0, 54, 73, 86, 7, 48, 81, 126, 10, 8, 81, 159, 7, 15, 70, 86, 10, 52, 73, 159, 10, 36, 72, 159, 10, 18, 73, 126, 109, 0, 1, 81, 86, 66, 8, 72, 44, 7, 15, 70, 86, 7, 17, 73, 86, 16, 18, 73, 86, 14, 50, 70, 86, 10, 23, 70, 86, 14, 32, 73, 86, 10, 33, 72, 44, 10, 53, 72, 86, 10, 0, 1, 73, 43, 7, 4, 70, 34, 7, 8, 81, 44, 7, 15, 73, 86, 7, 17, 73, 43, 7, 18, 73, 34, 10, 50, 73, 43, 31, 23, 12, 71, 16, 33, 73, 86, 27, 0, 1, 70, 86, 7, 4, 72, 43, 7, 8, 70, 88, 7, 15, 73, 86, 111, 17, 73, 86, 16, 18, 73, 44, 16, 23, 73, 86, 7, 33, 73, 86, 66, 0, 1, 73, 44, 10, 4, 12, 43, 10, 8, 73, 44, 16, 15, 81, 43, 16, 36, 73, 34, 16, 23, 73, 43, 10, 53, 73, 34, 16, 0, 54, 73, 34, 7, 48, 73, 41, 7, 8, 73, 39, 27, 8, 73, 39, 109]
9
[0, 1, 2, 87, 1, 73, 84, 21, 8, 2, 87, 17, 2, 87, 17, 70, 22, 14, 36, 70, 22, 10, 18, 73, 22, 14, 50, 70, 20, 10, 23, 2, 87, 23, 81, 61, 7, 33, 70, 64, 10, 53, 81, 61, 10, 0, 1, 2, 87, 1, 73, 6, 7, 4, 70, 61, 10, 48,

[0, 1, 58, 131, 1, 73, 41, 10, 8, 58, 131, 15, 73, 39, 31, 15, 5, 71, 31, 17, 58, 131, 36, 72, 39, 10, 36, 72, 71, 14, 23, 58, 131, 23, 72, 39, 10, 23, 73, 71, 10, 0, 1, 58, 131, 1, 70, 46, 16, 1, 79, 34, 16, 8, 58, 131, 8, 19, 67, 10, 11, 19, 39, 10, 15, 19, 39, 10, 52, 19, 71, 10, 17, 58, 131, 17, 19, 46, 10, 36, 19, 39, 10, 18, 19, 71, 10, 50, 19, 34, 10, 23, 58, 131, 23, 19, 43, 10, 32, 19, 44, 10, 33, 19, 44, 10, 53, 19, 86, 10, 0, 1, 58, 131, 1, 19, 159, 10, 54, 19, 125, 10, 4, 19, 123, 66, 8, 58, 131, 15, 19, 88, 10, 52, 19, 43, 10, 17, 19, 34, 10, 36, 19, 34, 10, 18, 19, 125, 10, 50, 19, 124, 10, 23, 19, 125, 10, 32, 19, 124, 10, 33, 19, 125, 7, 0, 1, 19, 126, 7, 8, 19, 125, 10, 11, 19, 132, 10, 15, 19, 125, 10, 52, 19, 86, 10, 17, 19, 126, 10, 36, 19, 125, 10, 18, 19, 126, 10, 50, 19, 93, 10, 23, 19, 125, 10, 32, 19, 125, 10, 33, 19, 125, 10, 53, 19, 132, 10, 0, 1, 19, 126, 7, 8, 19, 93, 10, 11, 19, 126, 10, 15, 19, 86, 7, 17, 19, 125, 10, 36, 19, 132, 10, 18, 19, 126, 7, 23, 

[0, 1, 2, 134, 1, 73, 65, 10, 54, 73, 41, 16, 48, 12, 67, 7, 8, 2, 134, 11, 81, 65, 45, 17, 2, 134, 50, 81, 46, 7, 23, 2, 134, 32, 70, 61, 7, 53, 73, 64, 7, 0, 54, 73, 61, 82, 52, 73, 46, 7, 36, 73, 61, 7, 50, 70, 46, 7, 32, 73, 65, 66, 0, 54, 12, 46, 7, 48, 12, 65, 7, 11, 73, 62, 51, 50, 73, 64, 7, 32, 73, 61, 7, 53, 73, 46, 7, 0, 54, 73, 64, 7, 48, 73, 61, 7, 11, 73, 46, 121, 0, 54, 81, 61, 7, 48, 12, 46, 7, 11, 81, 65, 51, 50, 73, 41, 7, 32, 5, 65, 57, 0, 48, 81, 41, 7, 11, 73, 71, 45, 50, 73, 67, 7, 32, 81, 65, 27, 0, 54, 73, 65, 16, 48, 72, 46, 7, 11, 73, 61, 45, 50, 73, 46, 7, 32, 70, 64, 7, 53, 73, 99, 31, 0, 54, 73, 62, 7, 48, 70, 64, 7, 11, 70, 61, 111]
23
[0, 1, 2, 185, 54, 5, 22, 31, 54, 5, 65, 31, 48, 5, 22, 16, 48, 72, 30, 16, 11, 12, 61, 24, 11, 12, 77, 24, 17, 81, 99, 51, 17, 81, 65, 51, 53, 26, 55, 16, 53, 26, 13, 16, 0, 54, 5, 25, 16, 54, 5, 65, 16, 48, 5, 61, 16, 48, 5, 77, 16, 11, 5, 22, 31, 11, 5, 30, 31, 17, 9, 25, 24, 17, 9, 67, 24, 50, 9, 130, 27, 50, 9, 61, 27, 

[0, 1, 2, 122, 1, 69, 46, 31, 4, 69, 46, 7, 8, 2, 122, 15, 69, 46, 10, 52, 75, 61, 31, 17, 2, 122, 17, 12, 64, 14, 36, 12, 46, 10, 18, 72, 46, 16, 23, 2, 122, 23, 9, 46, 10, 32, 73, 47, 10, 33, 75, 62, 16, 53, 73, 61, 10, 0, 1, 2, 122, 1, 72, 100, 24, 8, 2, 122, 8, 5, 62, 10, 11, 81, 46, 10, 15, 72, 65, 10, 52, 72, 46, 10, 17, 2, 122, 17, 72, 41, 94, 23, 2, 122, 23, 76, 71, 16, 32, 70, 41, 10, 33, 76, 41, 10, 53, 69, 71, 10, 0, 1, 2, 122, 1, 70, 77, 10, 54, 72, 43, 10, 4, 75, 44, 10, 4, 73, 71, 10, 48, 79, 41, 14, 8, 2, 122, 8, 73, 77, 16, 11, 70, 71, 16, 15, 75, 61, 14, 52, 70, 77, 10, 17, 2, 122, 17, 73, 71, 14, 36, 73, 65, 14, 18, 81, 46, 27, 23, 2, 122, 33, 79, 46, 14, 53, 81, 47, 10, 0, 1, 9, 46, 10, 54, 12, 71, 10, 54, 12, 61, 10, 4, 72, 77, 10, 48, 73, 71, 10, 11, 12, 67, 10, 15, 73, 71, 14, 15, 81, 62, 10, 52, 81, 41, 24, 17, 2, 122, 50, 70, 41, 10, 23, 2, 122, 53, 76, 67, 10, 23, 2, 122, 53, 70, 65, 10, 0, 1, 2, 122, 11, 81, 20, 10, 15, 69, 34, 10, 52, 5, 62, 10, 17, 2, 122, 2

[0, 1, 2, 3, 52, 76, 55, 66, 50, 75, 56, 10, 23, 139, 99, 82, 53, 139, 61, 14, 0, 1, 75, 22, 16, 4, 76, 20, 14, 48, 78, 22, 14, 8, 79, 20, 10, 11, 79, 22, 7, 15, 76, 20, 31, 15, 76, 22, 14, 52, 69, 49, 28, 0, 4, 76, 62, 14, 48, 78, 62, 14, 8, 26, 49, 14, 8, 76, 22, 7, 15, 76, 20, 7, 17, 79, 22, 7, 18, 139, 49, 10, 50, 76, 22, 31, 23, 75, 49, 82, 33, 79, 22, 10, 53, 78, 22, 14, 0, 1, 139, 22, 14, 54, 76, 62, 10, 4, 69, 55, 14, 48, 12, 49, 10, 8, 73, 22, 66, 17, 79, 62, 14, 36, 139, 22, 14, 18, 79, 20, 10, 18, 79, 22, 14, 23, 139, 6, 14, 32, 69, 13, 14, 33, 79, 65, 14, 33, 139, 13, 14, 53, 139, 6, 14, 0, 1, 75, 64, 14, 54, 139, 62, 14, 4, 69, 99, 82, 17, 79, 30, 14, 36, 75, 67, 10, 50, 139, 47, 66, 0, 1, 81, 13, 10, 54, 81, 46, 10, 4, 70, 47, 10, 48, 75, 13, 10, 8, 79, 47, 10, 11, 75, 39, 10, 15, 76, 65, 10, 52, 75, 13, 16, 17, 70, 71, 66, 53, 69, 41, 14, 18, 79, 39, 16, 23, 78, 20, 10, 32, 69, 39, 10, 33, 75, 41, 10, 53, 69, 6, 10, 0, 1, 75, 30, 10, 54, 76, 39, 10, 4, 139, 46, 10, 4, 74

[0, 1, 2, 135, 4, 81, 39, 31, 8, 81, 41, 66, 17, 73, 13, 14, 18, 73, 41, 14, 50, 73, 39, 10, 23, 73, 13, 10, 32, 73, 13, 14, 33, 81, 41, 14, 53, 73, 39, 14, 0, 1, 73, 13, 10, 4, 81, 46, 14, 8, 73, 13, 66, 18, 81, 46, 10, 50, 72, 13, 10, 23, 81, 39, 10, 33, 72, 20, 14, 53, 73, 47, 10, 0, 1, 73, 20, 66, 52, 73, 20, 10, 17, 73, 39, 10, 18, 81, 39, 7, 23, 81, 30, 10, 32, 73, 41, 10, 33, 73, 46, 14, 33, 73, 41, 10, 53, 73, 41, 14, 0, 1, 81, 46, 10, 54, 73, 39, 10, 4, 81, 41, 14, 48, 73, 39, 10, 8, 73, 39, 14, 11, 73, 41, 14, 15, 81, 41, 10, 52, 73, 39, 10, 17, 73, 41, 10, 36, 73, 34, 10, 18, 70, 34, 10, 50, 73, 43, 10, 23, 73, 43, 14, 32, 73, 34, 14, 33, 73, 159, 14, 53, 73, 39, 10, 0, 1, 73, 46, 14, 4, 12, 44, 14, 48, 81, 47, 14, 8, 81, 41, 14, 11, 73, 13, 14, 11, 73, 47, 14, 52, 70, 41, 14, 17, 73, 43, 14, 18, 12, 44, 14, 50, 70, 34, 14, 23, 73, 44, 14, 23, 81, 43, 14, 32, 73, 44, 14, 32, 81, 44, 14, 33, 81, 41, 14, 33, 70, 34, 14, 53, 81, 44, 14, 0, 1, 73, 41, 10, 4, 81, 41, 14, 54, 73, 

[0, 1, 2, 107, 8, 12, 47, 66, 18, 72, 46, 31, 23, 72, 47, 7, 33, 72, 39, 16, 0, 1, 72, 43, 66, 8, 72, 30, 42, 0, 8, 72, 34, 82, 18, 72, 44, 10, 23, 72, 41, 31, 33, 72, 43, 7, 0, 1, 72, 39, 7, 4, 72, 41, 7, 8, 72, 39, 90, 0, 8, 72, 46, 85, 18, 72, 39, 7, 23, 72, 41, 7, 33, 72, 34, 16, 0, 1, 72, 43, 16, 4, 72, 44, 7, 15, 72, 43, 115, 0, 8, 72, 34, 82, 18, 72, 39, 10, 23, 72, 43, 7, 33, 72, 39, 7, 0, 1, 72, 41, 16, 4, 72, 34, 16, 48, 72, 34, 37]
48
[0, 1, 58, 171, 54, 139, 71, 66, 11, 139, 65, 66, 50, 81, 39, 28, 0, 11, 73, 43, 45, 50, 79, 34, 45, 0, 54, 72, 71, 82, 11, 75, 41, 109, 0, 11, 35, 86, 51, 50, 35, 84, 82, 53, 40, 43, 82, 0, 54, 5, 86, 14, 11, 5, 86, 10, 36, 9, 86, 14, 32, 26, 86, 10, 0, 54, 5, 47, 10, 4, 72, 39, 16, 48, 73, 41, 10, 8, 9, 71, 16, 11, 5, 43, 45, 50, 9, 34, 31, 53, 40, 71, 24, 0, 4, 26, 41, 7]
49
[0, 1, 2, 68, 8, 12, 132, 111, 0, 1, 12, 93, 66, 8, 81, 132, 114, 0, 8, 73, 84, 42, 8, 81, 132, 111, 0, 1, 73, 93, 7, 4, 73, 84, 7, 8, 5, 88, 105, 8, 5, 132, 105]
50
[0,

[0, 1, 58, 101, 48, 70, 34, 14, 48, 81, 86, 14, 8, 58, 101, 11, 12, 34, 10, 52, 12, 34, 10, 17, 58, 101, 17, 12, 30, 14, 36, 5, 34, 16, 50, 12, 125, 14, 23, 58, 101, 23, 9, 132, 14, 32, 12, 125, 14, 33, 5, 132, 14, 53, 26, 86, 10, 0, 1, 58, 101, 54, 26, 86, 24, 8, 58, 101, 11, 5, 13, 10, 52, 5, 13, 14, 17, 58, 101, 17, 9, 6, 14, 36, 9, 13, 14, 50, 26, 110, 14, 23, 58, 101, 23, 60, 88, 14, 32, 26, 110, 14, 33, 9, 88, 14, 53, 9, 30, 10, 0, 1, 58, 101, 54, 9, 34, 24, 8, 58, 101, 11, 12, 39, 7, 52, 12, 67, 45, 52, 9, 93, 24, 17, 58, 101, 50, 5, 93, 10, 23, 58, 101, 32, 9, 132, 27, 0, 1, 58, 101, 54, 5, 86, 10, 48, 5, 93, 10, 8, 58, 101, 11, 12, 125, 91, 52, 73, 39, 14, 17, 81, 39, 14, 36, 81, 39, 14, 53, 73, 39, 14, 0, 1, 81, 39, 14, 54, 81, 39, 14, 4, 81, 39, 14, 48, 70, 13, 16]
58
[0, 1, 2, 152, 54, 69, 41, 14, 54, 76, 34, 14, 4, 74, 41, 14, 4, 69, 34, 14, 4, 75, 34, 14, 48, 79, 41, 14, 48, 78, 34, 14, 8, 78, 41, 14, 8, 75, 34, 14, 11, 69, 41, 82, 36, 12, 39, 10, 18, 12, 13, 10, 50, 9, 4

[0, 1, 2, 162, 4, 35, 77, 14, 8, 35, 77, 14, 15, 35, 110, 10, 17, 26, 84, 14, 18, 35, 132, 7, 23, 35, 93, 7, 0, 8, 29, 84, 10, 11, 26, 110, 10, 15, 12, 84, 10, 52, 26, 110, 10, 17, 19, 77, 27, 0, 1, 9, 67, 10, 54, 5, 77, 14, 4, 26, 110, 10, 8, 35, 77, 66, 0, 1, 29, 88, 16, 4, 5, 84, 10, 8, 35, 132, 7, 15, 35, 93, 51, 0, 54, 81, 132, 14, 4, 29, 132, 16, 8, 29, 93, 16, 15, 35, 132, 16, 17, 35, 170, 7, 18, 35, 170, 16, 23, 35, 170, 7, 33, 40, 170, 109, 0, 18, 40, 170, 10, 50, 9, 132, 10, 23, 5, 84, 16, 32, 9, 77, 10, 32, 35, 93, 150]
66
[0, 1, 2, 135, 4, 35, 46, 14, 48, 40, 46, 14, 8, 35, 65, 14, 11, 9, 65, 14, 15, 35, 46, 14, 52, 9, 46, 14, 17, 19, 61, 14, 36, 26, 61, 14, 18, 35, 46, 14, 50, 19, 61, 14, 50, 40, 46, 14, 23, 5, 61, 14, 32, 9, 61, 14, 33, 40, 62, 14, 53, 5, 62, 14, 0, 1, 9, 99, 14, 54, 26, 99, 14, 4, 29, 100, 14, 48, 5, 99, 14, 8, 35, 62, 10, 11, 35, 62, 10, 52, 40, 62, 14, 17, 29, 62, 14, 18, 26, 64, 10, 50, 9, 62, 10, 23, 12, 99, 10, 32, 35, 62, 10, 33, 5, 99, 10, 53, 40,

[0, 1, 58, 101, 52, 5, 43, 7, 36, 19, 43, 121, 0, 48, 5, 20, 31, 11, 26, 44, 66, 36, 26, 110, 28, 53, 40, 159, 7, 0, 54, 29, 86, 66, 11, 40, 110, 45, 50, 9, 44, 7, 32, 35, 86, 109, 0, 11, 35, 159, 14, 52, 40, 159, 31, 36, 35, 86, 42, 0, 54, 5, 20, 66, 11, 26, 13, 109, 32, 26, 41, 7, 53, 40, 34, 7, 0, 54, 35, 41, 109, 36, 26, 47, 66, 32, 19, 13, 45, 0, 48, 19, 39, 7, 11, 5, 46, 66, 36, 9, 47, 21, 0, 54, 29, 41, 66, 11, 35, 34, 94, 32, 9, 41, 66, 0, 54, 19, 39, 85, 11, 26, 47, 66, 36, 29, 13, 66, 32, 35, 13, 168, 0, 48, 12, 41, 7, 11, 9, 41, 7, 52, 70, 34, 31, 36, 73, 13, 31, 53, 5, 62, 10, 53, 73, 62, 10]
73
[0, 1, 58, 101, 1, 12, 77, 14, 54, 81, 77, 10, 4, 26, 71, 14, 48, 12, 77, 14, 8, 58, 101, 8, 81, 67, 14, 11, 12, 67, 14, 15, 5, 65, 14, 52, 12, 67, 14, 17, 58, 101, 17, 26, 65, 14, 36, 81, 65, 14, 18, 5, 46, 14, 50, 81, 65, 14, 23, 58, 101, 23, 73, 46, 14, 32, 81, 46, 14, 33, 12, 64, 14, 53, 5, 64, 14, 0, 1, 58, 101, 1, 12, 25, 66, 8, 58, 101, 8, 9, 126, 66, 17, 58, 101, 17, 9, 93, 

[0, 1, 2, 87, 11, 12, 192, 14, 11, 5, 130, 10, 15, 81, 164, 10, 15, 12, 55, 10, 52, 81, 56, 10, 52, 5, 55, 10, 17, 26, 56, 16, 17, 29, 99, 14, 36, 35, 130, 109, 36, 40, 25, 57, 0, 54, 40, 55, 85, 11, 9, 192, 14, 11, 19, 130, 10, 15, 81, 192, 10, 15, 73, 56, 10, 52, 26, 130, 10, 52, 5, 55, 10, 36, 29, 56, 16, 36, 29, 99, 10, 50, 29, 130, 85, 50, 19, 99, 51, 0, 11, 40, 25, 10, 15, 81, 22, 10, 52, 5, 61, 10, 17, 26, 61, 14, 36, 19, 67, 57, 53, 9, 39, 10, 0, 1, 19, 67, 10, 54, 5, 30, 85, 48, 40, 77, 7, 11, 26, 30, 10, 15, 73, 13, 10, 52, 26, 34, 10, 17, 5, 30, 14, 17, 19, 30, 51]
81
[0, 1, 2, 160, 8, 2, 160, 17, 2, 160, 23, 2, 160, 23, 81, 43, 10, 32, 72, 44, 10, 33, 69, 84, 10, 53, 75, 86, 10, 0, 1, 2, 160, 1, 81, 159, 45, 8, 2, 160, 52, 139, 86, 14, 17, 2, 160, 17, 70, 84, 10, 17, 70, 43, 82, 23, 2, 160, 0, 1, 2, 160, 1, 60, 34, 16, 54, 81, 71, 10, 48, 73, 41, 16, 8, 2, 160, 8, 60, 71, 16, 11, 81, 41, 10, 15, 5, 39, 16, 17, 2, 160, 17, 12, 41, 85, 23, 2, 160, 53, 74, 39, 14, 0, 1, 2, 160

[0, 1, 58, 171, 54, 70, 61, 14, 4, 75, 22, 7, 8, 72, 61, 16, 15, 12, 6, 109, 15, 73, 88, 16, 52, 73, 77, 10, 17, 69, 34, 10, 18, 9, 88, 16, 0, 4, 72, 61, 7, 8, 79, 88, 10, 15, 81, 110, 10, 52, 73, 77, 14, 17, 69, 6, 31, 17, 12, 30, 14, 18, 79, 65, 16, 18, 9, 88, 31, 23, 70, 34, 16, 33, 12, 67, 10, 0, 1, 12, 30, 16, 54, 69, 65, 10, 4, 70, 13, 16, 4, 70, 93, 31, 8, 75, 65, 16, 8, 75, 110, 10, 15, 29, 67, 154, 15, 70, 93, 16, 52, 69, 84, 10, 17, 79, 88, 16, 18, 60, 132, 66, 0, 1, 75, 88, 7, 4, 73, 132, 16, 8, 81, 93, 16, 15, 79, 84, 10, 52, 69, 110, 16, 17, 75, 67, 31, 17, 79, 88, 10, 18, 72, 67, 7, 18, 75, 84, 7, 23, 75, 34, 7, 33, 73, 77, 7, 0, 1, 70, 34, 7, 4, 70, 30, 7, 8, 75, 67, 7, 15, 12, 30, 154, 15, 70, 110, 16, 52, 139, 88, 10, 17, 75, 77, 10, 18, 60, 110, 37, 0, 17, 73, 30, 7, 18, 73, 67, 7, 23, 72, 13, 31, 33, 9, 67, 31, 0, 1, 73, 65, 7, 4, 73, 13, 7, 8, 73, 6, 10, 15, 75, 6, 14, 15, 75, 88, 14, 17, 72, 13, 10, 17, 72, 110, 10, 18, 5, 65, 14, 18, 5, 84, 14, 23, 60, 6, 10, 23, 

[0, 1, 58, 171, 1, 73, 46, 7, 4, 9, 65, 31, 8, 19, 71, 82, 18, 19, 71, 7, 23, 40, 77, 24, 0, 1, 40, 77, 66, 8, 29, 71, 45, 0, 8, 29, 71, 24, 17, 9, 71, 16, 18, 5, 65, 7, 23, 5, 67, 16, 33, 29, 67, 10, 53, 9, 41, 10, 0, 1, 19, 67, 7, 4, 81, 65, 7, 8, 9, 46, 28, 33, 5, 64, 7, 0, 1, 12, 61, 7, 4, 5, 46, 31, 8, 26, 65, 82, 18, 12, 71, 7, 23, 5, 64, 45, 0, 8, 29, 65, 24, 17, 9, 65, 16, 18, 19, 6, 16, 23, 9, 61, 51, 0, 8, 9, 65, 16, 15, 5, 64, 24, 18, 12, 61, 7, 23, 81, 65, 7, 33, 73, 6, 7, 0, 1, 81, 61, 66, 8, 73, 64, 133, 15, 12, 71, 16, 17, 73, 71, 7, 18, 81, 77, 7, 0, 1, 81, 126, 7]
98
[0, 1, 58, 166, 4, 5, 67, 7, 8, 81, 30, 45, 18, 73, 67, 7, 23, 73, 67, 45, 0, 4, 73, 13, 7, 8, 12, 6, 7, 15, 73, 13, 7, 17, 73, 13, 7, 18, 73, 25, 7, 23, 73, 61, 66, 0, 1, 73, 25, 7, 4, 72, 22, 7, 8, 70, 20, 45, 18, 73, 25, 7, 23, 75, 25, 7, 33, 75, 22, 7, 0, 1, 70, 25, 7, 4, 70, 25, 7, 48, 72, 22, 91]
99
[0, 1, 2, 87, 54, 5, 6, 10, 4, 5, 88, 10, 48, 9, 110, 10, 8, 2, 87, 8, 81, 30, 10, 11, 5, 30, 10, 15, 

[0, 1, 2, 68, 1, 9, 6, 10, 54, 73, 13, 10, 4, 9, 39, 16, 48, 70, 67, 10, 8, 60, 39, 115, 0, 1, 12, 6, 14, 1, 72, 39, 16, 4, 60, 13, 10, 48, 12, 6, 14, 8, 9, 13, 7, 15, 73, 6, 14, 52, 12, 6, 104, 0, 4, 9, 6, 10, 48, 26, 47, 10, 8, 9, 20, 7, 15, 76, 47, 10, 52, 12, 6, 31, 18, 9, 13, 16, 23, 60, 6, 82, 0, 4, 9, 13, 10, 48, 60, 39, 14, 8, 26, 67, 16, 15, 12, 39, 10, 52, 60, 13, 24, 18, 70, 6, 10, 50, 73, 13, 82, 0, 1, 81, 6, 10, 54, 5, 13, 14, 4, 5, 39, 10, 48, 73, 67, 10, 8, 19, 39, 109, 0, 1, 26, 39, 16, 4, 9, 67, 16, 48, 19, 30, 16, 8, 26, 34, 104, 0, 4, 12, 30, 10, 48, 5, 67, 14, 8, 60, 39, 7, 15, 12, 13, 14, 52, 9, 13, 7, 18, 9, 6, 16, 23, 9, 47, 31, 53, 9, 20, 31, 0, 4, 12, 47, 7, 8, 12, 6, 133]
109
[0, 1, 2, 83, 4, 9, 13, 7, 8, 12, 61, 16, 15, 12, 6, 7, 17, 75, 25, 51, 18, 12, 61, 28, 0, 4, 9, 41, 7, 8, 73, 61, 16, 15, 9, 64, 66, 18, 9, 61, 7, 23, 73, 62, 16, 33, 12, 6, 7, 0, 1, 12, 61, 7, 4, 26, 6, 24, 15, 9, 22, 27, 18, 12, 47, 27, 33, 12, 46, 24, 0, 54, 12, 65, 183]
110
[0, 1, 2,

In [28]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

25080
25080
17520
17520
29880
29880
26280
26280
23760
23760
26100
26100
35820
35820
32880
32880
22800
22800
25800
25800
20160
20160
38760
38760
25380
25380
22500
22500
25800
25800
23400
23400
25320
25320
18540
18540
40620
40620
24840
24840
25620
25620
20520
20520


In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
vocab_folder = "vocab_outro/"
save_vocab(intro_field.vocab, vocab + 'intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab + 'solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + 'outro_vocab.pkl')

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    sentence = test_intro[i]
    #print(sentence)
    list_sentence = [int(x) for x in sentence.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence_beam(model, sentence, intro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    #print(translated_sentence)
    utils.write_midi(list_sentence, word2event, generated_outputs + "/src" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict_beam" + str(i)  + ".mid")
    print(i)
    if i == 10:
        break

In [None]:
beam = [2, 59, 119, 13, 212, 59, 212, 59, 59, 75, 59, 59, 13, 119, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 119, 59, 59, 59, 59, 166, 59, 59, 59, 13, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 59, 59, 212, 212, 13, 59, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 14, 59, 59, 212, 59, 212, 212, 59, 59, 59, 68, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 68, 59, 59, 212, 59, 59, 13, 59, 59, 59, 59, 59, 59, 97, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 158, 59, 59, 13, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 13, 212, 13, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 158, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 13, 13, 59, 59, 13, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 13, 59, 166, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 86, 59, 212, 212, 212, 59, 59, 59, 59, 212, 59, 86, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 212, 59, 59, 166, 59, 86, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 59, 59, 59, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 212, 212, 158, 59, 59, 13, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 158, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 212, 212, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 13, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 212, 86, 59, 59, 59, 158, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 212, 59, 13, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 68, 59, 13, 59, 59, 13, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 14, 59, 59, 59, 59, 59, 59, 13, 86, 59, 59, 59, 212, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 13, 59, 59, 59, 59, 68, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 212, 212, 13, 59, 166, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 212, 59, 59, 212, 59, 212, 59, 59, 13, 59, 59, 59, 59, 13, 59, 59, 14, 13, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 68, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 212, 59, 59, 13, 59, 59, 166, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 13, 59, 59, 212, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59]

translated_sentence1 = [solo_field.vocab.itos[idx] for idx in beam]
translated_sentence = [int(x) for x in translated_sentence1 if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']    
utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict1.mid")

In [None]:
translated_sentence

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    list_sentence = [int(x) for x in sentence.split(' ')]
    remi = [word2event[x] for x in list_sentence]
    print(remi)

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns