In [1]:
from Datasets import TokenizedChromaDataset
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from transformer.models import TransformerFromModels, EncoderModel, DecoderModel
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os

In [2]:
npz_path = '../data/augmented_and_padded_data.npz'
dataset = TokenizedChromaDataset(npz_path)

train_percentage = 0.9
split_idx = int( len(dataset)*train_percentage )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 16
epochs = 1000

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [3]:
src_vocab_size = 2**12
tgt_vocab_size = 2**12
d_model = 128
num_heads = 4
num_layers = 4
d_ff = 128
max_seq_length = 129
dropout = 0.3

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoderModel = EncoderModel(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
decoderModel = DecoderModel(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

encoderModel = encoderModel.to(dev)
decoderModel = decoderModel.to(dev)

transformer = TransformerFromModels(encoderModel, decoderModel)

transformer = transformer.to(dev)

In [None]:
criterion = CrossEntropyLoss(ignore_index=0)
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# keep best validation loss for saving
best_val_loss = np.inf
save_dir = '../saved_models/encoder_decoder_one_hot/'
encoder_path = save_dir + 'encoder_one_hot.pt'
decoder_path = save_dir + 'decoder_one_hot.pt'
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies, chords[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            prediction = output.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction == chords[:, 1:]).sum().item()/prediction.shape[1]
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    # validation
    with torch.no_grad():
        val_loss = 0
        running_loss = 0
        samples_num = 0
        running_accuracy = 0
        accuracy = 0
        print('validation...')
        for melodies, chords in test_loader:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            output = transformer(melodies, chords[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:].contiguous().view(-1))
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            val_loss = running_loss/samples_num
            # accuracy
            prediction = output.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction == chords[:, 1:]).sum().item()/prediction.shape[1]
            accuracy = running_accuracy/samples_num
        if best_val_loss > val_loss:
            print('saving!')
            best_val_loss = val_loss
            torch.save(encoderModel.state_dict(), encoder_path)
            torch.save(decoderModel.state_dict(), decoder_path)
        print(f'validation: accuracy={accuracy}, loss={val_loss}')
    # print(f"Epoch: {epoch+1}, training loss: {train_loss} | validation loss {val_loss}")

Epoch 0 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.56batch/s, accuracy=0.00523, loss=0.413]


validation...
saving!
validation: accuracy=0.012077861163227017, loss=0.36242847907833936


Epoch 1 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.52batch/s, accuracy=0.0199, loss=0.341]


validation...
saving!
validation: accuracy=0.027702861163227017, loss=0.3311176067445336


Epoch 2 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:29<00:00, 10.30batch/s, accuracy=0.0372, loss=0.315]


validation...
saving!
validation: accuracy=0.03959017354596623, loss=0.3105406510673365


Epoch 3 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:29<00:00, 10.34batch/s, accuracy=0.0475, loss=0.297]


validation...
saving!
validation: accuracy=0.044837593808630394, loss=0.2967490049508902


Epoch 4 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:29<00:00, 10.34batch/s, accuracy=0.0528, loss=0.285]


validation...
saving!
validation: accuracy=0.04923487335834897, loss=0.286071726797222


Epoch 5 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.35batch/s, accuracy=0.0566, loss=0.276]


validation...
saving!
validation: accuracy=0.05257680581613509, loss=0.27880688053582


Epoch 6 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.35batch/s, accuracy=0.0601, loss=0.268]


validation...
saving!
validation: accuracy=0.05495133677298311, loss=0.27115610527052886


Epoch 7 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:29<00:00, 10.34batch/s, accuracy=0.0632, loss=0.262]


validation...
saving!
validation: accuracy=0.05802943245778611, loss=0.2669186779974102


Epoch 8 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.35batch/s, accuracy=0.0656, loss=0.256]


validation...
saving!
validation: accuracy=0.05898217636022514, loss=0.2616371947426286


Epoch 9 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 300/300 [00:34<00:00,  8.65batch/s, accuracy=0.068, loss=0.251]


validation...
saving!
validation: accuracy=0.06167917448405253, loss=0.2575103113843621


Epoch 10 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [01:02<00:00,  4.78batch/s, accuracy=0.0703, loss=0.247]


validation...
saving!
validation: accuracy=0.06436151500938087, loss=0.2546437882571909


Epoch 11 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [00:29<00:00, 10.34batch/s, accuracy=0.0724, loss=0.243]


validation...
saving!
validation: accuracy=0.06544617729831144, loss=0.24998085360142347


Epoch 12 | trn: 100%|███████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.35batch/s, accuracy=0.0741, loss=0.24]


validation...
saving!
validation: accuracy=0.06697056754221388, loss=0.2471988474003146


Epoch 13 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:28<00:00, 10.35batch/s, accuracy=0.0758, loss=0.237]


validation...
saving!
validation: accuracy=0.06924249530956848, loss=0.24456376877332048


Epoch 14 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:31<00:00,  9.66batch/s, accuracy=0.0773, loss=0.234]


validation...
saving!
validation: accuracy=0.07098674953095685, loss=0.2422829640515526


Epoch 15 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████████| 300/300 [01:09<00:00,  4.34batch/s, accuracy=0.0784, loss=0.231]


validation...
saving!
validation: accuracy=0.07095743433395872, loss=0.2416520736007261


Epoch 16 | trn:  84%|██████████████████████████████████████████████████████████████████▉             | 251/300 [00:24<00:04, 10.25batch/s, accuracy=0.0797, loss=0.229]

In [None]:
# to load the model
encoderModel = EncoderModel(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
decoderModel = DecoderModel(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

encoderModel.load_state_dict(torch.load(encoder_dir))
decoderModel.load_state_dict(torch.load(decoder_dir))

transformer = TransformerFromModels(encoderModel, decoderModel)
# for inference, not retraining, we need to run
transformer.eval()