In [1]:
from Datasets import BinChromaDataset
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from transformer.models import EncoderOnlyWrapper, EncoderModel
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os

In [2]:
npz_path = '../data/augmented_and_padded_data.npz'
dataset = BinChromaDataset(npz_path)

train_percentage = 0.9
split_idx = int( len(dataset)*train_percentage )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 8
epochs = 1000

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [3]:
src_vocab_size = 2**12
tgt_vocab_size = 2**12
d_model = 512
num_heads = 8
num_layers = 8
d_ff = 1024
max_seq_length = 129
dropout = 0.3

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoderModel = EncoderModel(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

transformer = EncoderOnlyWrapper(encoderModel)

transformer = transformer.to(dev)

In [None]:
criterion = CrossEntropyLoss(ignore_index=0)
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# keep best validation loss for saving
best_val_loss = np.inf
save_dir = '../saved_models/encoderOnly_one_hot/'
encoder_path = save_dir + 'encoderOnly_one_hot.pt'
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies)
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords.contiguous().view(-1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            prediction = output.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction == chords).sum().item()/prediction.shape[1]
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    # validation
    with torch.no_grad():
        val_loss = 0
        running_loss = 0
        samples_num = 0
        running_accuracy = 0
        accuracy = 0
        print('validation...')
        for melodies, chords in test_loader:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            output = transformer(melodies)
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords.contiguous().view(-1))
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            val_loss = running_loss/samples_num
            # accuracy
            prediction = output.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction == chords).sum().item()/prediction.shape[1]
            accuracy = running_accuracy/samples_num
        if best_val_loss > val_loss:
            print('saving!')
            best_val_loss = val_loss
            torch.save(transformer.state_dict(), encoder_path)
        print(f'validation: accuracy={accuracy}, loss={val_loss}')

Epoch 0 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.33batch/s, accuracy=0.0169, loss=0.68]


validation...
saving!
validation: accuracy=0.03129863141207443, loss=0.5990776031594339


Epoch 1 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:13<00:00,  3.10batch/s, accuracy=0.0359, loss=0.587]


validation...
saving!
validation: accuracy=0.04182846837412919, loss=0.5614077016068221


Epoch 2 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.0438, loss=0.554]


validation...
saving!
validation: accuracy=0.04511540643134517, loss=0.5431299281165032


Epoch 3 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:06<00:00,  3.22batch/s, accuracy=0.0481, loss=0.532]


validation...
saving!
validation: accuracy=0.04694794711811163, loss=0.5340269694408825


Epoch 4 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:07<00:00,  3.20batch/s, accuracy=0.0521, loss=0.515]


validation...
saving!
validation: accuracy=0.04998763762235117, loss=0.5226395926377115


Epoch 5 | trn: 100%|███████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.0558, loss=0.5]


validation...
saving!
validation: accuracy=0.05251828904693342, loss=0.51339879044896


Epoch 6 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:06<00:00,  3.22batch/s, accuracy=0.0593, loss=0.486]


validation...
saving!
validation: accuracy=0.05510711636633362, loss=0.5088100965653755


Epoch 7 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:04<00:00,  3.26batch/s, accuracy=0.0629, loss=0.473]


validation...
saving!
validation: accuracy=0.05465625318149427, loss=0.508454122418087


Epoch 8 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.0667, loss=0.46]


validation...
saving!
validation: accuracy=0.05494713265558416, loss=0.5049541739093429


Epoch 9 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:06<00:00,  3.22batch/s, accuracy=0.0697, loss=0.449]


validation...
saving!
validation: accuracy=0.05661968963160112, loss=0.5019567732963061


Epoch 10 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [02:17<00:00,  4.35batch/s, accuracy=0.0736, loss=0.436]


validation...
validation: accuracy=0.05499076457669766, loss=0.5047033821664205


Epoch 11 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [02:54<00:00,  3.44batch/s, accuracy=0.0768, loss=0.425]


validation...
validation: accuracy=0.05640153002603372, loss=0.5057322124602871


Epoch 12 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 600/600 [03:04<00:00,  3.24batch/s, accuracy=0.08, loss=0.415]


validation...
validation: accuracy=0.05542708378783248, loss=0.504528004054057


Epoch 13 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.0839, loss=0.403]


validation...
validation: accuracy=0.05501985252410663, loss=0.507551374846954


Epoch 14 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [03:05<00:00,  3.24batch/s, accuracy=0.0877, loss=0.392]


validation...
validation: accuracy=0.054641709207789727, loss=0.5093927217916521


Epoch 15 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.0917, loss=0.381]


validation...
validation: accuracy=0.05436537370740434, loss=0.5126664003332829


Epoch 16 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [03:03<00:00,  3.27batch/s, accuracy=0.0955, loss=0.369]


validation...
validation: accuracy=0.05373998283811103, loss=0.5201241021755713


Epoch 17 | trn: 100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.35batch/s, accuracy=0.0996, loss=0.358]


validation...
validation: accuracy=0.05352182323254361, loss=0.5213316529150528


Epoch 18 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:04<00:00,  3.26batch/s, accuracy=0.104, loss=0.347]


validation...
validation: accuracy=0.05426356589147287, loss=0.5244351180364669


Epoch 19 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [02:43<00:00,  3.66batch/s, accuracy=0.108, loss=0.337]


validation...
validation: accuracy=0.05304187210029525, loss=0.5307106385758849


Epoch 20 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [02:38<00:00,  3.78batch/s, accuracy=0.112, loss=0.326]


validation...
validation: accuracy=0.05260555288916037, loss=0.5318418068018014


Epoch 21 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:05<00:00,  3.24batch/s, accuracy=0.116, loss=0.316]


validation...
validation: accuracy=0.05173291446689065, loss=0.5392338160502307


Epoch 22 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.35batch/s, accuracy=0.12, loss=0.306]


validation...
validation: accuracy=0.05200924996727607, loss=0.5398114215142285


Epoch 23 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:04<00:00,  3.26batch/s, accuracy=0.124, loss=0.295]


validation...
validation: accuracy=0.05120933141352881, loss=0.5493320627015706


Epoch 24 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.128, loss=0.286]


validation...
validation: accuracy=0.05061302849164451, loss=0.553057055983266


Epoch 25 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:02<00:00,  3.29batch/s, accuracy=0.132, loss=0.276]


validation...
validation: accuracy=0.051092979623892844, loss=0.5582004474654206


Epoch 26 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.34batch/s, accuracy=0.137, loss=0.267]


validation...
validation: accuracy=0.049158631121194944, loss=0.5634273722292559


Epoch 27 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 600/600 [03:01<00:00,  3.30batch/s, accuracy=0.14, loss=0.258]


validation...
validation: accuracy=0.0499730936486467, loss=0.5710329289284254


Epoch 28 | trn: 100%|██████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.35batch/s, accuracy=0.144, loss=0.25]


validation...
validation: accuracy=0.050002181596055684, loss=0.5736042419920272


Epoch 29 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [03:02<00:00,  3.29batch/s, accuracy=0.148, loss=0.241]


validation...
validation: accuracy=0.04889683959451403, loss=0.5841495310835275


Epoch 30 | trn: 100%|█████████████████████████████████████████████████████████████████████████████████| 600/600 [02:18<00:00,  4.35batch/s, accuracy=0.152, loss=0.232]


validation...
validation: accuracy=0.04902773535785447, loss=0.5927000958297758


Epoch 31 | trn:  23%|███████████████████▏                                                              | 140/600 [00:32<01:45,  4.35batch/s, accuracy=0.16, loss=0.217]