In [1]:
import torch
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torchaudio
from SongDataset import SongDataset
from TranscriptionModel import GuitarModel
from torch import nn
import os
import logging
import copy
import time
import math
from timeit import default_timer as timer
from tqdm.auto import tqdm


# check if tensorflow is working correctly
plt.rcParams['figure.figsize'] = [12, 8]
print(f"Cuda : {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm


Cuda : True


In [5]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SAMPLE_RATE = 44100

In [7]:
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=2048,
    hop_length=512,
    n_mels=128
)
dataset = SongDataset("test.hdf5", mel_spectrogram, sampleRate=SAMPLE_RATE)


In [9]:
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

In [None]:
model = GuitarModel((BATCH_SIZE, 2, 128, 87),
                        emb_size=EMB_SIZE,
                        num_encoder_layers=NUM_ENCODER_LAYERS,
                        num_decoder_layers=NUM_DECODER_LAYERS,
                        multi_head_attention_size=NHEAD,
                        dim_feedforward=FFN_HID_DIM,
                        tgt_vocab_size=dataset.vocabSize)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
transformer = model.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=dataset.pad_token)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
class CheckpointSaver:
    def __init__(self, dirpath, decreasing=True, top_n=5):
        """
        dirpath: Directory path where to store all model weights
        decreasing: If decreasing is `True`, then lower metric is better
        top_n: Total number of models to track based on validation metric value
        """
        if not os.path.exists(dirpath): os.makedirs(dirpath)
        self.dirpath = dirpath
        self.top_n = top_n
        self.decreasing = decreasing
        self.top_model_paths = []
        self.best_metric_val = np.Inf if decreasing else -np.Inf

    def __call__(self, model, epoch, metric_val):
        model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}.pt')
        save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val
        if save:
            logging.info(f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}")
            self.best_metric_val = metric_val
            torch.save(model.state_dict(), model_path)
            self.top_model_paths.append({'path': model_path, 'score': metric_val})
            self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
        if len(self.top_model_paths)>self.top_n:
            self.cleanup()

    def cleanup(self):
        to_remove = self.top_model_paths[self.top_n:]
        logging.info(f"Removing extra models.. {to_remove}")
        for o in to_remove:
            os.remove(o['path'])
        self.top_model_paths = self.top_model_paths[:self.top_n]


In [None]:
train_set, val_set = torch.utils.data.random_split(dataset, [0.9,0.1], generator=torch.Generator().manual_seed(42))

In [None]:
def train_epoch(model, optimizer, epoch):
    model.train()
    losses = 0
    train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE,shuffle=True,num_workers=4)

    for src, tgt in tqdm(train_dataloader,desc=f"Epoch {epoch}"):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        #TODO DO whats the logits thing
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model,epoch,checkpoint_saver):
    model.eval()
    losses = 0

    val_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE,shuffle=False,num_workers=4)
    for src, tgt in tqdm(val_dataloader,desc=f"Eval {epoch}"):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    losses = losses/len(val_dataloader)
    checkpoint_saver(model, epoch, losses)
    return losses / len(val_dataloader)


In [None]:
NUM_EPOCHS = 18
checkpoint_saver = CheckpointSaver(dirpath='./model_weights', decreasing=True, top_n=1)
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer,epoch)
    end_time = timer()
    val_loss = evaluate(transformer,epoch,checkpoint_saver)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s")