In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_augmented_presplit"
folder = "dynamic_augmented_models/2enc_3rd"
destination_folder = folder + "/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [2]:
# state = random.getstate()
# pickle.dump(state, open('./state.pkl', 'wb'))

import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor
state = pickle.load(open('./state.pkl', 'rb'))
random.setstate(state)

In [3]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/intro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/outro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/predict").mkdir(parents=True, exist_ok=True)

In [4]:
event2word, word2event = pickle.load(open('dictionary_augmented.pkl', 'rb'))

In [5]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [6]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=3)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=3)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=3)

In [7]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
    #print(intro.transpose(1,0).size(0))
    print(solo.size(1))

272
353
509
335
326
279
253
522
281
444
619
325
319
414


In [8]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [9]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        src2_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.src2_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src2_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = kk.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )
        
        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        ).to(self.device)
        embed_src2 = self.dropout(
            (self.src2_word_embedding(src2) + self.src2_position_embedding(src2_positions))
        ).to(self.device)
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        ).to(self.device)
        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        #print(src_padding_mask.size())
        #print(src2_padding_mask.size())
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [10]:
src_vocab_size = len(intro_field.vocab)
src2_vocab_size = len(outro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

model = Transformer(
    embedding_size,
    src_vocab_size,
    src2_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
model = model.to(device)

In [11]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=2e-4) #non augmented 3e-4


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 14,997,275 trainable parameters


In [12]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [13]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [14]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, src2, trg = intro.transpose(1,0), outro.transpose(1,0), solo.transpose(1,0)
        src, src2, trg = src.to(device), src2.to(device), trg.to(device)

        optimizer.zero_grad()
        output = model(src ,src2, trg[:-1, :])
        
#         print(output.size())
#         print(trg.size())
        
        output = output.view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)
        loss = criterion(output, trg)
#         print(torch.isfinite(trg).all().cpu().item())
#         print(torch.isfinite(output).all().cpu().item())
#         print(torch.isfinite(loss).all().cpu().item())
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, src2, trg = intro.transpose(1,0), outro.transpose(1,0), solo.transpose(1,0)
            src, src2, trg = src.to(device), src2.to(device), trg.to(device)

            output = model(src, src2, trg[:-1, :]) #turn off teacher forcing

            output = output.view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs



In [15]:
def translate_sentence(model, sentence, sentence2, intro, outro, solo, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, intro.init_token)
    tokens.append(intro.eos_token)

    tokens2.insert(0, outro.init_token)
    tokens2.append(outro.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [intro.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [outro.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [solo.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, sentence2_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == solo.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [solo.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [16]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [17]:
def check_mode_collapse(model):
    count = 0
    translations = []
    for i in range(3):
        if len(val_intro) > 1200:
            continue
        intro = val_intro[i]
        solo = val_solo[i]
        outro = val_outro[i]
        #print(intro)
        list_intro = [int(x) for x in intro.split(' ')]
        list_solo = [int(x) for x in solo.split(' ')]
        list_outro = [int(x) for x in outro.split(' ')]
        translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
        
        translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
        print(translated_sentence)
        translations.append(translated_sentence)
        if i > 0:
            if translations[i-1] == translations[i]:
                count += 1
    return count


In [18]:
N_EPOCHS = 500
S_EPOCH = 0
CLIP = 1
train_loss_log = []
valid_loss_log = []
best_valid_loss = float('inf')
#torch.autograd.set_detect_anomaly(True)
#model = nn.DataParallel(model, device_ids=[0,1]).to(device)
for epoch in range(S_EPOCH, N_EPOCHS):
    
    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)
    
    train_loss_log.append(train_loss)
    valid_loss_log.append(valid_loss)
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
        save_best_checkpoint(checkpoint,N_EPOCHS)
    if (epoch+1) % 20 == 0 or (epoch) % 20 == 0:
        save_final_checkpoint(checkpoint,epoch)
    if (epoch+1) % 25 ==0:
        if check_mode_collapse(model) > 1:
            print("model is mode collapsing")
save_final_checkpoint(checkpoint,N_EPOCHS)
test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 1m 1s
	Train Loss: 4.291 | Train PPL:  73.070
	 Val. Loss: 3.421 |  Val. PPL:  30.606
=> Saving checkpoint
=> Saving checkpoint
Epoch: 02 | Time: 1m 1s
	Train Loss: 3.022 | Train PPL:  20.531
	 Val. Loss: 2.845 |  Val. PPL:  17.201
=> Saving checkpoint
Epoch: 03 | Time: 1m 2s
	Train Loss: 2.697 | Train PPL:  14.837
	 Val. Loss: 2.637 |  Val. PPL:  13.970
=> Saving checkpoint
Epoch: 04 | Time: 1m 2s
	Train Loss: 2.523 | Train PPL:  12.462
	 Val. Loss: 2.480 |  Val. PPL:  11.935
=> Saving checkpoint
Epoch: 05 | Time: 1m 2s
	Train Loss: 2.357 | Train PPL:  10.554
	 Val. Loss: 2.340 |  Val. PPL:  10.378
=> Saving checkpoint
Epoch: 06 | Time: 1m 2s
	Train Loss: 2.200 | Train PPL:   9.024
	 Val. Loss: 2.156 |  Val. PPL:   8.640
=> Saving checkpoint
Epoch: 07 | Time: 1m 3s
	Train Loss: 2.041 | Train PPL:   7.702
	 Val. Loss: 2.058 |  Val. PPL:   7.829
=> Saving checkpoint
Epoch: 08 | Time: 1m 2s
	Train Loss: 1.947 | Train PPL:   7.005
	 Val. Loss: 2.030 |  Val. PPL:   7.615


[0, 1, 43, 114, 1, 53, 58, 31, 4, 53, 58, 31, 8, 53, 58, 31, 10, 53, 58, 34, 13, 53, 58, 34, 17, 53, 58, 34, 0, 1, 54, 58, 42, 10, 54, 58, 34, 16, 32, 58, 31, 17, 35, 58, 34, 0, 1, 51, 58, 28, 10, 54, 58, 41, 13, 54, 58, 41, 17, 35, 58, 34, 0, 1, 35, 58, 135, 0, 1, 35, 58, 31, 4, 54, 58, 37, 17, 51, 58, 34, 0, 1, 51, 58, 108, 0, 1, 53, 58, 108, 8, 35, 58, 31, 10, 51, 58, 31, 13, 54, 61, 34, 17, 51, 58, 34, 0, 1, 51, 58, 19]
[0, 1, 2, 175, 1, 53, 9, 15, 67, 32, 9, 22, 4, 51, 9, 52, 72, 32, 9, 80, 74, 51, 9, 22, 0, 1, 45, 9, 22, 67, 11, 9, 22, 4, 64, 9, 52, 72, 32, 9, 80, 74, 32, 9, 22, 0, 1, 51, 9, 22, 67, 32, 9, 22, 4, 32, 9, 52, 72, 51, 9, 66, 27, 11, 40, 52, 0, 67, 32, 40, 22, 4, 54, 40, 112]
Epoch: 51 | Time: 1m 3s
	Train Loss: 0.921 | Train PPL:   2.511
	 Val. Loss: 3.271 |  Val. PPL:  26.325
Epoch: 52 | Time: 1m 3s
	Train Loss: 0.905 | Train PPL:   2.472
	 Val. Loss: 3.290 |  Val. PPL:  26.841
Epoch: 53 | Time: 1m 3s
	Train Loss: 0.900 | Train PPL:   2.459
	 Val. Loss: 3.308 |  Va

[0, 1, 43, 117, 4, 123, 73, 34, 4, 123, 97, 34, 10, 45, 73, 31, 10, 45, 97, 31, 13, 45, 62, 31, 13, 45, 76, 31, 16, 63, 73, 31, 16, 63, 97, 31, 17, 71, 76, 31, 17, 71, 97, 31, 27, 49, 76, 34, 27, 49, 120, 34, 0, 4, 45, 73, 34, 4, 45, 97, 34, 10, 53, 73, 31, 10, 53, 97, 31, 13, 63, 62, 31, 13, 63, 76, 31, 16, 71, 73, 31, 16, 71, 97, 31, 17, 45, 76, 31, 17, 45, 120, 31, 27, 45, 73, 34, 27, 45, 97, 34, 0, 4, 63, 73, 34, 4, 63, 97, 34, 10, 45, 73, 31, 10, 45, 97, 31, 13, 63, 62, 31, 13, 63, 76, 31, 16, 63, 73, 31, 16, 63, 97, 31, 17, 71, 73, 31, 17, 71, 97, 31, 27, 53, 76, 34, 27, 53, 120, 34, 0, 4, 71, 73, 34, 4, 71, 97, 34, 10, 45, 73, 31, 10, 45, 97, 31, 13, 63, 62, 31, 13, 63, 76, 31, 16, 123, 73, 31, 16, 123, 97, 31, 17, 63, 20, 31, 17, 63, 120, 31, 90, 49, 73, 34, 90, 49, 97, 34]
[0, 1, 2, 89, 67, 51, 21, 66, 91, 53, 9, 15, 13, 51, 73, 22, 70, 51, 21, 66, 74, 51, 73, 22, 0, 1, 53, 9, 22, 67, 64, 9, 59, 74, 35, 9, 15, 0, 1, 54, 73, 15, 67, 35, 9, 22, 4, 53, 21, 77, 74, 53, 9, 15, 0, 1

[0, 1, 2, 95, 4, 11, 48, 31, 8, 11, 58, 31, 10, 38, 57, 31, 13, 38, 58, 31, 16, 5, 48, 34, 27, 25, 46, 52, 0, 1, 54, 48, 7, 4, 38, 100, 80, 16, 24, 128, 37, 0, 4, 38, 100, 34, 10, 32, 48, 34, 16, 32, 75, 36, 0, 4, 5, 46, 85, 16, 32, 75, 34, 27, 32, 46, 34, 0, 67, 11, 48, 83]
[0, 1, 43, 176, 70, 51, 107, 102, 78, 35, 62, 15, 78, 51, 58, 15, 78, 35, 61, 15, 90, 53, 62, 15, 90, 51, 73, 15, 90, 35, 61, 15, 74, 49, 62, 22, 74, 49, 58, 22, 74, 35, 73, 15, 0, 67, 53, 62, 15, 67, 53, 58, 22, 67, 53, 61, 15, 23, 53, 58, 66, 23, 53, 61, 66, 23, 35, 76, 66, 72, 51, 84, 52, 70, 51, 97, 108, 78, 123, 62, 15, 78, 53, 58, 15, 78, 35, 73, 15, 90, 53, 62, 15, 90, 35, 58, 15, 90, 54, 73, 22, 74, 53, 26, 22, 74, 53, 62, 15, 74, 53, 73, 15, 0, 67, 49, 62, 22, 67, 35, 58, 22, 67, 35, 73, 15, 23, 54, 9, 34, 23, 54, 73, 34, 23, 35, 61, 34, 70, 51, 97, 108, 78, 51, 58, 15, 78, 51, 73, 15, 90, 54, 58, 22, 90, 54, 73, 22, 74, 63, 62, 22, 74, 49, 58, 15, 0, 67, 71, 58, 22, 67, 53, 73, 15, 23, 35, 58, 34, 23, 54,

Epoch: 186 | Time: 1m 3s
	Train Loss: 0.375 | Train PPL:   1.455
	 Val. Loss: 5.337 |  Val. PPL: 207.955
Epoch: 187 | Time: 1m 3s
	Train Loss: 0.371 | Train PPL:   1.449
	 Val. Loss: 5.288 |  Val. PPL: 197.956
Epoch: 188 | Time: 1m 3s
	Train Loss: 0.367 | Train PPL:   1.444
	 Val. Loss: 5.354 |  Val. PPL: 211.426
Epoch: 189 | Time: 1m 3s
	Train Loss: 0.366 | Train PPL:   1.443
	 Val. Loss: 5.384 |  Val. PPL: 217.843
Epoch: 190 | Time: 1m 3s
	Train Loss: 0.363 | Train PPL:   1.438
	 Val. Loss: 5.368 |  Val. PPL: 214.468
Epoch: 191 | Time: 1m 3s
	Train Loss: 0.358 | Train PPL:   1.430
	 Val. Loss: 5.492 |  Val. PPL: 242.841
Epoch: 192 | Time: 1m 3s
	Train Loss: 0.357 | Train PPL:   1.429
	 Val. Loss: 5.370 |  Val. PPL: 214.939
Epoch: 193 | Time: 1m 3s
	Train Loss: 0.352 | Train PPL:   1.422
	 Val. Loss: 5.442 |  Val. PPL: 230.902
Epoch: 194 | Time: 1m 3s
	Train Loss: 0.349 | Train PPL:   1.418
	 Val. Loss: 5.429 |  Val. PPL: 227.913
Epoch: 195 | Time: 1m 3s
	Train Loss: 0.346 | Train PPL

[0, 1, 2, 159, 1, 49, 50, 41, 1, 54, 6, 34, 8, 2, 159, 72, 49, 50, 52, 72, 49, 30, 52, 13, 2, 159, 70, 32, 50, 15, 16, 32, 50, 22, 17, 2, 159, 17, 11, 30, 22, 90, 32, 50, 52, 0, 1, 2, 159, 1, 11, 6, 47, 8, 2, 159, 13, 2, 159, 70, 51, 6, 15, 16, 64, 50, 15, 78, 51, 50, 15, 17, 2, 159, 90, 32, 50, 52, 74, 51, 50, 19, 0, 1, 2, 159, 4, 11, 50, 22, 8, 2, 159, 8, 32, 6, 22, 72, 51, 50, 22, 10, 51, 50, 31, 13, 51, 50, 22, 70, 51, 50, 42, 78, 51, 50, 34, 17, 2, 159, 0, 1, 51, 50, 52, 8, 2, 159, 72, 51, 50, 135, 17, 2, 159, 0, 1, 2, 159, 8, 2, 159, 13, 51, 50, 31, 16, 51, 30, 31, 17, 2, 159, 27, 51, 50, 47, 0, 1, 2, 159, 67, 51, 50, 7, 4, 51, 50, 52, 8, 2, 159, 8, 2, 159, 13, 2, 159, 16, 51, 6, 41, 17, 2, 159, 17, 64, 50, 22, 90, 51, 50, 37, 0, 1, 2, 159, 1, 2, 159, 4, 51, 48, 15, 23, 35, 33, 22, 8, 2, 159, 8, 2, 159, 8, 51, 50, 7, 10, 54, 50, 22, 91, 51, 50, 22, 13, 64, 33, 31, 13, 2, 159, 16, 51, 50, 22, 78, 51, 50, 22, 17, 2, 159, 27, 51, 50, 7, 27, 51, 50, 22, 0, 1, 2, 159, 1, 2, 159, 1, 64

Epoch: 267 | Time: 1m 3s
	Train Loss: 0.183 | Train PPL:   1.201
	 Val. Loss: 6.099 |  Val. PPL: 445.459
Epoch: 268 | Time: 1m 4s
	Train Loss: 0.186 | Train PPL:   1.204
	 Val. Loss: 6.107 |  Val. PPL: 448.842
Epoch: 269 | Time: 1m 3s
	Train Loss: 0.183 | Train PPL:   1.201
	 Val. Loss: 6.039 |  Val. PPL: 419.353
Epoch: 270 | Time: 1m 4s
	Train Loss: 0.180 | Train PPL:   1.197
	 Val. Loss: 6.174 |  Val. PPL: 480.118
Epoch: 271 | Time: 1m 3s
	Train Loss: 0.179 | Train PPL:   1.196
	 Val. Loss: 6.092 |  Val. PPL: 442.455
Epoch: 272 | Time: 1m 3s
	Train Loss: 0.177 | Train PPL:   1.194
	 Val. Loss: 6.120 |  Val. PPL: 455.078
Epoch: 273 | Time: 1m 4s
	Train Loss: 0.178 | Train PPL:   1.195
	 Val. Loss: 6.116 |  Val. PPL: 453.248
Epoch: 274 | Time: 1m 3s
	Train Loss: 0.174 | Train PPL:   1.190
	 Val. Loss: 6.189 |  Val. PPL: 487.484
Epoch: 275 | Time: 1m 4s
	Train Loss: 0.176 | Train PPL:   1.193
	 Val. Loss: 6.146 |  Val. PPL: 466.710
[0, 1, 2, 175, 10, 25, 127, 22, 91, 11, 127, 22, 13, 5,

Epoch: 301 | Time: 1m 3s
	Train Loss: 0.144 | Train PPL:   1.155
	 Val. Loss: 6.451 |  Val. PPL: 633.057
=> Saving checkpoint
Epoch: 302 | Time: 1m 3s
	Train Loss: 0.143 | Train PPL:   1.154
	 Val. Loss: 6.430 |  Val. PPL: 620.417
Epoch: 303 | Time: 1m 3s
	Train Loss: 0.143 | Train PPL:   1.154
	 Val. Loss: 6.480 |  Val. PPL: 651.843
Epoch: 304 | Time: 1m 3s
	Train Loss: 0.141 | Train PPL:   1.151
	 Val. Loss: 6.503 |  Val. PPL: 666.956
Epoch: 305 | Time: 1m 3s
	Train Loss: 0.141 | Train PPL:   1.151
	 Val. Loss: 6.531 |  Val. PPL: 686.056
Epoch: 306 | Time: 1m 3s
	Train Loss: 0.142 | Train PPL:   1.152
	 Val. Loss: 6.427 |  Val. PPL: 618.374
Epoch: 307 | Time: 1m 4s
	Train Loss: 0.139 | Train PPL:   1.149
	 Val. Loss: 6.522 |  Val. PPL: 679.839
Epoch: 308 | Time: 1m 3s
	Train Loss: 0.138 | Train PPL:   1.148
	 Val. Loss: 6.483 |  Val. PPL: 653.699
Epoch: 309 | Time: 1m 3s
	Train Loss: 0.137 | Train PPL:   1.147
	 Val. Loss: 6.495 |  Val. PPL: 661.502
Epoch: 310 | Time: 1m 3s
	Train Lo

[0, 1, 43, 177, 23, 54, 12, 15, 8, 45, 12, 7, 72, 64, 86, 7, 10, 51, 86, 7, 91, 18, 96, 31, 70, 35, 96, 15, 78, 54, 98, 15, 17, 49, 39, 7, 90, 64, 107, 7, 27, 35, 39, 7, 74, 24, 39, 7, 0, 67, 53, 98, 31, 23, 35, 98, 108, 13, 32, 84, 15, 70, 64, 96, 15, 16, 54, 86, 22, 78, 32, 96, 66, 0, 72, 35, 96, 31, 91, 54, 96, 31, 70, 38, 82, 31, 78, 53, 96, 52, 27, 11, 96, 15, 74, 64, 84, 31, 0, 67, 53, 96, 31, 23, 35, 96, 56, 27, 51, 86, 22, 74, 32, 96, 15, 0, 1, 64, 84, 15, 67, 64, 96, 15, 4, 35, 86, 22, 23, 54, 86, 19, 72, 71, 96, 7, 91, 32, 86, 31, 70, 51, 96, 52, 78, 54, 86, 15, 90, 35, 86, 52, 74, 35, 86, 31, 0, 67, 54, 57, 31, 23, 54, 12, 31, 72, 49, 12, 31, 91, 35, 86, 31, 70, 35, 12, 52, 70, 35, 96, 52, 78, 35, 96, 31, 90, 54, 86, 31, 74, 35, 86, 31, 0, 67, 35, 12, 31, 23, 35, 57, 52, 72, 54, 12, 52, 70, 35, 57, 19, 78, 54, 12, 34, 74, 35, 57, 19]
[0, 1, 2, 159, 1, 49, 50, 41, 1, 54, 6, 34, 8, 2, 159, 72, 49, 33, 52, 72, 49, 30, 52, 13, 2, 159, 70, 32, 33, 15, 16, 32, 33, 22, 17, 2, 159, 

Epoch: 386 | Time: 1m 4s
	Train Loss: 0.090 | Train PPL:   1.094
	 Val. Loss: 7.039 |  Val. PPL: 1140.398
Epoch: 387 | Time: 1m 3s
	Train Loss: 0.089 | Train PPL:   1.093
	 Val. Loss: 7.082 |  Val. PPL: 1190.181
Epoch: 388 | Time: 1m 4s
	Train Loss: 0.089 | Train PPL:   1.094
	 Val. Loss: 7.033 |  Val. PPL: 1133.396
Epoch: 389 | Time: 1m 3s
	Train Loss: 0.088 | Train PPL:   1.092
	 Val. Loss: 6.966 |  Val. PPL: 1060.236
Epoch: 390 | Time: 1m 3s
	Train Loss: 0.089 | Train PPL:   1.093
	 Val. Loss: 7.052 |  Val. PPL: 1155.281
Epoch: 391 | Time: 1m 4s
	Train Loss: 0.087 | Train PPL:   1.091
	 Val. Loss: 7.071 |  Val. PPL: 1177.671
Epoch: 392 | Time: 1m 4s
	Train Loss: 0.089 | Train PPL:   1.093
	 Val. Loss: 7.126 |  Val. PPL: 1243.724
Epoch: 393 | Time: 1m 4s
	Train Loss: 0.086 | Train PPL:   1.090
	 Val. Loss: 7.147 |  Val. PPL: 1270.058
Epoch: 394 | Time: 1m 3s
	Train Loss: 0.087 | Train PPL:   1.091
	 Val. Loss: 7.117 |  Val. PPL: 1232.880
Epoch: 395 | Time: 1m 4s
	Train Loss: 0.084 | 

[0, 1, 2, 183, 1, 64, 12, 19, 4, 63, 98, 31, 23, 35, 9, 15, 8, 2, 183, 8, 54, 57, 19, 10, 71, 30, 31, 91, 123, 9, 15, 13, 2, 183, 13, 54, 12, 19, 16, 54, 50, 7, 78, 35, 9, 15, 17, 2, 183, 17, 54, 57, 7, 90, 32, 12, 7, 27, 32, 55, 31, 74, 35, 6, 15, 0, 1, 2, 183, 1, 54, 57, 41, 67, 32, 96, 136, 4, 64, 30, 31, 23, 64, 9, 15, 8, 2, 183, 8, 51, 50, 52, 10, 49, 50, 31, 91, 32, 9, 15, 13, 2, 183, 13, 54, 57, 52, 16, 38, 30, 7, 78, 64, 9, 15, 17, 2, 183, 17, 51, 50, 31, 90, 51, 6, 15, 27, 54, 50, 31, 74, 32, 9, 15, 0, 1, 2, 183, 1, 64, 57, 41, 1, 64, 96, 136, 4, 32, 30, 7, 23, 71, 9, 15, 8, 2, 183, 8, 53, 57, 31, 10, 51, 30, 7, 91, 51, 9, 15, 13, 64, 50, 31, 16, 32, 9, 31, 78, 32, 9, 31, 17, 53, 50, 31, 90, 63, 6, 31, 27, 11, 96, 15, 74, 64, 9, 52, 0, 1, 64, 9, 108, 0, 1, 53, 50, 7, 67, 45, 6, 7, 4, 53, 57, 31, 23, 51, 50, 7, 72, 63, 6, 7, 10, 64, 57, 7, 91, 64, 9, 15, 13, 2, 183, 70, 35, 62, 15, 16, 35, 50, 15, 78, 35, 48, 7, 17, 64, 50, 41, 17, 64, 50, 41, 17, 35, 50, 52, 27, 35, 50, 52, 74

Epoch: 469 | Time: 1m 4s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 7.495 |  Val. PPL: 1799.539
Epoch: 470 | Time: 1m 3s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 7.476 |  Val. PPL: 1765.674
Epoch: 471 | Time: 1m 4s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 7.473 |  Val. PPL: 1760.745
Epoch: 472 | Time: 1m 3s
	Train Loss: 0.063 | Train PPL:   1.065
	 Val. Loss: 7.488 |  Val. PPL: 1787.115
Epoch: 473 | Time: 1m 3s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 7.550 |  Val. PPL: 1900.891
Epoch: 474 | Time: 1m 3s
	Train Loss: 0.063 | Train PPL:   1.065
	 Val. Loss: 7.390 |  Val. PPL: 1620.504
Epoch: 475 | Time: 1m 4s
	Train Loss: 0.064 | Train PPL:   1.066
	 Val. Loss: 7.488 |  Val. PPL: 1785.596
[0, 1, 2, 154, 23, 64, 93, 31, 72, 11, 128, 31, 91, 79, 55, 42, 27, 5, 55, 15, 74, 11, 75, 41, 0, 10, 11, 131, 15, 91, 79, 55, 42, 17, 64, 104, 22, 90, 25, 46, 15, 27, 18, 75, 47, 0, 10, 11, 104, 15, 91, 11, 100, 22, 13, 53, 100, 22, 70, 11, 100, 22, 16, 51,

| Test Loss: 7.858 | Test PPL: 2585.684 |


In [19]:
output = open(folder + "train_loss_log.pkl", 'wb')
pickle.dump(train_loss_log, output)
output.close()
gug.pkl", 'wb')
pickle.dump(valid_loss_log, output)
output.close()

In [None]:
## checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [None]:
best_model = Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
).to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

In [39]:
state = torch.load(destination_folder + '/500_checkpoint.pt', map_location=device)
load_checkpoint(state, model, optimizer)

=> Loading checkpoint


RuntimeError: Error(s) in loading state_dict for Transformer:
	Missing key(s) in state_dict: "src2_word_embedding.weight", "src2_position_embedding.weight", "transformer.encoder2.layers.0.self_attn.in_proj_weight", "transformer.encoder2.layers.0.self_attn.in_proj_bias", "transformer.encoder2.layers.0.self_attn.out_proj.weight", "transformer.encoder2.layers.0.self_attn.out_proj.bias", "transformer.encoder2.layers.0.linear1.weight", "transformer.encoder2.layers.0.linear1.bias", "transformer.encoder2.layers.0.linear2.weight", "transformer.encoder2.layers.0.linear2.bias", "transformer.encoder2.layers.0.norm1.weight", "transformer.encoder2.layers.0.norm1.bias", "transformer.encoder2.layers.0.norm2.weight", "transformer.encoder2.layers.0.norm2.bias", "transformer.encoder2.layers.1.self_attn.in_proj_weight", "transformer.encoder2.layers.1.self_attn.in_proj_bias", "transformer.encoder2.layers.1.self_attn.out_proj.weight", "transformer.encoder2.layers.1.self_attn.out_proj.bias", "transformer.encoder2.layers.1.linear1.weight", "transformer.encoder2.layers.1.linear1.bias", "transformer.encoder2.layers.1.linear2.weight", "transformer.encoder2.layers.1.linear2.bias", "transformer.encoder2.layers.1.norm1.weight", "transformer.encoder2.layers.1.norm1.bias", "transformer.encoder2.layers.1.norm2.weight", "transformer.encoder2.layers.1.norm2.bias", "transformer.encoder2.layers.2.self_attn.in_proj_weight", "transformer.encoder2.layers.2.self_attn.in_proj_bias", "transformer.encoder2.layers.2.self_attn.out_proj.weight", "transformer.encoder2.layers.2.self_attn.out_proj.bias", "transformer.encoder2.layers.2.linear1.weight", "transformer.encoder2.layers.2.linear1.bias", "transformer.encoder2.layers.2.linear2.weight", "transformer.encoder2.layers.2.linear2.bias", "transformer.encoder2.layers.2.norm1.weight", "transformer.encoder2.layers.2.norm1.bias", "transformer.encoder2.layers.2.norm2.weight", "transformer.encoder2.layers.2.norm2.bias", "transformer.encoder2.norm.weight", "transformer.encoder2.norm.bias", "transformer.decoder.layers.0.norm4.weight", "transformer.decoder.layers.0.norm4.bias", "transformer.decoder.layers.1.norm4.weight", "transformer.decoder.layers.1.norm4.bias", "transformer.decoder.layers.2.norm4.weight", "transformer.decoder.layers.2.norm4.bias". 

In [38]:
test_loss = evaluate(model, test_iter, criterion)
print(math.exp(test_loss))

266468.95956058794


In [20]:
generated_outputs = folder +  "/generated_samples_temp"
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)

In [None]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

In [None]:
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


In [None]:
import mido
for i in range(11):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

In [None]:
# dissimilar_interpolation
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    #solo = test_solo[i]
    if i + 3 < (len(test_intro)):
        outro = test_outro[i+3]
    else:
        outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    #list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence(model, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, dissimilar_interpolation + "/intro/" + "/intro" + str(i)  + ".mid")
    #utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, dissimilar_interpolation + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, dissimilar_interpolation + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


In [None]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(dissimilar_interpolation + "/intro/" + '/intro' + str(i) + '.mid')
    outro = mido.MidiFile(dissimilar_interpolation + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(dissimilar_interpolation + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(dissimilar_interpolation + '/merged_predict' + str(i) + '.mid')

In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [None]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [None]:
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns