In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import twoencodertransformer as kk
import pickle
source_folder = "solo_generation_dataset_augmented_presplit"
folder = "interpolation_mag"
destination_folder = folder + "/solo_generation_weights"
weights_solo = "dynamic_augmented_models/2nec/solo_generation_weights"
weights_mag = "dynamic_mag_models/2nec/solo_generation_weights"
generated_outputs = folder +  "/generated_samples"
dissimilar_interpolation = folder + "/interpolation"
vocab = folder + "/vocab"

In [2]:
from pathlib import Path
Path(destination_folder).mkdir(parents=True, exist_ok=True)
Path(generated_outputs).mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation).mkdir(parents=True, exist_ok=True)
Path(vocab).mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/intro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/outro").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/solo").mkdir(parents=True, exist_ok=True)
Path(generated_outputs+"/predict").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/intro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/outro").mkdir(parents=True, exist_ok=True)
Path(dissimilar_interpolation+"/predict").mkdir(parents=True, exist_ok=True)

In [3]:
event2word, word2event = pickle.load(open('dictionary_augmented.pkl', 'rb'))

In [4]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [9]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 8
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.intro),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=1)
intro_piano_field.build_vocab(train, min_freq=1)
outro_field.build_vocab(train, min_freq=1)
outro_piano_field.build_vocab(train, min_freq=1)
solo_field.build_vocab(train, min_freq=1)
solo_piano_field.build_vocab(train, min_freq=1)

In [None]:
# Fields

main_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('main', main_field), ('piano', piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators
BATCH_SIZE = 1
train_iter = BucketIterator(train, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.main),
                            device=device, sort=False, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.main),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=BATCH_SIZE, sort_key=lambda x: len(x.main),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

main_field.build_vocab(train, min_freq=1)
piano_field.build_vocab(train, min_freq=1)

In [5]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [6]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

In [7]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        src2_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.src2_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src2_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = kk.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, src2, trg):
        src_seq_length, N = src.shape
        src2_seq_length, N = src2.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )
        
        src2_positions = (
            torch.arange(0, src2_seq_length)
            .unsqueeze(1)
            .expand(src2_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        ).to(self.device)
        embed_src2 = self.dropout(
            (self.src2_word_embedding(src2) + self.src2_position_embedding(src2_positions))
        ).to(self.device)
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        ).to(self.device)
        src_padding_mask = self.make_src_mask(src)
        src2_padding_mask = self.make_src_mask(src2)
        #print(src_padding_mask.size())
        #print(src2_padding_mask.size())
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_src2,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            src2_key_padding_mask=src2_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [8]:
#https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/seq2seq_transformer/seq2seq_transformer.py
class MAG_Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expansion,
        dropout,
        max_len,
        device,
    ):
        super(Transformer, self).__init__()
        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_position_embedding = nn.Embedding(max_len, embedding_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)
        self.trg_position_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expansion,
            dropout,
        )
        self.fc_out = nn.Linear(embedding_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.src_pad_idx = src_pad_idx

    def make_src_mask(self, src):
        src_mask = src.transpose(0, 1) == self.src_pad_idx

        # (N, src_len)
        return src_mask.to(self.device)

    def forward(self, src, trg):
        src_seq_length, N = src.shape
        trg_seq_length, N = trg.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(1)
            .expand(trg_seq_length, N)
            .to(self.device)
        )

        embed_src = self.dropout(
            (self.src_word_embedding(src) + self.src_position_embedding(src_positions))
        )
        embed_trg = self.dropout(
            (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(
            self.device
        )

        out = self.transformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=trg_mask,
        )
        out = self.fc_out(out)
        return out


In [11]:
src_vocab_size = len(intro_field.vocab)
src2_vocab_size = len(outro_field.vocab)
trg_vocab_size = len(solo_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 1200
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

solo_model = Transformer(
    embedding_size,
    src_vocab_size,
    src2_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
solo_model = model.to(device)

src_vocab_size = len(main_field.vocab)
trg_vocab_size = len(piano_field.vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3
num_decoder_layers = 3
dropout = 0.10
max_len = 3000
forward_expansion = 4
src_pad_idx = 1 #english.vocab.stoi["<pad>"]

mag_model = MAG_Transformer(
    embedding_size,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    num_heads,
    num_encoder_layers,
    num_decoder_layers,
    forward_expansion,
    dropout,
    max_len,
    device,
)
mag_model = model.to(device)


RuntimeError: CUDA error: out of memory

In [17]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters(), lr=4e-5)


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')


def save_best_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
#     torch.save(state, destination_folder + str(nth)+filename)
    torch.save(state, destination_folder + '/metrics.pt')

def save_final_checkpoint(state, nth,filename="_checkpoint.pt"):
    print("=> Saving checkpoint")
    torch.save(state, destination_folder + "/" + str(nth)+filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

The model has 11,094,799 trainable parameters


In [18]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [19]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [21]:
def translate_sentence(model, sentence, sentence2, intro, outro, solo, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, intro.init_token)
    tokens.append(intro.eos_token)

    tokens2.insert(0, outro.init_token)
    tokens2.append(outro.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [intro.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [outro.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [solo.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, sentence2_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == solo.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [solo.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [22]:
df_intro = pd.read_csv(source_folder + '/val_torchtext.csv')
val_intro = df_intro['intro'].values
val_solo = df_intro['solo'].values
val_outro = df_intro['outro'].values
val_data=[]
for i in range(len(val_intro)):
    temp_dict = {}
    temp_dict['intro'] = val_intro[i]
    temp_dict['solo'] = val_solo[i]
    temp_dict['outro'] = val_outro[i]
    val_data.append(temp_dict)
print(len(val_intro))

112


In [22]:
# checkpoint = {'model_state_dict': model.state_dict(),
#                   'optimizer_state_dict': optimizer.state_dict(),
#                   'valid_loss': valid_loss}
# save_checkpoint(destination_folder + checkpoint,N_EPOCHS)

In [29]:
optimizer = optim.Adam(outro_transformer.parameters(), lr=0.001)
optimizer = optim.Adam(intro_transformer.parameters(), lr=0.001)

In [50]:
state = torch.load(weights_solo + '/500_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, solo_transformer, optimizer)
state = torch.load(weights_mag + '/2000_checkpoint.pt', map_location=device)
for key in state:
    print(key)
load_checkpoint(state, mag_transformer, optimizer)

model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint
model_state_dict
optimizer_state_dict
valid_loss
=> Loading checkpoint


In [20]:
test_loss = evaluate(model, valid_iter, criterion)
print(math.exp(test_loss))

4891.367024264349


In [31]:
df_intro = pd.read_csv(source_folder + '/test_torchtext.csv')
test_intro = df_intro['intro'].values
test_solo = df_intro['solo'].values
test_outro = df_intro['outro'].values
test_data=[]
for i in range(len(test_intro)):
    temp_dict = {}
    temp_dict['intro'] = test_intro[i]
    temp_dict['solo'] = test_solo[i]
    temp_dict['outro'] = test_outro[i]
    test_data.append(temp_dict)
print(len(test_intro))

112


In [59]:
def translate_sentence_ensemble(model, model2, sentence, sentence2, intro, outro, solo, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)
    tokens2 = [token.lower() for token in sentence2.split(' ')]
    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, intro.init_token)
    tokens.append(intro.eos_token)

    tokens2.insert(0, outro.init_token)
    tokens2.append(outro.eos_token)
    
    # Go through each german token and convert to an index
    text_to_indices = [intro.vocab.stoi[token] for token in tokens]
    text_to_indices2 = [outro.vocab.stoi[token] for token in tokens2]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    sentence2_tensor = torch.LongTensor(text_to_indices2).unsqueeze(1).to(device)
    
    outputs = [solo.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
        
            if len(outputs) < ( (len(tokens)+len(tokens2))/2 ):
                #print("model1")
                output = model(sentence_tensor, trg_tensor)
                best_guess = output.argmax(2)[-1, :].item()
                outputs.append(best_guess)
            else:
                #print("model2")
                output2 = model2(sentence2_tensor, trg_tensor)
                best_guess = output2.argmax(2)[-1, :].item()
                outputs.append(best_guess)
        if best_guess == solo.vocab.stoi["<eos>"]:
            break
    # print(outputs)
    translated_sentence = [solo.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [60]:
for i in range(0,len(test_intro)):
    intro = test_intro[i]
    solo = test_solo[i]
    outro = test_outro[i]
    #print(intro)
    list_intro = [int(x) for x in intro.split(' ')]
    list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence_ensemble(intro_transformer, outro_transformer, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, generated_outputs + "/intro/" + "/intro" + str(i)  + ".mid")
    utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, generated_outputs + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
    
        


[0, 1, 2, 162, 10, 11, 9, 15, 91, 18, 73, 15, 13, 5, 9, 15, 70, 11, 9, 15, 16, 64, 6, 7, 78, 38, 9, 15, 17, 32, 9, 15, 90, 32, 9, 7, 27, 38, 9, 7, 74, 64, 12, 15, 0, 1, 38, 6, 15, 67, 32, 21, 15, 4, 51, 21, 15, 4, 38, 20, 22, 23, 38, 73, 22, 8, 11, 73, 7, 10, 51, 12, 19, 10, 38, 21, 22, 91, 64, 20, 15, 91, 38, 20, 15, 13, 14, 61, 15, 13, 64, 20, 19, 70, 32, 21, 34, 90, 32, 20, 7, 27, 64, 9, 7, 74, 64, 12, 15, 0, 1, 64, 26, 22, 67, 64, 9, 15, 4, 38, 12, 22, 67, 38, 21, 22, 4, 38, 20, 22, 23, 11, 20, 22, 8, 32, 21, 7, 72, 32, 12, 22, 10, 51, 20, 22, 10, 51, 73, 22, 91, 38, 21, 34, 13, 24, 20, 7, 16, 11, 21, 7, 16, 51, 21, 15, 78, 64, 20, 7, 17, 64, 73, 7, 27, 38, 21, 7, 27, 64, 21, 7, 27, 11, 96, 19, 1, 32, 12, 7, 1, 64, 21, 7, 1, 32, 20, 22, 67, 11, 21, 15, 67, 38, 20, 15, 4, 51, 12, 22, 4, 38, 20, 15, 4, 11, 21, 15, 23, 51, 20, 15, 8, 51, 12, 22, 72, 38, 21, 28, 8, 38, 21, 7, 72, 32, 20, 80, 8, 38, 21, 7, 10, 32, 21, 7, 91, 38, 20, 15, 91, 32, 73, 15, 13, 64, 20, 15, 13, 38, 21, 22, 70

[0, 1, 2, 146, 1, 24, 57, 15, 67, 51, 86, 15, 67, 51, 57, 15, 4, 5, 96, 15, 4, 5, 76, 15, 23, 11, 48, 19, 8, 51, 58, 15, 72, 32, 12, 15, 10, 11, 57, 15, 10, 24, 57, 15, 91, 38, 50, 15, 13, 11, 57, 15, 13, 64, 57, 15, 70, 38, 48, 36, 17, 2, 150, 27, 38, 57, 15, 74, 11, 57, 7, 0, 1, 51, 58, 41, 1, 38, 57, 7, 4, 38, 86, 7, 4, 5, 86, 15, 8, 2, 146, 8, 2, 146, 10, 11, 55, 15, 91, 35, 61, 15, 13, 64, 57, 15, 70, 11, 57, 15, 16, 11, 50, 80, 27, 64, 57, 15, 74, 51, 57, 15, 0, 1, 64, 86, 15, 1, 35, 48, 15, 1, 64, 57, 15, 67, 64, 61, 15, 4, 32, 86, 15, 8, 2, 159, 8, 64, 57, 15, 72, 54, 57, 15, 10, 51, 58, 19, 91, 35, 12, 15, 13, 2, 142, 70, 32, 96, 52, 17, 2, 146, 90, 32, 86, 22, 27, 54, 21, 15, 74, 5, 86, 15, 0, 1, 54, 57, 52, 67, 64, 12, 15, 4, 24, 57, 15, 23, 32, 86, 7, 8, 2, 146, 8, 2, 146, 72, 32, 12, 19, 8, 2, 146, 10, 11, 86, 15, 10, 11, 57, 19, 13, 2, 146, 27, 45, 12, 15, 74, 51, 86, 15, 74, 32, 57, 22, 0, 1, 2, 146, 91, 53, 12, 15, 4, 38, 86, 15, 13, 64, 48, 15, 16, 32, 12, 15, 17, 51, 

[0, 1, 43, 161, 4, 35, 58, 52, 8, 45, 58, 15, 10, 64, 62, 31, 13, 64, 46, 52, 16, 64, 48, 52, 17, 49, 48, 7, 27, 64, 58, 31, 0, 1, 64, 57, 52, 4, 35, 62, 52, 8, 64, 48, 52, 10, 64, 46, 31, 13, 54, 58, 31, 16, 53, 62, 52, 17, 64, 50, 31, 27, 64, 46, 31, 0, 1, 64, 48, 31, 4, 53, 46, 52, 8, 64, 48, 31, 10, 64, 48, 31, 13, 32, 58, 31, 16, 35, 62, 52, 17, 32, 62, 31, 27, 54, 46, 52, 0, 1, 38, 58, 31, 4, 64, 57, 52, 8, 64, 50, 7, 8, 64, 62, 66, 17, 64, 48, 31, 27, 35, 50, 31, 0, 1, 64, 48, 41, 4, 51, 50, 34, 8, 11, 62, 31, 10, 32, 48, 31, 13, 32, 62, 31, 16, 51, 48, 31, 17, 64, 58, 7, 27, 32, 58, 31, 0, 1, 54, 58, 34, 17, 32, 58, 34, 0, 1, 64, 48, 41, 8, 64, 58, 31, 27, 51, 62, 66, 8, 64, 50, 19, 13, 64, 48, 41, 17, 32, 48, 15, 0, 1, 38, 58, 31, 10, 64, 58, 66, 8, 32, 48, 31, 16, 64, 58, 52, 13, 51, 57, 52, 17, 64, 58, 31, 27, 38, 48, 7, 0, 1, 32, 58, 31, 4, 64, 46, 31, 8, 54, 58, 31, 10, 64, 48, 34, 10, 53, 48, 31, 13, 51, 46, 7, 16, 54, 17, 51, 46, 41, 17, 64, 46, 41, 17, 35, 100, 41, 0, 1

[0, 1, 2, 142, 8, 2, 142, 8, 38, 76, 7, 10, 51, 73, 31, 13, 2, 142, 13, 32, 86, 15, 16, 51, 21, 31, 17, 2, 142, 17, 54, 9, 31, 27, 35, 61, 19, 0, 1, 2, 142, 8, 2, 142, 8, 51, 58, 34, 8, 51, 97, 7, 10, 35, 73, 31, 13, 2, 142, 17, 2, 142, 27, 35, 62, 7, 27, 32, 84, 15, 74, 38, 84, 15, 0, 1, 51, 84, 34, 8, 2, 142, 8, 53, 84, 31, 10, 54, 40, 31, 13, 51, 62, 85, 0, 8, 2, 142, 13, 2, 142, 16, 51, 62, 15, 78, 53, 97, 31, 17, 2, 142, 17, 51, 84, 34, 17, 51, 26, 7, 27, 35, 20, 15, 27, 51, 73, 85, 0, 1, 2, 142, 1, 2, 142, 1, 2, 142, 8, 54, 73, 31, 10, 51, 20, 31, 10, 51, 96, 31, 13, 2, 142, 13, 63, 62, 7, 16, 35, 84, 15, 78, 35, 76, 31, 17, 64, 86, 31, 17, 2, 142, 17, 2, 142, 17, 64, 86, 15, 27, 63, 84, 15, 74, 49, 84, 31, 0, 1, 35, 76, 15, 1, 2, 142, 1, 2, 142, 1, 53, 61, 31, 8, 64, 84, 15, 72, 64, 96, 15, 10, 53, 84, 15, 91, 49, 20, 15, 91, 45, 76, 15, 13, 2, 142, 13, 38, 84, 15, 13, 2, 142, 17, 2, 142, 17, 51, 20, 15, 90, 35, 86, 31, 27, 35, 20, 28, 0, 1, 2, 142, 1, 2, 142, 8, 2, 142, 8, 35, 

[0, 1, 2, 126, 67, 51, 33, 7, 23, 32, 30, 22, 8, 32, 75, 19, 10, 11, 6, 41, 78, 11, 55, 7, 91, 11, 55, 37, 0, 67, 14, 82, 19, 23, 24, 96, 7, 8, 32, 55, 7, 72, 32, 33, 7, 91, 14, 33, 37, 91, 32, 96, 41, 27, 32, 21, 31, 0, 1, 11, 33, 7, 23, 32, 33, 31, 8, 11, 21, 7, 72, 32, 111, 7, 91, 32, 6, 7, 70, 32, 122, 7, 78, 32, 111, 31, 17, 11, 122, 7, 90, 14, 33, 41, 0, 4, 32, 55, 22, 23, 32, 55, 22, 8, 11, 30, 7, 10, 32, 111, 7, 13, 32, 111, 7, 70, 32, 55, 7, 78, 32, 30, 31, 17, 38, 55, 36, 0, 67, 32, 33, 7, 67, 32, 111, 7, 23, 32, 111, 7, 8, 32, 55, 22, 23, 38, 48, 7, 72, 32, 30, 22, 8, 32, 33, 22, 10, 32, 33, 22, 91, 32, 30, 22, 91, 32, 30, 22, 70, 32, 33, 37, 0, 1, 32, 33, 52, 78, 32, 50, 31, 90, 24, 33, 7, 74, 24, 6, 66, 0, 67, 32, 33, 7, 23, 11, 33, 94, 23, 51, 30, 19, 13, 32, 6, 15, 72, 38, 6, 15, 91, 32, 9, 15, 70, 38, 73, 15, 16, 32, 57, 41, 0, 67, 38, 12, 31, 23, 32, 9, 41, 23, 38, 12, 15, 8, 32, 12, 77, 13, 32, 21, 7, 10, 32, 12, 19, 13, 32, 73, 41, 78, 32, 57, 22, 17, 32, 57, 19, 90,

[0, 1, 43, 110, 8, 2, 110, 10, 63, 62, 7, 13, 2, 110, 78, 45, 58, 7, 17, 45, 58, 31, 17, 123, 61, 15, 17, 45, 58, 7, 27, 45, 58, 7, 8, 5, 58, 7, 4, 45, 57, 7, 5, 84, 7, 8, 5, 58, 7, 8, 123, 61, 7, 8, 5, 62, 7, 10, 63, 86, 56, 17, 63, 50, 76, 15, 27, 45, 61, 31, 0, 1, 2, 110, 4, 45, 58, 34, 8, 2, 110, 10, 123, 58, 7, 13, 63, 48, 15, 70, 123, 86, 15, 16, 63, 61, 15, 17, 123, 62, 7, 17, 123, 57, 15, 90, 35, 61, 31, 27, 45, 58, 15, 27, 63, 86, 31, 0, 1, 2, 110, 4, 45, 61, 15, 8, 123, 57, 15, 8, 63, 61, 31, 10, 45, 61, 15, 10, 63, 62, 15, 13, 63, 62, 7, 16, 63, 46, 15, 17, 63, 58, 15, 17, 45, 58, 15, 27, 45, 58, 15, 74, 51, 61, 15, 0, 1, 45, 58, 15, 1, 53, 73, 15, 4, 45, 58, 31, 8, 5, 61, 15, 8, 71, 48, 31, 10, 71, 57, 31, 13, 45, 57, 36, 13, 5, 61, 15, 16, 63, 58, 15, 16, 63, 73, 15, 17, 5, 61, 15, 17, 53, 61, 15, 27, 45, 73, 15, 74, 45, 57, 31, 0, 1, 45, 58, 15, 17, 45, 58, 7, 27, 45, 62, 7, 1, 63, 61, 15, 4, 123, 86, 31, 8, 71, 86, 31, 13, 63, 61, 15, 27, 71, 61, 15, 27, 45, 61, 15, 74, 

[0, 1, 43, 208, 8, 35, 46, 31, 10, 54, 62, 31, 13, 49, 58, 31, 16, 32, 58, 59, 0, 1, 35, 61, 31, 4, 45, 58, 135, 0, 4, 35, 62, 7, 8, 54, 62, 7, 8, 64, 61, 7, 10, 54, 58, 31, 13, 51, 61, 7, 16, 32, 58, 28, 0, 1, 53, 48, 31, 4, 51, 58, 31, 8, 64, 58, 135, 0, 23, 53, 58, 15, 8, 53, 46, 15, 10, 54, 48, 31, 13, 53, 61, 77, 16, 38, 86, 15, 17, 51, 58, 31, 27, 51, 86, 15, 0, 1, 35, 58, 31, 4, 51, 62, 15, 4, 32, 48, 7, 8, 35, 57, 103, 0, 1, 53, 48, 15, 10, 38, 58, 15, 91, 54, 46, 41]
45
[0, 1, 2, 114, 10, 11, 76, 15, 91, 51, 86, 42, 90, 51, 86, 15, 74, 51, 76, 22, 0, 1, 64, 86, 7, 23, 51, 86, 22, 8, 51, 86, 22, 10, 51, 86, 22, 91, 51, 86, 22, 13, 51, 86, 7, 70, 64, 61, 7, 16, 64, 61, 41, 16, 51, 86, 7, 78, 51, 86, 22, 17, 51, 86, 22, 90, 64, 86, 7, 27, 54, 61, 15, 74, 51, 86, 15, 0, 1, 54, 61, 52, 1, 51, 98, 52, 1, 51, 73, 22, 67, 51, 86, 22, 4, 51, 57, 22, 8, 51, 86, 15, 72, 11, 57, 101, 10, 64, 86, 52, 0, 1, 54, 86, 15, 67, 54, 76, 15, 4, 64, 86, 22, 8, 35, 86, 15, 10, 51, 76, 41, 13, 51, 76

[0, 1, 2, 114, 1, 35, 58, 22, 67, 54, 58, 15, 4, 51, 73, 15, 8, 35, 57, 31, 10, 51, 57, 66, 70, 35, 73, 22, 35, 73, 31, 78, 35, 57, 31, 27, 54, 58, 22, 74, 51, 62, 22, 74, 54, 58, 22, 0, 1, 51, 58, 31, 1, 53, 6, 31, 23, 35, 58, 31, 72, 54, 6, 47, 70, 54, 62, 31, 78, 54, 58, 80, 27, 53, 73, 66, 27, 35, 76, 31, 0, 1, 54, 58, 34, 4, 54, 58, 52, 23, 35, 73, 41, 72, 35, 57, 7, 91, 51, 76, 7, 70, 54, 58, 31, 70, 54, 73, 52, 17, 35, 58, 15, 90, 35, 73, 52, 12, 7, 27, 49, 61, 52, 74, 35, 86, 112, 0, 67, 54, 62, 15, 4, 54, 62, 31, 23, 35, 6, 15, 8, 54, 76, 15, 8, 32, 58, 52, 10, 54, 73, 41, 13, 64, 61, 31, 70, 64, 62, 31, 74, 35, 73, 103, 0, 23, 49, 73, 31, 10, 35, 61, 15, 91, 35, 73, 52, 70, 35, 58, 15, 16, 54, 58, 31, 17, 54, 58, 15, 90, 35, 58, 31, 74, 35, 58, 66, 0, 1, 54, 58, 15, 67, 54, 61, 15, 4, 51, 86, 15, 23, 54, 58, 15, 8, 35, 61, 15, 72, 35, 58, 7, 91, 35, 61, 47, 16, 35, 73, 19, 78, 54, 73, 31, 90, 35, 73, 66, 0, 91, 35, 62, 7, 70, 54, 20, 34, 0, 4, 54, 73, 41, 78, 54, 58, 31, 90, 

[0, 1, 2, 148, 67, 51, 9, 15, 4, 51, 9, 15, 8, 35, 73, 31, 10, 51, 9, 7, 91, 51, 73, 31, 16, 54, 21, 15, 78, 51, 21, 31, 17, 51, 20, 7, 17, 51, 21, 15, 90, 35, 20, 34, 0, 67, 51, 40, 19, 67, 51, 40, 31, 23, 51, 40, 15, 8, 51, 40, 31, 72, 51, 82, 52, 10, 49, 20, 31, 91, 49, 73, 52, 70, 51, 20, 31, 78, 51, 40, 56, 0, 67, 51, 20, 31, 23, 35, 20, 31, 72, 35, 73, 19, 70, 51, 21, 7, 78, 51, 21, 31, 90, 35, 82, 19, 0, 67, 35, 20, 7, 23, 54, 20, 15, 8, 32, 118, 31, 72, 35, 40, 31, 91, 51, 118, 7, 70, 51, 118, 7, 16, 32, 97, 7, 78, 51, 118, 31, 90, 35, 40, 31, 74, 54, 82, 15, 0, 67, 64, 40, 7, 67, 51, 97, 15, 4, 51, 40, 31, 23, 51, 40, 52, 72, 51, 118, 19, 91, 35, 97, 34, 70, 51, 118, 42, 78, 51, 118, 34, 90, 53, 40, 36, 0, 23, 54, 82, 41, 72, 54, 118, 41, 0, 23, 51, 9, 7, 72, 51, 118, 7, 91, 51, 120, 22, 70, 51, 73, 7, 78, 51, 97, 34, 0, 67, 51, 97, 135]
59
[0, 1, 43, 191, 23, 18, 58, 7, 72, 25, 57, 7, 91, 18, 57, 15, 13, 32, 57, 34, 78, 38, 76, 7, 90, 5, 86, 34, 74, 25, 86, 31, 0, 1, 25, 86, 

[0, 1, 2, 89, 67, 24, 46, 7, 23, 79, 26, 85, 70, 25, 62, 7, 78, 24, 58, 34, 0, 1, 25, 62, 7, 23, 79, 62, 15, 72, 5, 62, 7, 91, 38, 26, 7, 70, 38, 62, 7, 78, 11, 46, 31, 74, 14, 9, 7, 0, 67, 79, 26, 31, 23, 24, 62, 41, 70, 11, 26, 7, 0, 67, 25, 26, 31, 23, 14, 26, 59, 78, 11, 26, 66, 74, 32, 62, 19, 0, 67, 24, 62, 7, 23, 32, 62, 31, 72, 32, 62, 80, 74, 79, 73, 7, 0, 67, 5, 26, 15, 23, 24, 62, 7, 10, 11, 73, 7, 70, 25, 26, 7, 70, 25, 62, 15, 78, 64, 26, 7, 90, 32, 61, 31, 74, 38, 20, 66, 0, 23, 11, 61, 31, 0, 67, 38, 73, 7, 23, 24, 73, 7, 72, 11, 73, 31, 91, 11, 61, 15, 70, 38, 26, 85, 78, 32, 73, 52, 74, 5, 26, 7, 0, 67, 14, 26, 31, 23, 32, 62, 7]
70
[0, 1, 2, 89, 23, 5, 75, 7, 8, 2, 89, 10, 11, 55, 15, 13, 14, 30, 47, 70, 14, 33, 34, 17, 2, 89, 17, 2, 89, 90, 64, 111, 31, 0, 67, 14, 30, 7, 4, 11, 111, 7, 23, 32, 30, 7, 8, 2, 89, 8, 25, 33, 7, 72, 24, 55, 7, 91, 5, 33, 31, 13, 11, 55, 31, 16, 11, 111, 7, 78, 25, 29, 7, 90, 64, 48, 7, 74, 18, 33, 7, 0, 1, 32, 50, 31, 4, 64, 48, 52, 13, 2

[0, 1, 2, 177, 8, 24, 122, 31, 10, 79, 48, 7, 13, 14, 50, 31, 16, 79, 12, 31, 17, 79, 57, 52, 27, 49, 6, 7, 0, 1, 32, 50, 19, 8, 11, 12, 31, 10, 32, 50, 7, 13, 38, 86, 7, 16, 25, 50, 52, 17, 5, 57, 7, 27, 51, 55, 15, 74, 24, 55, 7, 0, 1, 11, 12, 66, 8, 14, 57, 15, 72, 18, 12, 31, 91, 5, 50, 7, 13, 24, 61, 7, 16, 11, 12, 15, 17, 5, 57, 52, 27, 24, 57, 31, 0, 1, 24, 58, 7, 8, 5, 12, 7, 10, 5, 57, 7, 13, 11, 57, 52, 17, 38, 57, 7, 27, 11, 57, 15, 74, 38, 6, 7, 0, 1, 24, 57, 31, 4, 11, 57, 15, 8, 24, 86, 15, 10, 24, 12, 94, 13, 11, 86, 15, 16, 11, 12, 31, 17, 11, 57, 7, 27, 11, 57, 31, 0, 1, 25, 12, 7, 1, 11, 96, 7, 4, 14, 57, 36, 10, 24, 96, 15, 13, 11, 57, 15, 16, 38, 86, 19, 17, 5, 57, 7, 27, 24, 86, 31, 0, 1, 25, 57, 7, 4, 25, 58, 31, 8, 11, 12, 7, 13, 11, 57, 7, 16, 25, 57, 31, 27, 24, 50, 15]
79
[0, 1, 2, 139, 23, 64, 20, 15, 8, 11, 82, 31, 10, 24, 82, 7, 91, 32, 20, 15, 13, 24, 82, 7, 16, 79, 76, 7, 78, 79, 20, 31, 90, 38, 96, 41, 0, 23, 38, 20, 31, 8, 24, 82, 15, 72, 5, 21, 31, 91,

[0, 1, 2, 69, 72, 25, 12, 31, 91, 79, 12, 7, 70, 18, 12, 7, 78, 38, 12, 31, 90, 79, 21, 31, 74, 79, 12, 34, 0, 23, 5, 12, 7, 72, 79, 96, 77, 0, 91, 79, 12, 15, 13, 25, 12, 15, 70, 25, 21, 7, 78, 25, 96, 7, 90, 38, 12, 41, 0, 67, 5, 21, 31, 23, 25, 12, 31, 72, 25, 12, 7, 91, 38, 12, 31, 70, 5, 9, 7, 78, 5, 6, 7, 90, 18, 12, 77, 0, 67, 5, 12, 34, 72, 11, 12, 7, 91, 5, 12, 7, 13, 25, 12, 7, 70, 5, 21, 7, 78, 11, 12, 7, 90, 79, 12, 37, 0, 67, 5, 82, 15, 70, 5, 96, 7, 23, 5, 21, 22, 16, 5, 20, 7, 90, 79, 21, 28, 0, 67, 25, 96, 66, 72, 5, 12, 7, 91, 5, 12, 31, 70, 18, 9, 7, 70, 25, 6, 7, 78, 5, 21, 7, 90, 5, 82, 7, 27, 24, 39, 7, 0, 67, 24, 98, 41, 72, 38, 82, 15, 10, 54, 96, 15, 91, 25, 40, 60, 70, 38, 98, 31, 78, 25, 39, 37]
90
[0, 1, 43, 208, 1, 32, 55, 31, 4, 25, 50, 31, 8, 25, 50, 135, 0, 23, 53, 57, 31, 8, 53, 57, 28, 10, 51, 57, 7, 10, 32, 6, 31, 13, 49, 57, 94, 0, 8, 64, 50, 42, 16, 32, 12, 31, 17, 25, 12, 52, 27, 14, 48, 31, 0, 1, 25, 6, 31, 4, 24, 57, 31, 8, 11, 50, 34, 13, 45, 48,

[0, 1, 2, 142, 1, 11, 84, 19, 8, 18, 98, 7, 10, 5, 84, 31, 13, 79, 84, 31, 16, 79, 84, 31, 27, 18, 98, 7, 0, 1, 5, 82, 7, 4, 5, 82, 31, 8, 79, 86, 36, 0, 1, 79, 84, 15, 67, 24, 98, 52, 23, 79, 84, 41, 8, 5, 82, 7, 10, 79, 96, 15, 13, 38, 86, 7, 16, 5, 96, 31, 27, 5, 96, 15, 0, 1, 5, 96, 15, 67, 24, 86, 77, 0, 4, 79, 86, 7, 8, 5, 86, 7, 10, 79, 96, 15, 91, 79, 84, 7, 13, 79, 98, 7, 70, 24, 96, 7, 16, 79, 84, 7, 27, 79, 84, 31, 0, 1, 5, 98, 66, 0, 8, 18, 98, 7, 10, 5, 84, 22, 91, 18, 107, 31, 13, 11, 84, 31, 16, 11, 98, 15, 90, 18, 98, 15, 27, 79, 98, 15, 74, 11, 84, 7, 0, 1, 18, 67, 5, 98, 31, 4, 32, 98, 7, 8, 79, 82, 52, 8, 11, 84, 31, 72, 54, 86, 7, 10, 38, 58, 31, 91, 79, 98, 94, 17, 24, 98, 31, 27, 5, 12, 31, 0, 1, 25, 57, 41, 8, 5, 96, 41, 13, 79, 39, 37, 0, 1, 25, 98, 112, 8, 79, 84, 52, 10, 24, 98, 7, 13, 25, 50, 7, 16, 79, 58, 22, 78, 25, 6, 31, 17, 25, 50, 31, 27, 79, 84, 31]
100
[0, 1, 2, 110, 1, 5, 26, 66, 10, 11, 73, 7, 13, 25, 73, 52, 78, 64, 73, 7, 27, 24, 21, 31, 0, 1, 79

[0, 1, 2, 69, 1, 11, 21, 31, 8, 5, 21, 41, 10, 24, 12, 31, 13, 14, 21, 31, 16, 38, 21, 31, 17, 38, 21, 34, 0, 1, 38, 21, 19, 67, 51, 12, 15, 4, 38, 73, 7, 23, 51, 50, 34, 8, 5, 61, 31, 10, 24, 96, 31, 13, 32, 96, 7, 13, 11, 20, 31, 16, 25, 21, 59, 27, 24, 21, 31, 0, 1, 24, 21, 7, 4, 14, 21, 52, 8, 38, 82, 34, 13, 38, 21, 19, 78, 11, 21, 31, 17, 24, 12, 31, 27, 24, 21, 19, 0, 1, 32, 96, 52, 8, 14, 9, 85, 8, 38, 82, 7, 10, 11, 20, 15, 91, 38, 21, 15, 91, 24, 21, 34, 13, 11, 96, 41, 0, 1, 38, 21, 15, 67, 11, 12, 41, 20, 31, 4, 11, 12, 31, 8, 5, 21, 31, 13, 79, 21, 7, 16, 38, 21, 15, 78, 38, 12, 31, 17, 79, 20, 15, 90, 11, 96, 31, 74, 24, 20, 15, 0, 1, 32, 21, 15, 4, 24, 21, 31, 23, 11, 12, 22, 8, 11, 12, 31, 10, 24, 9, 15, 91, 79, 9, 15, 13, 5, 12, 31, 16, 11, 9, 15, 78, 32, 21, 28, 0, 1, 11, 12, 15, 67, 5, 6, 15, 4, 38, 12, 15, 23, 24, 21, 41, 16, 5, 21, 7, 78, 11, 12, 31, 27, 38, 21, 15, 17, 38, 12, 31, 74, 24, 12, 15, 0, 1, 11, 12, 15, 8, 5, 9, 7, 10, 11, 12, 31, 0, 1, 11, 21, 66, 10, 

In [61]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(generated_outputs + "/intro/" + '/intro' + str(i) + '.mid')
    solo = mido.MidiFile(generated_outputs + "/solo/" +'/solo' + str(i) + '.mid')
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(generated_outputs + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in solo.tracks[1]:
        if msg.type == "note_on":
            total_solo_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_solo_time + total_intro_time)
    solo.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_solo_time + total_intro_time
    print(outro.tracks[1][1].time)
    intro.tracks[1].name = "intro"
    solo.tracks[1].name = "solo"
    outro.tracks[1].name = "outro"
    predict.tracks[1].name = "predict"
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + solo.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged' + str(i) + '.mid')
    
    
    outro = mido.MidiFile(generated_outputs + "/outro/" +'/outro' + str(i) + '.mid')
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(generated_outputs + '/merged_predict' + str(i) + '.mid')

30480
30480
23460
23460
40200
40200
37500
37500
27720
27720
32820
32820
53340
53340
49080
49080
25800
25800
30480
30480
20400
20400
35640
35640
30900
30900
27540
27540
31680
31680
31440
31440
30720
30720
32220
32220
61500
61500
58080
58080
31380
31380
29640
29640
32940
32940
30960
30960
29280
29280
38400
38400
22260
22260
21420
21420
28560
28560
28200
28200
60660
60660
60060
60060
23460
23460
35700
35700
30300
30300
37380
37380
30840
30840
41700
41700
73260
73260
51900
51900
42360
42360
38520
38520
23100
23100
32340
32340
32760
32760
39780
39780
30060
30060
34680
34680
30720
30720
37620
37620
31200
31200
32580
32580
38160
38160
35460
35460
35400
35400
32940
32940
35400
35400
41100
41100
15840
15840
17580
17580
28740
28740
36120
36120
28800
28800
40380
40380
15840
15840
20700
20700
30960
30960
38520
38520
67560
67560
60900
60900
16560
16560
22080
22080
28740
28740
29160
29160
61920
61920
47400
47400
30000
30000
31200
31200
32400
32400
34980
34980
27720
27720
40500
40500
25920
25920
2892

In [62]:
# dissimilar_interpolation
for i in range(0,len(test_intro)):
#     if len(test_intro) > 1200:
#         continue
    intro = test_intro[i]
    #solo = test_solo[i]
    if i + 3 < (len(test_intro)):
        outro = test_outro[i+3]
    else:
        outro = test_outro[i]
    #print(intro)
    #print(outro)
    list_intro = [int(x) for x in intro.split(' ')]
    #list_solo = [int(x) for x in solo.split(' ')]
    list_outro = [int(x) for x in outro.split(' ')]
    #print(list_sentence)
    translated_sentence = translate_sentence_ensemble(intro_transformer, outro_transformer, intro, outro, intro_field, outro_field, solo_field, device, max_length=1200)
    #print(translated_sentence)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']
    print(translated_sentence)
    utils.write_midi(list_intro, word2event, dissimilar_interpolation + "/intro/" + "/intro" + str(i)  + ".mid")
    #utils.write_midi(list_solo, word2event, generated_outputs  + "/solo/" + "/solo" + str(i)  + ".mid")
    utils.write_midi(list_outro, word2event, dissimilar_interpolation + "/outro/" + "/outro" + str(i)  + ".mid")
    utils.write_midi(translated_sentence, word2event, dissimilar_interpolation + "/predict/" + "/predict" + str(i)  + ".mid")
    print(i)
#     if i == 10:
#         break
        


[0, 1, 2, 162, 10, 11, 40, 15, 91, 11, 73, 15, 13, 11, 20, 15, 70, 38, 20, 22, 16, 38, 12, 15, 16, 64, 82, 22, 78, 64, 20, 22, 17, 32, 20, 22, 17, 38, 61, 15, 90, 32, 20, 7, 27, 11, 20, 34, 0, 1, 51, 120, 15, 67, 51, 21, 15, 4, 38, 20, 15, 23, 32, 39, 15, 8, 11, 97, 15, 72, 11, 82, 19, 10, 64, 20, 15, 91, 64, 20, 15, 13, 11, 20, 15, 91, 32, 61, 22, 70, 11, 20, 15, 16, 51, 21, 22, 16, 11, 20, 15, 78, 64, 12, 15, 17, 38, 20, 7, 27, 11, 21, 15, 27, 64, 20, 15, 74, 64, 21, 7, 0, 1, 51, 61, 19, 1, 51, 21, 15, 1, 38, 21, 15, 67, 11, 21, 34, 4, 38, 12, 15, 4, 38, 9, 22, 4, 32, 20, 34, 8, 11, 9, 22, 23, 11, 20, 15, 8, 64, 12, 22, 8, 32, 26, 7, 10, 64, 9, 7, 13, 51, 9, 41, 8, 24, 21, 7, 13, 51, 73, 7, 16, 64, 9, 19, 90, 54, 20, 22, 27, 64, 21, 7, 74, 51, 19, 0, 1, 11, 73, 22, 1, 51, 73, 15, 1, 38, 21, 7, 67, 38, 21, 7, 4, 11, 21, 15, 23, 64, 96, 15, 8, 38, 12, 22, 8, 38, 12, 7, 8, 38, 21, 22, 72, 32, 21, 7, 10, 38, 21, 15, 10, 32, 12, 7, 91, 38, 20, 22, 91, 32, 20, 7, 13, 51, 20, 15, 13, 51, 73

[0, 1, 2, 146, 1, 51, 58, 15, 67, 51, 57, 15, 4, 5, 61, 15, 23, 32, 86, 15, 8, 51, 21, 15, 8, 24, 86, 15, 72, 35, 57, 15, 10, 38, 96, 15, 91, 64, 57, 41, 13, 51, 86, 22, 67, 38, 61, 19, 17, 64, 96, 15, 4, 32, 86, 15, 8, 51, 96, 47, 90, 64, 86, 15, 74, 64, 96, 15, 0, 1, 51, 96, 15, 67, 64, 86, 15, 4, 11, 98, 34, 8, 24, 57, 15, 8, 38, 86, 15, 72, 32, 96, 15, 10, 51, 86, 34, 91, 32, 96, 15, 13, 51, 86, 15, 70, 32, 96, 36, 17, 2, 146, 27, 11, 86, 15, 74, 11, 76, 7, 0, 1, 51, 86, 15, 67, 51, 86, 15, 4, 64, 96, 31, 23, 54, 96, 15, 8, 2, 146, 8, 2, 146, 8, 2, 146, 72, 32, 96, 19, 13, 2, 159, 1, 51, 96, 15, 70, 11, 98, 15, 16, 64, 96, 15, 78, 11, 98, 15, 17, 2, 150, 90, 11, 98, 15, 27, 11, 98, 15, 74, 14, 84, 15, 0, 1, 2, 159, 67, 32, 98, 15, 4, 32, 98, 15, 8, 2, 159, 23, 11, 86, 15, 8, 2, 148, 72, 64, 98, 15, 10, 11, 98, 15, 91, 54, 96, 15, 13, 64, 98, 15, 13, 2, 159, 17, 18, 98, 15, 74, 24, 96, 15, 0, 1, 2, 142, 4, 64, 84, 15, 8, 64, 96, 15, 72, 5, 84, 15, 91, 11, 96, 15, 13, 54, 96, 22, 13,

[0, 1, 2, 114, 1, 35, 58, 31, 4, 35, 58, 22, 23, 54, 9, 31, 8, 35, 86, 31, 8, 51, 61, 7, 10, 51, 86, 15, 91, 54, 61, 66, 16, 35, 84, 52, 17, 49, 61, 15, 17, 35, 86, 31, 27, 51, 61, 42, 0, 8, 35, 86, 31, 72, 51, 97, 15, 10, 35, 84, 15, 91, 49, 96, 31, 13, 45, 96, 7, 16, 49, 86, 15, 78, 54, 20, 31, 17, 35, 84, 31, 27, 51, 84, 31, 0, 1, 64, 73, 52, 4, 53, 61, 15, 23, 63, 61, 15, 8, 53, 58, 85, 0, 4, 54, 76, 22, 8, 5, 76, 15, 8, 5, 86, 15, 10, 35, 20, 7, 10, 35, 76, 15, 91, 51, 61, 41, 16, 51, 73, 31, 16, 25, 58, 22, 78, 51, 58, 15, 17, 35, 61, 52, 27, 5, 40, 52, 0, 1, 35, 58, 52, 4, 35, 86, 31, 8, 53, 20, 22, 72, 11, 86, 7, 10, 35, 76, 15, 91, 35, 86, 31, 13, 51, 20, 31, 70, 51, 26, 34, 17, 53, 84, 15, 90, 32, 20, 22, 27, 35, 86, 34, 0, 1, 63, 97, 15, 67, 51, 84, 31, 4, 63, 86, 22, 4, 35, 86, 7, 8, 32, 96, 15, 72, 51, 61, 31, 10, 51, 86, 31, 72, 11, 84, 15, 10, 32, 86, 31, 91, 54, 61, 31, 70, 5, 61, 31, 70, 5, 76, 31, 78, 35, 76, 15, 16, 53, 86, 7, 78, 38, 73, 31, 78, 54, 76, 15, 17, 53, 

[0, 1, 43, 69, 8, 35, 46, 7, 72, 51, 58, 7, 10, 51, 58, 31, 13, 2, 69, 70, 64, 73, 7, 78, 51, 58, 31, 17, 35, 73, 19, 0, 1, 64, 58, 66, 8, 51, 61, 7, 72, 54, 61, 31, 91, 32, 58, 31, 70, 51, 61, 31, 13, 2, 69, 78, 51, 62, 41, 17, 35, 62, 31, 90, 32, 62, 52, 27, 35, 73, 52, 0, 1, 64, 58, 52, 4, 51, 62, 34, 4, 51, 62, 31, 8, 38, 76, 31, 10, 35, 76, 31, 91, 54, 76, 31, 13, 64, 73, 31, 16, 54, 61, 7, 17, 35, 73, 31, 17, 51, 76, 31, 27, 51, 61, 31, 0, 1, 51, 61, 31, 8, 54, 61, 34, 10, 54, 62, 31, 13, 35, 62, 31, 13, 35, 20, 31, 16, 51, 76, 31, 17, 51, 62, 28, 0, 1, 51, 86, 34, 8, 51, 76, 7, 10, 51, 61, 31, 13, 51, 86, 52, 17, 51, 61, 31, 27, 35, 73, 31, 0, 1, 64, 62, 41, 8, 35, 62, 41, 72, 71, 61, 66]
23
[0, 1, 2, 148, 4, 38, 50, 22, 8, 2, 148, 8, 11, 57, 52, 13, 11, 12, 31, 16, 32, 12, 19, 17, 2, 148, 17, 2, 148, 27, 51, 12, 31, 0, 1, 64, 57, 15, 4, 32, 86, 7, 8, 2, 148, 8, 2, 148, 13, 64, 86, 15, 16, 51, 57, 7, 17, 2, 148, 17, 14, 57, 15, 27, 32, 12, 52, 0, 1, 32, 12, 7, 4, 11, 61, 15, 8, 

[0, 1, 43, 124, 8, 35, 97, 31, 10, 54, 84, 15, 91, 51, 84, 31, 13, 35, 107, 15, 70, 51, 84, 15, 16, 64, 118, 15, 78, 32, 84, 31, 90, 64, 107, 31, 27, 54, 76, 31, 0, 1, 54, 61, 15, 4, 51, 97, 31, 8, 35, 86, 31, 10, 54, 20, 15, 91, 54, 76, 15, 13, 32, 76, 31, 16, 54, 97, 7, 17, 53, 76, 15, 27, 54, 76, 31, 0, 1, 54, 76, 31, 4, 35, 76, 31, 8, 54, 61, 31, 10, 54, 73, 31, 13, 32, 73, 7, 13, 32, 76, 31, 16, 63, 84, 15, 17, 45, 76, 31, 27, 35, 76, 31, 0, 1, 49, 86, 31, 4, 35, 76, 31, 8, 54, 61, 31, 10, 54, 73, 31, 13, 51, 86, 7, 16, 35, 61, 31, 17, 54, 61, 31, 27, 35, 73, 56, 0, 23, 54, 61, 7, 8, 35, 84, 31, 10, 51, 73, 31, 13, 64, 61, 7, 13, 51, 73, 42, 17, 35, 76, 31, 27, 54, 61, 31, 0, 1, 64, 61, 28, 10, 49, 84, 34, 16, 54, 61, 31, 10, 54, 61, 7, 17, 35, 61, 34, 17, 35, 61, 31, 27, 35, 61, 31, 0, 1, 54, 61, 31, 4, 54, 58, 34, 8, 35, 76, 34, 8, 32, 61, 31, 10, 35, 86, 215, 0, 13, 49, 96, 28, 16, 35, 86, 94]
33
[0, 1, 43, 116, 1, 51, 6, 19, 8, 43, 116, 8, 51, 6, 15, 10, 51, 6, 52, 13, 43, 21,

[0, 1, 43, 191, 23, 54, 111, 15, 72, 45, 100, 15, 10, 51, 104, 22, 91, 11, 33, 15, 13, 32, 93, 7, 70, 54, 104, 34, 90, 32, 93, 31, 70, 14, 104, 7, 78, 35, 29, 34, 0, 1, 54, 75, 52, 72, 32, 75, 94, 72, 49, 111, 31, 13, 45, 75, 15, 70, 32, 111, 31, 78, 35, 75, 34, 27, 51, 29, 7, 0, 72, 49, 111, 31, 10, 24, 93, 7, 91, 49, 92, 15, 13, 32, 33, 31, 70, 11, 111, 15, 70, 11, 33, 7, 78, 35, 33, 34, 27, 63, 29, 31, 0, 1, 32, 29, 15, 67, 32, 33, 52, 4, 32, 30, 15, 8, 51, 75, 22, 72, 64, 55, 31, 10, 51, 30, 31, 91, 51, 57, 7, 70, 51, 50, 31, 70, 51, 12, 52, 90, 53, 55, 31, 74, 51, 57, 19, 0, 67, 53, 55, 41, 70, 51, 57, 34, 91, 35, 55, 31, 17, 51, 58, 31, 90, 51, 57, 19, 0, 67, 63, 29, 41, 67, 51, 100, 34, 23, 51, 55, 28, 70, 64, 55, 15, 90, 51, 30, 31, 27, 64, 48, 41, 0, 67, 51, 30, 7, 23, 51, 6, 15, 72, 51, 57, 52, 70, 64, 30, 15, 17, 64, 55, 15, 90, 35, 30, 15, 90, 64, 55, 15, 27, 51, 55, 31, 0, 1, 51, 29, 52, 67, 64, 55, 31, 67, 49, 30, 31, 4, 51, 30, 34, 23, 54, 55, 31, 72, 35, 55, 34]
40
[0, 

[0, 1, 2, 139, 4, 24, 49, 12, 15, 8, 38, 21, 15, 72, 38, 40, 7, 91, 38, 21, 15, 13, 32, 21, 15, 70, 64, 12, 66, 0, 4, 38, 12, 15, 23, 11, 12, 7, 72, 11, 21, 7, 10, 54, 12, 15, 91, 38, 9, 15, 13, 32, 12, 15, 70, 38, 21, 19, 16, 38, 21, 7, 16, 11, 12, 7, 78, 38, 12, 7, 17, 11, 12, 77, 0, 23, 11, 12, 15, 8, 38, 12, 52, 72, 38, 21, 7, 10, 38, 12, 37, 0, 67, 38, 21, 15, 91, 38, 96, 52, 13, 38, 20, 41, 16, 38, 82, 15, 78, 32, 21, 15, 17, 38, 21, 15, 90, 38, 9, 15, 27, 38, 82, 15, 74, 38, 20, 15, 0, 67, 64, 20, 7, 23, 38, 96, 15, 23, 38, 40, 7, 72, 64, 20, 15, 10, 38, 96, 15, 91, 38, 21, 52, 70, 38, 96, 15, 78, 54, 96, 15, 90, 38, 20, 15, 0, 8, 38, 20, 7, 72, 64, 12, 47, 16, 38, 21, 47, 0, 72, 32, 96, 15]
50
[0, 1, 2, 150, 8, 5, 82, 22, 72, 5, 82, 31, 10, 51, 82, 15, 91, 11, 96, 7, 13, 14, 84, 7, 16, 64, 82, 15, 78, 53, 98, 15, 17, 11, 82, 31, 90, 32, 98, 59, 27, 51, 96, 34, 0, 8, 51, 96, 15, 72, 54, 82, 15, 10, 64, 82, 15, 10, 64, 96, 15, 13, 51, 96, 7, 17, 49, 96, 22, 27, 51, 82, 15, 0, 1, 

[0, 1, 43, 105, 8, 43, 105, 13, 43, 105, 70, 32, 40, 7, 70, 64, 39, 7, 17, 43, 105, 27, 64, 40, 52, 0, 1, 43, 105, 67, 51, 98, 19, 8, 43, 105, 72, 32, 98, 31, 91, 35, 82, 15, 13, 43, 105, 13, 43, 105, 13, 43, 105, 17, 51, 96, 15, 17, 43, 105, 17, 43, 105, 1, 43, 105, 1, 64, 96, 7, 4, 51, 98, 15, 23, 51, 98, 15, 8, 43, 105, 8, 43, 209, 8, 35, 82, 52, 13, 43, 105, 13, 43, 105, 17, 43, 209, 13, 43, 105, 17, 43, 105, 17, 43, 105, 0, 1, 43, 209, 8, 43, 209, 8, 43, 105, 13, 43, 105, 13, 43, 105, 17, 43, 209, 27, 51, 96, 41, 0, 1, 43, 209, 8, 43, 105, 13, 43, 209, 13, 43, 105, 17, 43, 105, 17, 43, 105, 0, 1, 43, 209, 1, 43, 209, 8, 43, 105, 8, 43, 105, 17, 43, 209, 8, 43, 105, 17, 64, 96, 19, 8, 43, 105, 17, 43, 209, 27, 53, 12, 31, 0, 1, 43, 209, 8, 43, 209, 8, 64, 86, 15, 10, 43, 209, 13, 43, 105, 13, 43, 209, 17, 43, 117, 74, 64, 39, 15, 0, 1, 43, 209, 8, 43, 105, 17, 43, 209, 8, 51, 121, 19, 13, 43, 209, 17, 43, 209, 13, 43, 209, 17, 43, 105, 0, 1, 43, 105, 8, 43, 105, 91, 14, 39, 85, 8, 

[0, 1, 2, 181, 1, 5, 46, 7, 4, 11, 46, 22, 67, 54, 48, 22, 67, 14, 46, 22, 4, 53, 48, 22, 23, 5, 62, 15, 23, 14, 48, 22, 8, 5, 58, 15, 8, 11, 46, 66, 13, 11, 62, 15, 70, 11, 58, 15, 16, 11, 62, 22, 16, 24, 48, 19, 17, 14, 58, 15, 90, 11, 58, 7, 74, 38, 62, 15, 0, 1, 5, 100, 15, 67, 24, 58, 15, 4, 11, 58, 15, 23, 11, 58, 15, 8, 14, 58, 15, 72, 24, 48, 22, 8, 24, 48, 36, 91, 5, 58, 7, 16, 32, 46, 58, 15, 78, 11, 58, 15, 17, 11, 58, 15, 90, 32, 62, 15, 27, 11, 48, 15, 0, 1, 24, 48, 7, 4, 11, 58, 15, 67, 24, 48, 31, 23, 11, 62, 15, 8, 11, 57, 15, 72, 5, 62, 15, 10, 11, 58, 15, 13, 14, 58, 15, 70, 32, 48, 47, 0, 4, 11, 46, 7, 8, 5, 76, 15, 72, 14, 62, 31, 91, 51, 73, 15, 13, 32, 58, 15, 16, 32, 58, 28, 70, 32, 73, 15, 16, 11, 58, 34, 0, 1, 5, 57, 15, 17, 11, 73, 15, 90, 32, 73, 22, 27, 11, 73, 15, 74, 11, 61, 15, 0, 1, 38, 21, 15, 67, 11, 61, 22, 4, 11, 20, 22, 23, 11, 21, 15, 8, 18, 61, 31, 10, 11, 73, 22, 91, 51, 73, 15, 13, 11, 21, 34, 16, 11, 21, 15, 78, 38, 61, 15, 17, 5, 73, 22, 17, 1

[0, 1, 2, 130, 10, 11, 20, 31, 13, 32, 21, 52, 16, 18, 20, 31, 17, 24, 20, 31, 27, 24, 82, 31, 0, 1, 25, 82, 103, 0, 1, 32, 20, 34, 8, 11, 82, 7, 10, 32, 20, 31, 13, 11, 20, 42, 0, 4, 38, 20, 34, 8, 5, 20, 31, 10, 11, 21, 31, 13, 38, 20, 34, 17, 25, 21, 31, 27, 14, 21, 31, 0, 1, 25, 20, 41, 0, 67, 64, 20, 7, 4, 24, 73, 31, 8, 49, 73, 7, 8, 5, 21, 15, 10, 14, 73, 77, 0, 4, 38, 40, 7, 8, 5, 20, 31, 10, 32, 82, 31, 13, 25, 21, 66, 27, 25, 20, 31, 0, 4, 32, 21, 31, 8, 14, 82, 15, 72, 24, 20, 15, 10, 64, 21, 34, 16, 32, 20, 31, 17, 38, 82, 34, 27, 11, 96, 7, 0, 4, 11, 73, 52, 8, 11, 21, 7, 10, 11, 21, 7, 13, 11, 20, 31, 16, 24, 21, 34, 17, 14, 21, 31, 27, 32, 20, 15]
76
[0, 1, 2, 175, 10, 53, 98, 7, 91, 11, 88, 7, 13, 32, 144, 15, 70, 32, 222, 80, 74, 11, 98, 7, 0, 1, 45, 39, 31, 4, 11, 84, 22, 23, 32, 10, 11, 98, 31, 8, 49, 82, 22, 91, 32, 39, 7, 70, 24, 98, 77, 0, 72, 14, 88, 77, 0, 67, 24, 88, 7, 10, 32, 98, 41, 10, 64, 98, 7, 13, 51, 98, 41, 13, 11, 107, 66, 70, 11, 39, 7, 0, 67, 32, 39

[0, 1, 2, 159, 8, 5, 33, 31, 10, 5, 26, 42, 0, 4, 5, 9, 31, 8, 25, 9, 31, 10, 51, 9, 7, 13, 25, 6, 31, 16, 79, 9, 31, 17, 38, 9, 31, 27, 18, 6, 31, 0, 1, 5, 9, 15, 67, 24, 6, 31, 4, 25, 6, 47, 8, 14, 26, 52, 10, 18, 26, 31, 13, 5, 6, 31, 16, 14, 26, 31, 17, 5, 40, 31, 27, 5, 6, 7, 0, 1, 14, 6, 31, 4, 14, 21, 28, 13, 14, 21, 31, 16, 11, 9, 31, 17, 25, 21, 31, 27, 38, 12, 15, 74, 38, 12, 31, 0, 1, 5, 12, 31, 4, 5, 73, 31, 8, 5, 73, 59]
86
[0, 1, 2, 3, 4, 5, 48, 7, 8, 5, 48, 7, 10, 11, 48, 31, 13, 25, 46, 7, 16, 79, 62, 31, 17, 38, 48, 7, 27, 5, 48, 31, 0, 1, 38, 62, 31, 4, 5, 58, 7, 8, 14, 46, 52, 10, 25, 46, 52, 13, 5, 62, 7, 16, 25, 62, 31, 17, 11, 48, 7, 27, 11, 48, 28, 0, 10, 24, 48, 31, 13, 24, 75, 52, 16, 5, 62, 7, 17, 64, 46, 7, 27, 25, 62, 31, 0, 1, 11, 46, 7, 4, 32, 26, 34, 8, 11, 48, 36, 10, 5, 46, 31, 13, 5, 62, 7, 16, 11, 62, 31, 17, 24, 48, 7, 27, 11, 62, 7, 0, 1, 38, 62, 7, 4, 11, 58, 31, 8, 25, 48, 22, 10, 38, 62, 52, 10, 24, 62, 22, 13, 38, 73, 7, 16, 11, 62, 31, 17, 5, 5

[0, 1, 2, 113, 67, 25, 62, 7, 23, 79, 62, 103, 90, 79, 58, 19, 74, 24, 62, 7, 0, 67, 18, 62, 15, 4, 64, 26, 15, 23, 79, 62, 56, 0, 67, 18, 26, 15, 4, 79, 26, 15, 23, 18, 46, 31, 91, 5, 46, 41, 70, 32, 46, 36, 0, 67, 11, 46, 7, 4, 11, 46, 15, 8, 11, 62, 31, 91, 11, 26, 7, 13, 25, 46, 15, 16, 11, 26, 15, 0, 1, 54, 75, 7, 67, 25, 46, 15, 23, 38, 75, 15, 72, 11, 46, 80, 90, 79, 46, 7, 0, 67, 25, 26, 7, 8, 5, 46, 41, 10, 79, 20, 15, 91, 18, 61, 108, 0, 1, 5, 61, 7, 4, 5, 76, 125]
96
[0, 1, 2, 113, 4, 11, 50, 7, 8, 38, 57, 7, 10, 32, 57, 7, 13, 11, 57, 31, 16, 5, 57, 36, 0, 1, 11, 58, 31, 4, 51, 58, 34, 8, 38, 50, 34, 10, 11, 57, 31, 13, 38, 58, 37, 0, 1, 64, 57, 15, 4, 32, 58, 31, 8, 24, 57, 15, 72, 11, 50, 7, 10, 38, 48, 7, 13, 32, 58, 7, 16, 24, 58, 34, 27, 32, 57, 31, 0, 1, 32, 58, 52, 4, 24, 57, 31, 8, 11, 58, 31, 10, 11, 57, 36, 0, 1, 79, 86, 7, 4, 5, 58, 31, 8, 24, 57, 31, 10, 38, 61, 52, 13, 79, 57, 31, 16, 5, 57, 31]
97
[0, 1, 43, 208, 1, 32, 62, 31, 4, 25, 57, 184, 0, 4, 11, 61, 31

[0, 1, 2, 69, 72, 24, 61, 7, 10, 24, 20, 77, 27, 24, 20, 37, 0, 72, 24, 61, 41, 8, 38, 20, 41, 10, 25, 20, 41, 13, 14, 20, 103, 90, 38, 82, 102, 0, 8, 2, 134, 72, 24, 61, 37, 0, 10, 79, 76, 7, 13, 25, 20, 41, 16, 24, 76, 7, 17, 11, 61, 15, 90, 25, 73, 19, 27, 24, 73, 7, 0, 1, 79, 20, 7, 4, 25, 20, 15, 23, 24, 20, 77, 8, 18, 20, 7, 8, 11, 40, 7, 10, 64, 84, 41, 70, 5, 84, 77, 0, 4, 5, 20, 7, 8, 25, 97, 7, 10, 11, 20, 31, 13, 25, 97, 136, 0, 4, 38, 84, 15, 23, 24, 76, 36, 8, 5, 97, 7, 16, 14, 120, 7, 8, 25, 40, 31, 10, 5, 107, 41, 13, 25, 40, 7, 16, 25, 97, 7, 17, 25, 97, 52, 27, 14, 97, 37, 0, 8, 38, 120, 34]
107
[0, 1, 2, 197, 10, 35, 26, 15, 91, 35, 9, 15, 13, 64, 62, 52, 13, 51, 62, 19, 78, 64, 58, 19, 90, 51, 26, 19, 0, 67, 64, 26, 7, 4, 32, 20, 31, 8, 64, 9, 34, 8, 64, 26, 7, 10, 64, 62, 31, 91, 51, 9, 15, 13, 35, 20, 15, 70, 32, 20, 19, 17, 64, 26, 52, 27, 51, 9, 52, 0, 67, 32, 9, 52, 8, 2, 142, 8, 64, 26, 31, 10, 25, 9, 31, 13, 35, 6, 31, 70, 51, 21, 52, 78, 35, 9, 31, 17, 35, 9,

In [63]:
import mido
for i in range(len(test_intro)):
    intro = mido.MidiFile(dissimilar_interpolation + "/intro/" + '/intro' + str(i) + '.mid')
    outro = mido.MidiFile(dissimilar_interpolation + "/outro/" +'/outro' + str(i) + '.mid')
    predict = mido.MidiFile(dissimilar_interpolation + "/predict/" +'/predict' + str(i) + '.mid')
    total_intro_time = 0
    total_solo_time = 0
    total_predict_time = 0
    for msg in intro.tracks[1]:
        if msg.type == "note_on":
            total_intro_time += msg.time
    for msg in predict.tracks[1]:
        if msg.type == "note_on":
            total_predict_time += msg.time
            
    original_outro_time = 0 + outro.tracks[1][1].time
    
    print(original_outro_time + total_predict_time + total_intro_time)
    predict.tracks[1][1].time += total_intro_time
    outro.tracks[1][1].time = original_outro_time + total_predict_time + total_intro_time
    print(outro.tracks[1][1].time)
    merged_mid = mido.MidiFile()
    merged_mid.ticks_per_beat = intro.ticks_per_beat
    merged_mid.tracks = intro.tracks + predict.tracks + outro.tracks
    merged_mid.save(dissimilar_interpolation + '/merged_predict' + str(i) + '.mid')

28560
28560
35760
35760
39060
39060
46320
46320
29400
29400
24960
24960
36600
36600
34980
34980
31500
31500
44340
44340
25140
25140
27660
27660
36360
36360
25260
25260
22380
22380
53400
53400
51660
51660
36180
36180
41220
41220
56820
56820
42300
42300
17820
17820
28200
28200
25800
25800
34260
34260
34260
34260
34320
34320
28980
28980
31560
31560
22020
22020
26220
26220
21180
21180
25320
25320
31260
31260
60180
60180
24840
24840
31740
31740
47760
47760
31020
31020
34380
34380
28020
28020
28500
28500
32220
32220
23640
23640
27600
27600
16440
16440
27540
27540
23580
23580
44940
44940
29700
29700
23880
23880
26580
26580
30840
30840
32580
32580
31020
31020
34080
34080
30060
30060
33060
33060
40260
40260
30540
30540
29760
29760
27720
27720
32460
32460
29580
29580
14040
14040
29580
29580
33900
33900
33960
33960
36720
36720
32940
32940
35460
35460
25380
25380
33300
33300
35100
35100
19560
19560
46740
46740
33960
33960
27360
27360
27780
27780
41220
41220
28800
28800
24360
24360
32760
32760
2550

In [None]:
class BeamSearchNode(object):
    def __init__(self, prev_node, wid, logp, length):
        self.prev_node = prev_node
        self.wid = wid
        self.logp = logp
        self.length = length

    def eval(self):
        return self.logp / float(self.length - 1 + 1e-6)
# }}}
import copy
from heapq import heappush, heappop

In [None]:
def translate_sentence_beam(model, sentence, german, english, device, max_length=1200,beam_width=2,max_dec_steps=25000):
    
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    eos_token = english.vocab.stoi["<eos>"]
    sos_token = english.vocab.stoi["<sos>"]
    
    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    n_best_list = []
    
     
    #trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

    #first token as input
    trg_tensor = torch.LongTensor(outputs).to(device)
    
    end_nodes = []

    #starting node
    node = BeamSearchNode(prev_node=None, wid=trg_tensor, logp=0, length=1)

    nodes = []

    heappush(nodes, (-node.eval(), id(node), node))
    n_dec_steps = 0

    while True:
        # Give up when decoding takes too long
        if n_dec_steps > max_dec_steps:
            break
        
        # Fetch the best node
        #print([n[2].wid for n in nodes])
        score, _, n = heappop(nodes)
        decoder_input = n.wid
        
        if n.wid.item() == eos_token and n.prev_node is not None:
            end_nodes.append((score, id(n), n))
            # If we reached maximum # of sentences required
            if len(end_nodes) >= beam_width:
                break
            else:
                continue
   
        sequence = [n.wid.item()]
        a = n
        while a.prev_node is not None:
            a = a.prev_node
            sequence.append(a.wid.item())
        sequence = sequence[::-1] # reverse
        
        #print(sequence)
        
        with torch.no_grad():
            output = model(sentence_tensor, torch.LongTensor(sequence).unsqueeze(1).to(device))
        
        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(output, beam_width)
        #print(topk_indexes)
        #print(topk_log_prob)
        # Then, register new top-k nodes
        for new_k in range(beam_width):
            decoded_t = topk_indexes[0][0][new_k].view(1) # (1)
            logp = topk_log_prob[0][0][new_k].item() # float log probability val

            node = BeamSearchNode(prev_node=n,
                                  wid=decoded_t,
                                  logp=n.logp+logp,
                                  length=n.length+1)
            heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps += beam_width
        #print(n_dec_steps)
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    if len(end_nodes) == 0:
        end_nodes = [heappop(nodes) for _ in range(beam_width)]

    # Construct sequences from end_nodes
    n_best_seq_list = []
    for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
        sequence = [n.wid.item()]
        # back trace from end node
        while n.prev_node is not None:
            n = n.prev_node
            sequence.append(n.wid.item())
        sequence = sequence[::-1] # reverse

        n_best_seq_list.append(sequence)


    # return n_best_seq_list

    translated_sentence = [english.vocab.itos[idx] for idx in n_best_seq_list[0]]

    # remove start token
    return translated_sentence


In [30]:
def save_vocab(vocab, path):
    output = open(path, 'wb')
    pickle.dump(vocab, output)
    output.close()

In [31]:
vocab_folder = "vocab/"
save_vocab(intro_field.vocab, vocab + '/intro_vocab.pkl')
save_vocab(solo_field.vocab, vocab  + '/solo_vocab.pkl')
save_vocab(outro_field.vocab, vocab + '/outro_vocab.pkl')

In [27]:
def load_vocab(path):
    with open(path, 'rb') as f:
        x = pickle.load(f)
    return x

In [24]:
vocab_folder = "vocab_intro/"
intro_field.vocab = load_vocab(vocab_folder + 'intro_vocab.pkl')
solo_field.vocab = load_vocab(vocab_folder + 'solo_vocab.pkl')
outro_field.vocab = load_vocab(vocab_folder + 'outro_vocab.pkl')

In [None]:
vocab_folder = "vocab/"
with open()
intro_field.vocab = 

In [None]:
beam = [2, 59, 119, 13, 212, 59, 212, 59, 59, 75, 59, 59, 13, 119, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 119, 59, 59, 59, 59, 166, 59, 59, 59, 13, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 59, 59, 212, 212, 13, 59, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 14, 59, 59, 212, 59, 212, 212, 59, 59, 59, 68, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 68, 59, 59, 212, 59, 59, 13, 59, 59, 59, 59, 59, 59, 97, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 158, 59, 59, 13, 59, 13, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 13, 212, 13, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 158, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 13, 13, 59, 59, 13, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 212, 212, 59, 59, 59, 59, 59, 59, 13, 59, 166, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 86, 59, 212, 212, 212, 59, 59, 59, 59, 212, 59, 86, 59, 59, 59, 212, 212, 212, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 212, 59, 59, 166, 59, 86, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 59, 59, 59, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 13, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 212, 212, 158, 59, 59, 13, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 158, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 212, 59, 212, 212, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 13, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 158, 59, 59, 59, 212, 59, 212, 86, 59, 59, 59, 158, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 212, 59, 13, 59, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 13, 212, 59, 59, 59, 59, 59, 68, 59, 13, 59, 59, 13, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 14, 59, 59, 59, 59, 59, 59, 13, 86, 59, 59, 59, 212, 59, 86, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 13, 59, 59, 59, 59, 68, 59, 59, 59, 212, 13, 59, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 212, 212, 212, 13, 59, 166, 59, 212, 59, 59, 59, 13, 59, 59, 59, 59, 59, 59, 59, 59, 166, 212, 212, 59, 59, 212, 59, 212, 59, 59, 13, 59, 59, 59, 59, 13, 59, 59, 14, 13, 59, 59, 86, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 212, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 166, 59, 59, 59, 59, 59, 59, 59, 59, 13, 13, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 68, 59, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 212, 59, 59, 13, 59, 59, 166, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 158, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 13, 59, 59, 59, 59, 59, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 166, 59, 59, 59, 59, 13, 59, 59, 212, 212, 59, 59, 212, 59, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 59, 59, 59, 212, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59, 212, 59, 59, 59, 59]

translated_sentence1 = [solo_field.vocab.itos[idx] for idx in beam]
translated_sentence = [int(x) for x in translated_sentence1 if x != '<pad>' and x != '<sos>' and x != '<eos>' and x != '<unk>']    
utils.write_midi(translated_sentence, word2event, generated_outputs + "/predict1.mid")

In [None]:
translated_sentence

In [None]:
for i in range(len(test_intro)):
    if len(test_intro) > 1200:
        continue
    list_sentence = [int(x) for x in sentence.split(' ')]
    remi = [word2event[x] for x in list_sentence]
    print(remi)

In [None]:
def bleu_translate_sentence(model, sentence, german, english, device, max_length=1200):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    #tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
    #tokens.insert(0, german.init_token)
    #tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    #text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(sentence).unsqueeze(1).to(device)

    outputs = [english.vocab.stoi["<sos>"]]
    
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [None]:
from torchtext.data.metrics import bleu_score

def bleu(data, model, german, english, device):
    targets = []
    outputs = []
    print(len(data))
    for example in data:
        #print( vars(example))
        src = vars(example)["intro"]
        trg = vars(example)["solo"]
        
        src = [int(x) for x in src]
        trg = [int(x) for x in trg]
        
        if len(trg) > 1200 or len(src) > 1200:
            continue
        
        prediction = bleu_translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append(trg)
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [None]:
# running on entire test data takes a while
score = bleu(test[1:10], model, intro_field, solo_field, device)
print(f"Bleu score {score * 100:.2f}")

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns