In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from torchtext.legacy.data import Field, TabularDataset, BucketIterator,ReversibleField
import matplotlib.pyplot as plt
from ast import literal_eval
import remi_utils as utils
import pickle
source_folder = "solo_generation_dataset"
destination_folder = "solo_generation_weights"

In [2]:
event2word, word2event = pickle.load(open('dictionary.pkl', 'rb'))

In [3]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu" 
print(dev)
device = torch.device(dev)
print(device)

cuda:1
cuda:1


In [4]:
# Fields

intro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
intro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
outro_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
solo_piano_field = Field(tokenize=None, lower=True, include_lengths=True, batch_first=True, init_token="<sos>", eos_token="<eos>")
fields = [('intro', intro_field), ('intro_piano', intro_piano_field), \
          ('outro', outro_field), ('outro_piano', outro_piano_field), \
          ('solo', solo_field), ('solo_piano', solo_piano_field)]

# TabularDataset

train, valid, test = TabularDataset.splits(path=source_folder, train='train_torchtext.csv', validation='val_torchtext.csv', test='test_torchtext.csv',
                                           format='CSV', fields=fields, skip_header=True)

# Iterators

train_iter = BucketIterator(train, batch_size=8, sort_key=lambda x: len(x.solo),
                            device=device, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=8, sort_key=lambda x: len(x.solo),
                            device=device, sort=False, sort_within_batch=True)
test_iter = BucketIterator(test, batch_size=8, sort_key=lambda x: len(x.solo),
                            device=device, sort=False, sort_within_batch=True)

# Vocabulary

intro_field.build_vocab(train, min_freq=2)
intro_piano_field.build_vocab(train, min_freq=2)
outro_field.build_vocab(train, min_freq=2)
outro_piano_field.build_vocab(train, min_freq=2)
solo_field.build_vocab(train, min_freq=2)
solo_piano_field.build_vocab(train, min_freq=2)

In [5]:
for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (train_iter):
    print(intro.transpose(1,0).size())

torch.Size([226, 8])
torch.Size([138, 8])
torch.Size([170, 8])
torch.Size([219, 8])
torch.Size([191, 8])
torch.Size([223, 8])
torch.Size([214, 8])
torch.Size([271, 8])
torch.Size([231, 8])
torch.Size([218, 8])
torch.Size([182, 8])
torch.Size([158, 8])
torch.Size([182, 8])
torch.Size([186, 8])
torch.Size([122, 8])
torch.Size([154, 8])
torch.Size([195, 8])
torch.Size([198, 8])
torch.Size([199, 8])
torch.Size([190, 8])
torch.Size([174, 8])
torch.Size([142, 8])
torch.Size([154, 8])
torch.Size([158, 8])
torch.Size([162, 8])
torch.Size([126, 8])
torch.Size([131, 8])
torch.Size([218, 8])
torch.Size([219, 8])
torch.Size([190, 8])
torch.Size([182, 8])
torch.Size([196, 8])
torch.Size([204, 8])
torch.Size([199, 8])
torch.Size([146, 8])
torch.Size([189, 8])
torch.Size([186, 8])
torch.Size([191, 8])
torch.Size([232, 8])
torch.Size([158, 8])
torch.Size([182, 8])
torch.Size([170, 8])
torch.Size([201, 8])
torch.Size([251, 8])
torch.Size([187, 8])
torch.Size([145, 8])
torch.Size([230, 8])
torch.Size([1

In [6]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.backends.cudnn.enabled=False

In [7]:
import random
from typing import Tuple

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor

In [8]:

class Encoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers, p):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, bidirectional=True)

        self.fc_hidden = nn.Linear(hidden_size * 2, hidden_size)
        self.fc_cell = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(p)

    def forward(self, x):
        # x: (seq_length, N) where N is batch size

        embedding = self.dropout(self.embedding(x))
        # embedding shape: (seq_length, N, embedding_size)

        encoder_states, (hidden, cell) = self.rnn(embedding)
        # outputs shape: (seq_length, N, hidden_size)

        # Use forward, backward cells and hidden through a linear layer
        # so that it can be input to the decoder which is not bidirectional
        # Also using index slicing ([idx:idx+1]) to keep the dimension
        hidden = self.fc_hidden(torch.cat((hidden[0:1], hidden[1:2]), dim=2))
        cell = self.fc_cell(torch.cat((cell[0:1], cell[1:2]), dim=2))

        return encoder_states, hidden, cell


class Decoder(nn.Module):
    def __init__(
        self, input_size, embedding_size, hidden_size, output_size, num_layers, p
    ):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.LSTM(hidden_size * 2 + embedding_size, hidden_size, num_layers)

        self.energy = nn.Linear(hidden_size * 3, 1)
        self.fc = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p)
        self.softmax = nn.Softmax(dim=0)
        self.relu = nn.ReLU()

    def forward(self, x, encoder_states, hidden, cell):
        x = x.unsqueeze(0)
        # x: (1, N) where N is the batch size

        embedding = self.dropout(self.embedding(x))
        # embedding shape: (1, N, embedding_size)

        sequence_length = encoder_states.shape[0]
        h_reshaped = hidden.repeat(sequence_length, 1, 1)
        # h_reshaped: (seq_length, N, hidden_size*2)

        energy = self.relu(self.energy(torch.cat((h_reshaped, encoder_states), dim=2)))
        # energy: (seq_length, N, 1)

        attention = self.softmax(energy)
        # attention: (seq_length, N, 1)

        # attention: (seq_length, N, 1), snk
        # encoder_states: (seq_length, N, hidden_size*2), snl
        # we want context_vector: (1, N, hidden_size*2), i.e knl
        context_vector = torch.einsum("snk,snl->knl", attention, encoder_states)

        rnn_input = torch.cat((context_vector, embedding), dim=2)
        # rnn_input: (1, N, hidden_size*2 + embedding_size)

        outputs, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        # outputs shape: (1, N, hidden_size)

        predictions = self.fc(outputs).squeeze(0)
        # predictions: (N, hidden_size)

        return predictions, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_force_ratio=0.5):
        batch_size = source.shape[1]
        target_len = target.shape[0]
        target_vocab_size = len(solo_field.vocab)

        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)
        encoder_states, hidden, cell = self.encoder(source)

        # First input will be <SOS> token
        x = target[0]

        for t in range(1, target_len):
            # At every time step use encoder_states and update hidden, cell
            output, hidden, cell = self.decoder(x, encoder_states, hidden, cell)

            # Store prediction for current time step
            outputs[t] = output

            # Get the best word the Decoder predicted (index in the vocabulary)
            best_guess = output.argmax(1)

            # With probability of teacher_force_ratio we take the actual next word
            # otherwise we take the word that the Decoder predicted it to be.
            # Teacher Forcing is used so that the model gets used to seeing
            # similar inputs at training and testing time, if teacher forcing is 1
            # then inputs at test time might be completely different than what the
            # network is used to. This was a long comment.
            x = target[t] if random.random() < teacher_force_ratio else best_guess

        return outputs

In [9]:


INPUT_DIM = len(intro_field.vocab)
OUTPUT_DIM = len(solo_field.vocab)
OUTPUT_SIZE = len(solo_field.vocab)
# ENC_EMB_DIM = 256
# DEC_EMB_DIM = 256
# ENC_HID_DIM = 512
# DEC_HID_DIM = 512
# ATTN_DIM = 64
# ENC_DROPOUT = 0.5
# DEC_DROPOUT = 0.5

ENC_EMB_DIM = 300
DEC_EMB_DIM = 300
ENC_HID_DIM = 256
DEC_HID_DIM = 256
ATTN_DIM = 1024
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
NUM_LAYERS = 1

enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, NUM_LAYERS, ENC_DROPOUT)

#attn = Attention(ENC_HID_DIM, DEC_HID_DIM, ATTN_DIM)

dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, DEC_HID_DIM, OUTPUT_SIZE, NUM_LAYERS, DEC_DROPOUT)

model = Seq2Seq(enc, dec)
# model = nn.DataParallel(model.to(device))
model = model.to(device)

In [10]:
def init_weights(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)


model.apply(init_weights)

optimizer = optim.Adam(model.parameters())


def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 2,682,286 trainable parameters


In [11]:
# stoi input str get int
# intro_field.vocab.stoi
# itos input into get token/str
# intro_field.vocab.itos[4]

In [12]:
PAD_IDX = 1

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
#criterion = nn.CrossEntropyLoss()

In [13]:
import math
import time


def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()

    epoch_loss = 0

    #for _, (src, _,trg,_) in enumerate(iterator):
    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
        src, trg = intro.transpose(1,0), solo.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        optimizer.zero_grad()

        output = model(src, trg)
        
#         print(output.size())
#         print(trg.size())
        
        output = output[1:].view(-1, output.shape[-1])
        trg = trg[1:].reshape(-1)

        loss = criterion(output, trg)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             criterion: nn.Module):

    model.eval()

    epoch_loss = 0

    with torch.no_grad():

        #for _, (src, _,trg,_) in enumerate(iterator):
        for ((intro, intro_len), (intro_piano, intro_piano_len),\
         (outro, outro_len),(outro_piano, outro_piano_len),\
         (solo, solo_len),(solo_piano, solo_piano_len)), _ in (iterator):
            src, trg = intro.transpose(1,0), solo.transpose(1,0)
            src, trg = src.to(device), trg.to(device)

            output = model(src, trg, 0) #turn off teacher forcing

            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].reshape(-1)

            loss = criterion(output, trg)

            epoch_loss += loss.cpu().detach().item()

    return epoch_loss / len(iterator)


def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


N_EPOCHS = 20
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss = train(model, train_iter, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iter, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

test_loss = evaluate(model, test_iter, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 0m 32s
	Train Loss: 4.518 | Train PPL:  91.656
	 Val. Loss: 4.398 |  Val. PPL:  81.292
Epoch: 02 | Time: 0m 31s
	Train Loss: 4.346 | Train PPL:  77.176
	 Val. Loss: 4.355 |  Val. PPL:  77.898
Epoch: 03 | Time: 0m 31s
	Train Loss: 4.239 | Train PPL:  69.333
	 Val. Loss: 4.338 |  Val. PPL:  76.523
Epoch: 04 | Time: 0m 31s
	Train Loss: 3.638 | Train PPL:  38.025
	 Val. Loss: 5.339 |  Val. PPL: 208.363
Epoch: 05 | Time: 0m 31s
	Train Loss: 3.376 | Train PPL:  29.243
	 Val. Loss: 5.647 |  Val. PPL: 283.518
Epoch: 06 | Time: 0m 31s
	Train Loss: 3.286 | Train PPL:  26.733
	 Val. Loss: 5.902 |  Val. PPL: 365.599
Epoch: 07 | Time: 0m 31s
	Train Loss: 3.257 | Train PPL:  25.959
	 Val. Loss: 6.105 |  Val. PPL: 448.038
Epoch: 08 | Time: 0m 31s
	Train Loss: 3.204 | Train PPL:  24.633
	 Val. Loss: 6.322 |  Val. PPL: 556.826
Epoch: 09 | Time: 0m 31s
	Train Loss: 3.181 | Train PPL:  24.066
	 Val. Loss: 6.497 |  Val. PPL: 663.107
Epoch: 10 | Time: 0m 31s
	Train Loss: 3.165 | Train PPL

In [None]:
model.eval()

with torch.no_grad():

    for ((intro, intro_len), (intro_piano, intro_piano_len),\
     (outro, outro_len),(outro_piano, outro_piano_len),\
     (solo, solo_len),(solo_piano, solo_piano_len)), _ in (test_iter):
        src, trg = intro.transpose(1,0), solo.transpose(1,0)
        src, trg = src.to(device), trg.to(device)

        output = model(src, trg, 0) #turn off teacher forcing
        orig = src[4].cpu().detach().tolist()
        test = torch.argmax(output[1:], dim=2)[4].cpu().detach().tolist()
        
        ouput_t = output.transpose(1,0)
        #print(trg[1:].size())
        for i in range(len(torch.argmax(ouput_t, dim=2))):
            midi_series = torch.argmax(ouput_t, dim=2)[i].tolist()
            #print(midi_series)
            #utils.write_midi(midi_series, word2event, "test" + str(i) + ".midi")
            
#         output = output[1:].view(-1, output.shape[-1])
#         trg = trg[1:].view(-1)
        #print(output)

In [14]:
def translate_sentence(model, sentence, german, english, device, max_length=50):

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    tokens = [token.lower() for token in sentence.split(' ')]
    # print(tokens)

    # sys.exit()
    # Add <SOS> and <EOS> in beginning and end respectively
#     tokens.insert(0, german.init_token)
#     tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    # Build encoder hidden, cell state
    with torch.no_grad():
        encoder_states, hidden, cell = model.encoder(sentence_tensor)

                
    outputs = [english.vocab.stoi["0"]]

    for _ in range(max_length):
        previous_word = torch.LongTensor([outputs[-1]]).to(device)
        with torch.no_grad():
            output, hidden, cell = model.decoder(previous_word, encoder_states, hidden, cell)
            best_guess = output.argmax(1).item()
        outputs.append(best_guess)

        # Model predicts it's the end of the sentence
#         if output.argmax(1).item() == english.vocab.stoi["<eos>"]:
#             break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]

    # remove start token
    return translated_sentence


In [19]:
df_intro = pd.read_csv(source_folder + '/train_torchtext.csv')
test_intro = df_intro['intro'].values
print(len(test_intro))

526


In [20]:
for i in range(len(test_intro)):
    sentence = test_intro[i]
    translated_sentence = translate_sentence(model, sentence, intro_field, solo_field, device, max_length=50)
    translated_sentence = [int(x) for x in translated_sentence if x != '<pad>' or x != '<eos>' or x != '<sos>']
    #print(sentence)
    print(translated_sentence)
    #utils.write_midi(utils.remove_padding(translated_sentence,word2event), word2event, "test" + str(i)  + ".mid")

[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 10, 34, 42, 31, 30, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]
[0, 0, 1, 2, 87, 1, 60, 12, 31, 10, 60, 12, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 60, 42, 31, 10, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34, 42, 31, 14, 34]


In [22]:
sentence = "0 1 2 155 29 4 51 27 10 4 51 25 11 4 5 25 14 21 5 25 15 4 63 31 33 4 63 13 20 19 65 6 0 1 4 63 31 28 4 5 9 10 16 51 25 11 4 63 25 14 4 63 25 15 4 65 25 18 21 8 31 24 19 65 27 26 19 65 31 0 1 41 63 31 28 16 5 27 7 4 65 27 29 4 86 31 10 4 65 31 37 19 92 31 11 21 86 88 0 29 52 86 27 10 4 86 25 11 16 92 23 15 4 63 108 0 7 4 5 25 10 4 5 6 11 21 63 31 30 21 65 58"
translated_sentence = translate_sentence(model, sentence, intro_field, solo_field, device, max_length=50)
translated_sentence = [int(x) for x in translated_sentence if x != '<pad>']
print(translated_sentence)

[0, 1, 2, 1, 2, 28, 60, 42, 31, 10, 71, 42, 31, 10, 60, 42, 31, 10, 2, 12, 31, 10, 16, 51, 31, 10, 16, 51, 31, 10, 16, 51, 31, 10, 16, 51, 31, 10, 16, 51, 31, 10, 16, 51, 31, 10, 16, 51, 31, 10, 16]


In [None]:
print(utils.remove_padding(translated_sentence,word2event))

In [None]:
utils.write_midi(utils.remove_padding(translated_sentence,word2event), word2event, "test.mid")

In [None]:
class LSTM(nn.Module):

    def __init__(self, dimension=64):
        super(LSTM, self).__init__()

        self.embedding = nn.Embedding(len(text_field.vocab), 300)
        self.dimension = dimension
        self.lstm = nn.LSTM(input_size=300,
                            hidden_size=dimension,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)
        self.drop = nn.Dropout(p=0.5)

        self.fc = nn.Linear(2*dimension, 1)

    def forward(self, notes, notes_len):
        notes_emb = self.embedding(notes)

        packed_input = pack_padded_sequence(notes_emb, notes_len, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)

        out_forward = output[range(len(output)), notes_len - 1, :self.dimension]
        out_reverse = output[:, 0, self.dimension:]
        out_reduced = torch.cat((out_forward, out_reverse), 1)
        notes_fea = self.drop(out_reduced)

        notes_fea = self.fc(notes_fea)
        notes_fea = torch.squeeze(notes_fea, 1)
        notes_out = torch.sigmoid(notes_fea)

        return notes_out

In [None]:
# Save and Load Functions https://towardsdatascience.com/lstm-text-classification-using-pytorch-2c6c657f8fc0

def save_checkpoint(save_path, model, optimizer, valid_loss):

    if save_path == None:
        return
    
    state_dict = {'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'valid_loss': valid_loss}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')


def load_checkpoint(load_path, model, optimizer):

    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    optimizer.load_state_dict(state_dict['optimizer_state_dict'])
    
    return state_dict['valid_loss']


def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):

    if save_path == None:
        return
    
    state_dict = {'train_loss_list': train_loss_list,
                  'valid_loss_list': valid_loss_list,
                  'global_steps_list': global_steps_list}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')


def load_metrics(load_path):

    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']


In [None]:
# Training Function

def train(model,
          optimizer,
          criterion = nn.BCELoss(),
          train_loader = train_iter,
          valid_loader = valid_iter,
          num_epochs = 10,
          eval_every = len(train_iter) // 2,
          file_path = destination_folder,
          best_valid_loss = float("Inf")):
    
    # initialize running values
    running_loss = 0.0
    valid_running_loss = 0.0
    global_step = 0
    train_loss_list = []
    valid_loss_list = []
    global_steps_list = []

    # training loop
    model.train()
    for epoch in range(num_epochs):
        total = 0
        total_correct = 0
        for (labels, (notes, notes_len)), _ in (train_loader):           
            labels = labels.to(device)
            notes = notes.to(device)
            notes_len = notes_len.cpu()
            output = model(notes.long(), notes_len.long())

            loss = criterion(output, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            labels_max = labels.detach().cpu()
            output_max = torch.round(output.detach().cpu())

            for i in range(len(labels_max)):
                total+=1
                if labels_max[i] ==  output_max[i]:
                    total_correct += 1
            accuracy = accuracy_score(labels_max, output_max)
            
            # update running values
            running_loss += loss.item()
            global_step += 1

            # evaluation step
            if global_step % eval_every == 0:
                model.eval()
                with torch.no_grad():                    
                  # validation loop
                    for (labels, (notes, notes_len)), _ in (valid_loader):
                        labels = labels.to(device)
                        notes = notes.to(device)
                        notes_len = notes_len.cpu()
                        output = model(notes.long(), notes_len.long())
                        loss = criterion(output, labels.float())
                        valid_running_loss += loss.item()

                # evaluation
                average_train_loss = running_loss / eval_every
                average_valid_loss = valid_running_loss / len(valid_loader)
                train_loss_list.append(average_train_loss)
                valid_loss_list.append(average_valid_loss)
                global_steps_list.append(global_step)

                # resetting running values
                running_loss = 0.0                
                valid_running_loss = 0.0
                model.train()

                # print progress
                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
                      .format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader),
                              average_train_loss, average_valid_loss))
                
                # checkpoint
                if best_valid_loss > average_valid_loss:
                    best_valid_loss = average_valid_loss
                    save_checkpoint(file_path + '/model.pt', model, optimizer, best_valid_loss)
                    save_metrics(file_path + '/metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
        print("Epoch Accuracy: {}".format(total_correct/total))
    save_metrics(file_path + '/metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
    print('Finished Training!')


model = LSTM().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train(model=model, optimizer=optimizer, num_epochs=25)

In [None]:
# torch.backends.cudnn.enabled = False

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns

In [None]:
def evaluate(model, test_loader, version='title', threshold=0.5):
    y_pred = []
    y_true = []

    model.eval()
    with torch.no_grad():
        for (labels, (notes, notes_len)), _ in test_loader:           
            labels = labels.to(device)
            notes = notes.to(device)
            notes_len = notes_len.cpu()
            output = model(notes.long(), notes_len.long())

            output = (output > threshold).int()
            y_pred.extend(output.tolist())
            y_true.extend(labels.tolist())
    
    print('Classification Report:')
    print(classification_report(y_true, y_pred, labels=[1,0], digits=4))
    
    cm = confusion_matrix(y_true, y_pred, labels=[1,0])
    ax= plt.subplot()
    sns.heatmap(cm, annot=True, ax = ax, cmap='Blues', fmt="d")

    ax.set_title('Confusion Matrix')

    ax.set_xlabel('Predicted Labels')
    ax.set_ylabel('True Labels')

    ax.xaxis.set_ticklabels(['NON-SOLO', 'SOLO'])
    ax.yaxis.set_ticklabels(['NON-SOLO', 'SOLO'])
    
    
best_model = LSTM().to(device)
optimizer = optim.Adam(best_model.parameters(), lr=0.001)

load_checkpoint(destination_folder + '/model.pt', best_model, optimizer)
evaluate(best_model, test_iter)