In [1]:
from Datasets import BinChromaDataset
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from transformer.models import ContinuousEncoderOnlyWrapper, ContinuousEncoder
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os

In [2]:
npz_path = '../data/augmented_and_padded_data.npz'
dataset = BinChromaDataset(npz_path)

train_percentage = 0.9
split_idx = int( len(dataset)*train_percentage )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 8
epochs = 1000

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [3]:
src_vocab_size = 12
tgt_vocab_size = 12
d_model = 128
num_heads = 16
num_layers = 4
d_ff = 128
max_seq_length = 129
dropout = 0.3

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoderModel = ContinuousEncoder(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

transformer = ContinuousEncoderOnlyWrapper(encoderModel)

transformer = transformer.to(dev)

In [None]:
criterion = CrossEntropyLoss()
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# keep best validation loss for saving
best_val_loss = np.inf
save_dir = '../saved_models/encoderOnly_one_hot/'
encoder_path = save_dir + 'encoderOnly_one_hot.pt'
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies)
            # output = transformer(chords) # identity check
            loss = criterion(output.contiguous().view(-1), chords.contiguous().view(-1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    # validation
    with torch.no_grad():
        val_loss = 0
        running_loss = 0
        samples_num = 0
        running_accuracy = 0
        accuracy = 0
        print('validation...')
        for melodies, chords in test_loader:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            output = transformer(melodies)
            # output = transformer(chords) # identity check
            loss = criterion(output.contiguous().view(-1), chords.contiguous().view(-1))
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            val_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
        if best_val_loss > val_loss:
            print('saving!')
            best_val_loss = val_loss
            torch.save(transformer.state_dict(), encoder_path)
        print(f'validation: accuracy={accuracy}, loss={val_loss}')

Epoch 0 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.25batch/s, accuracy=tensor(0.0782, device='cuda:0'), loss=1.33e+3]


validation...
saving!
validation: accuracy=0.08267159014940262, loss=1289.0719907950281


Epoch 1 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.94batch/s, accuracy=tensor(0.0841, device='cuda:0'), loss=1.28e+3]


validation...
saving!
validation: accuracy=0.08695732057094574, loss=1277.9007937089586


Epoch 2 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.45batch/s, accuracy=tensor(0.0857, device='cuda:0'), loss=1.28e+3]


validation...
saving!
validation: accuracy=0.08658645302057266, loss=1271.5815548780488


Epoch 3 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.41batch/s, accuracy=tensor(0.0863, device='cuda:0'), loss=1.27e+3]


validation...
saving!
validation: accuracy=0.0871955156326294, loss=1268.852763506977


Epoch 4 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.35batch/s, accuracy=tensor(0.0865, device='cuda:0'), loss=1.27e+3]


validation...
saving!
validation: accuracy=0.08791286498308182, loss=1268.167519861046


Epoch 5 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:37<00:00, 16.15batch/s, accuracy=tensor(0.0867, device='cuda:0'), loss=1.27e+3]


validation...
saving!
validation: accuracy=0.08794448524713516, loss=1265.9332538769347


Epoch 6 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.92batch/s, accuracy=tensor(0.0868, device='cuda:0'), loss=1.27e+3]


validation...
saving!
validation: accuracy=0.08682024478912354, loss=1265.2079226152086


Epoch 7 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.84batch/s, accuracy=tensor(0.0869, device='cuda:0'), loss=1.27e+3]


validation...
saving!
validation: accuracy=0.08755327761173248, loss=1263.7257242685857


Epoch 8 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:37<00:00, 16.12batch/s, accuracy=tensor(0.0870, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08721621334552765, loss=1263.2810567029198


Epoch 9 | trn: 100%|██████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.42batch/s, accuracy=tensor(0.0871, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08673519641160965, loss=1261.8820255701805


Epoch 10 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.53batch/s, accuracy=tensor(0.0871, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08813899010419846, loss=1262.3799698419912


Epoch 11 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.53batch/s, accuracy=tensor(0.0872, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08784342557191849, loss=1260.6399773173664


Epoch 12 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.43batch/s, accuracy=tensor(0.0873, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08782415091991425, loss=1260.0176257621952


Epoch 13 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.43batch/s, accuracy=tensor(0.0873, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08784342557191849, loss=1260.8740747390948


Epoch 14 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.41batch/s, accuracy=tensor(0.0874, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08750090003013611, loss=1259.3652307106004


Epoch 15 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 16.10batch/s, accuracy=tensor(0.0874, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08788159489631653, loss=1258.732744342167


Epoch 16 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.72batch/s, accuracy=tensor(0.0875, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08685224503278732, loss=1258.625801587418


Epoch 17 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.92batch/s, accuracy=tensor(0.0874, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08778814226388931, loss=1259.2484618682574


Epoch 18 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.93batch/s, accuracy=tensor(0.0875, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08786016702651978, loss=1257.628956635495


Epoch 19 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.94batch/s, accuracy=tensor(0.0875, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08769142627716064, loss=1257.920869122303


Epoch 20 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.96batch/s, accuracy=tensor(0.0875, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08780013769865036, loss=1257.5488290411


Epoch 21 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.95batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08806303888559341, loss=1257.2266055566956


Epoch 22 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 15.88batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.0881219431757927, loss=1257.7236584632974


Epoch 23 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.54batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08781397342681885, loss=1256.7030691179057


Epoch 24 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.23batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08808556199073792, loss=1256.8162999824108


Epoch 25 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.48batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08803721517324448, loss=1256.8481555244489


Epoch 26 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.40batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08737146109342575, loss=1257.4976263851431


Epoch 27 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.39batch/s, accuracy=tensor(0.0876, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08800448477268219, loss=1256.4704466170263


Epoch 28 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.38batch/s, accuracy=tensor(0.0877, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08820047974586487, loss=1256.047108605476


Epoch 29 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.44batch/s, accuracy=tensor(0.0877, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08775468170642853, loss=1256.0165053720098


Epoch 30 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:37<00:00, 16.18batch/s, accuracy=tensor(0.0877, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08820337802171707, loss=1256.5154857528144


Epoch 31 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.24batch/s, accuracy=tensor(0.0877, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08798304945230484, loss=1254.830596637547


Epoch 32 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.65batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.0879405289888382, loss=1255.078759857235


Epoch 33 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.40batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08806920796632767, loss=1255.575497625469


Epoch 34 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.46batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.088089220225811, loss=1255.0617561268762


Epoch 35 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.35batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08799031376838684, loss=1254.8090270652556


Epoch 36 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:40<00:00, 14.99batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.08827320486307144, loss=1255.748799909123


Epoch 37 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.09batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
validation: accuracy=0.0881045013666153, loss=1254.953310052181


Epoch 38 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.01batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.26e+3]


validation...
saving!
validation: accuracy=0.08805830776691437, loss=1254.522286878518


Epoch 39 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.52batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
saving!
validation: accuracy=0.08817064762115479, loss=1254.3460145989682


Epoch 40 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.52batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08814775943756104, loss=1254.3765454605418


Epoch 41 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.00batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
saving!
validation: accuracy=0.08817430585622787, loss=1254.068369452099


Epoch 42 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.52batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08813902735710144, loss=1254.9770150533536


Epoch 43 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:40<00:00, 14.93batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08822555094957352, loss=1254.6977987951454


Epoch 44 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.66batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
saving!
validation: accuracy=0.08823029696941376, loss=1253.8635899756684


Epoch 45 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.62batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08810922503471375, loss=1254.4010923575281


Epoch 46 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.58batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08792378008365631, loss=1253.937469768703


Epoch 47 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.72batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08824227750301361, loss=1254.5225653728894


Epoch 48 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.57batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
saving!
validation: accuracy=0.08830083161592484, loss=1253.8020482161703


Epoch 49 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.27batch/s, accuracy=tensor(0.0879, device='cuda:0'), loss=1.25e+3]


validation...
saving!
validation: accuracy=0.08818191289901733, loss=1253.6496833958724


Epoch 50 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.06batch/s, accuracy=tensor(0.0879, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08792994916439056, loss=1254.1277107396224


Epoch 51 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:40<00:00, 14.93batch/s, accuracy=tensor(0.0878, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.0879514068365097, loss=1253.7889589807105


Epoch 52 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:40<00:00, 14.81batch/s, accuracy=tensor(0.0879, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08831719309091568, loss=1254.0867006112219


Epoch 53 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.17batch/s, accuracy=tensor(0.0879, device='cuda:0'), loss=1.25e+3]


validation...
saving!
validation: accuracy=0.08821755647659302, loss=1253.3621755174133


Epoch 54 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:39<00:00, 15.15batch/s, accuracy=tensor(0.0879, device='cuda:0'), loss=1.25e+3]


validation...
validation: accuracy=0.08823212236166, loss=1254.2169755144816


Epoch 55 | trn: 100%|█████████████████████████████████████████████████████| 600/600 [00:38<00:00, 15.42batch/s, accuracy=tensor(0.0879, device='cuda:0'), loss=1.25e+3]


validation...
