In [1]:
# if __name__ == "__main__":

import csv

EPOCH_LOSSES_SAVEPATH = "./models/bilstm_epoch_losses.csv"
BATCH_LOSSES_SAVEPATH = "./models/bilstm_batch_losses.csv"

def save_list_to_csv(lst, filename):
    with open(filename, 'w', newline='') as file:
        wr = csv.writer(file)
        wr.writerow(lst)

def evaluate_batch_mean_average_accuracy(y_truth, y_pred):
    matches = [x == y for (x,y) in zip(y_truth, y_pred)]
    maas = []
    for batch in range(0, len(matches), 10):
        num_correct = 0
        summ = 0
        for i in range(0, 10):
            if matches[batch+i] == 1:
                num_correct += 1
                summ += (num_correct / (i+1)) / 10
        maas.append(summ)
    return sum(maas) / len(maas)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models.BiLSTMEncoderDecoder import BiLSTMEncoderDecoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "./models/bilstm_encoder_decoder_1.pt"

from data.SpotifyDataset import SpotifyDataset

train_set = SpotifyDataset("./data/train_data_20.csv", "./data/track_feats.csv")
val_set = SpotifyDataset("./data/val_data_20.csv", "./data/track_feats.csv")
test_set = SpotifyDataset("./data/test_data_20.csv", "./data/track_feats.csv")

datasets = {"train": train_set,
            "val": val_set,
            "test": test_set}

In [2]:
from torch._six import container_abcs, string_classes, int_classes

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

In [3]:
# import torch
# # import multiprocessing as python_multiprocessing
# # import torch.multiprocessing as multiprocessing
# # from . import IterableDataset, Sampler, SequentialSampler, RandomSampler, BatchSampler
# from . import _utils
# from torch._utils import ExceptionWrapper
# import threading
# import itertools
# from torch._six import queue, string_classesimport torch.utils.data.dataloader

### TRAINING BLOC:
model = BiLSTMEncoderDecoder(encode_size=67, decode_size=29).to(device)
loss_fn = nn.CrossEntropyLoss()

batch_size = 64
learning_rate = 1e-3
weight_decay = 2.5e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay)

num_epochs = 3
best_maa = 0
batch_losses = []
epoch_losses = []

# default_collate = _utils.collate.default_collate

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch+1, num_epochs))
    for phase in ['train', 'val']:
        print("Running {} phase...".format(phase))
        total_maa = []
        total_loss = []
        dataloader = DataLoader(datasets[phase], batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=default_collate)
        for i, (X_encode, X_decode, y) in enumerate(dataloader):
            
            # just testing on small sizes for now
            if i > 400:
                break
                
            if i % 10 == 0:
                print("Calculating batch {} / {}".format(i, len(dataloader)))
            X_encode, X_decode, y = X_encode.to(device), X_decode.to(device), y.to(device)
            if phase == 'train':
                model.train()
                scores = model(X_encode, X_decode).flatten(start_dim=0, end_dim=1)
                labels = y.flatten(start_dim=0, end_dim=1).squeeze()
                loss = loss_fn(scores, labels)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                total_loss.append(loss)
                batch_losses.append(loss.item())
            else:
                model.eval()
                scores = model(X_encode, X_decode).flatten(start_dim=0, end_dim=1)
                labels = y.flatten(start_dim=0, end_dim=1).squeeze()
                y_pred = torch.argmax(scores, dim=1)
                maa = evaluate_batch_mean_average_accuracy(labels, y_pred)
                total_maa.append(maa)
        if phase == 'train':
            epoch_loss = sum(total_loss) / len(total_loss)
            print("Epoch {} Avg. Loss: {}".format(epoch, epoch_loss))
            epoch_losses.append(epoch_loss.item())
        else:
            epoch_maa = sum(total_maa) / len(total_maa)
            print("Epoch {} Avg. MAA: {}, Best MAA: {}".format(epoch, epoch_maa, best_maa))
            if epoch_maa > best_maa:
                torch.save(model.state_dict(), MODEL_PATH)
                best_maa = epoch_maa
    print()
print()
print("List of avg. loss across epochs: ")
print(epoch_losses)
save_list_to_csv(epoch_losses, EPOCH_LOSSES_SAVEPATH)
print("List of losses across batches: ")
print(batch_losses)
save_list_to_csv(batch_losses, BATCH_LOSSES_SAVEPATH)

Epoch 1/3
Running train phase...
Calculating batch 0 / 924
Calculating batch 10 / 924
Calculating batch 20 / 924
Calculating batch 30 / 924
Calculating batch 40 / 924
Calculating batch 50 / 924
Calculating batch 60 / 924
Calculating batch 70 / 924
Calculating batch 80 / 924
Calculating batch 90 / 924
Calculating batch 100 / 924
Calculating batch 110 / 924
Calculating batch 120 / 924
Calculating batch 130 / 924
Calculating batch 140 / 924
Calculating batch 150 / 924
Calculating batch 160 / 924
Calculating batch 170 / 924
Calculating batch 180 / 924
Calculating batch 190 / 924
Calculating batch 200 / 924
Calculating batch 210 / 924
Calculating batch 220 / 924
Calculating batch 230 / 924
Calculating batch 240 / 924
Calculating batch 250 / 924
Calculating batch 260 / 924
Calculating batch 270 / 924
Calculating batch 280 / 924
Calculating batch 290 / 924
Calculating batch 300 / 924
Calculating batch 310 / 924
Calculating batch 320 / 924
Calculating batch 330 / 924
Calculating batch 340 / 92

In [4]:
# TESTING BLOCK
test_model = BiLSTMEncoderDecoder(encode_size=67, decode_size=29).to(device)
test_model.load_state_dict(torch.load(MODEL_PATH))

with torch.no_grad():
    dataloader = DataLoader(datasets["test"], batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=default_collate)
    total_maa = []
    for i, (X_encode, X_decode, labels) in enumerate(dataloader):
#         if i > 40:
#             break
        X_encode, X_decode, y = X_encode.to(device), X_decode.to(device), y.to(device)
        labels = y.flatten(start_dim=0, end_dim=1).squeeze()
        scores = test_model(X_encode, X_decode).flatten(start_dim=0, end_dim=1)
        y_pred = torch.argmax(scores, dim=1)
        maa = evaluate_batch_mean_average_accuracy(labels, y_pred)
        total_maa.append(maa)
    print("Average batch MAA over test set: {}".format(sum(total_maa) / len(total_maa)))


Average batch MAA over test set: 0.40402910915841944
