In [1]:
from data_processing.Datasets import get_maestro_dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import torch.nn as nn
from model.transformer import Transformer
import numpy as np
import random
import muspy
import collections.abc as collections

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# hyperparams
EPOCHS = 10

In [4]:
# load device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Collect Data

In [5]:
train_data, test_data = get_maestro_dataset("data/maestro")

Skip downloading as the `.muspy.success` file is found.
Skip extracting as the `.muspy.success` file is found.
Skip conversion as the `.muspy.success` file is found.


In [6]:
def collate_fn_padd(batch):
    '''
    Pads batch of variable length

    note: it converts things ToTensor manually here since the ToTensor transform
    assume it takes in images rather than arbitrary tensors.
    '''
    ## get sequence lengths
    lengths = torch.tensor([ t.shape[0] for t in batch ]).to(device)
    ## pad (in this case, just shorten)
    batch = [ torch.Tensor(t)[:2048].to(device) for t in batch ]
    batch = torch.nn.utils.rnn.pad_sequence(batch)
    
    ## compute mask
    mask = (batch != 0).to(device)
    return batch, lengths, mask, batch # batch sent in twice for quick usage

batch_size = 20
eval_batch_size = 10

train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    collate_fn=collate_fn_padd,
    shuffle=False,
)

val_dataloader = DataLoader(
    dataset=test_data,
    batch_size=eval_batch_size,
    collate_fn=collate_fn_padd,
    shuffle=False,
)


In [7]:
# create model, optim, criterion

model = Transformer(
    num_tokens=128, dim_model=512, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, dropout_p=0.1
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

## Train

In [8]:
def train_loop(model, opt, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """

    model.train()
    total_loss = 0

    for batch in dataloader:
        X, y = batch[0], batch[3] # look at collate_fn() above
        X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(
            y, dtype=torch.long, device=device
        )

        # shift elements over to collect the next element, as a label
        y_input = y[:-1].permute(1, 0, 2)
        y_expected = y[1:]

        # Get mask to mask out the next words
        sequence_length = y_input.size(0)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        X_inp = X.permute(1, 0, 2)

        # Standard training except we pass in y_input and tgt_mask
        pred = model(X_inp, y_input, tgt_mask)

        # Permute pred to have batch size first again
        pred = pred.permute(1, 2, 0)
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.detach().item()

    return total_loss / len(dataloader)


def validation_loop(model, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """

    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[0], batch[3]
            X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(
                y, dtype=torch.long, device=device
            )

            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = y[:-1]
            y_expected = y[1:]

            # Get mask to mask out the next words
            sequence_length = y_input.size(0)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

            # Standard training except we pass in y_input and src_mask
            pred = model(X, y_input, tgt_mask)

            # Permute pred to have batch size first again
            pred = pred.permute(1, 2, 0)
            loss = loss_fn(pred, y_expected)
            total_loss += loss.detach().item()

    return total_loss / len(dataloader)


def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """

    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []

    print("Training and validating model")
    for epoch in range(epochs):
        print("-" * 25, f"Epoch {epoch + 1}", "-" * 25)

        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]

        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]

        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()

    return train_loss_list, validation_loss_list


train_loss_list, validation_loss_list = fit(
    model, opt, loss_fn, train_dataloader, val_dataloader, 10
)


Training and validating model
------------------------- Epoch 1 -------------------------


  X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(


torch.Size([20, 2048, 128, 512])


KeyboardInterrupt: 