In [1]:
from Datasets import BinChromaDataset
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from transformer.models import TransformerFromModels, EncoderModel, DecoderModel
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os

In [2]:
npz_path = '../data/augmented_and_padded_data.npz'
dataset = BinChromaDataset(npz_path)

split_idx = int( len(dataset)*0.8 )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 16
epochs = 1000

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [3]:
src_vocab_size = 2**12
tgt_vocab_size = 2**12
d_model = 512
num_heads = 4
num_layers = 3
d_ff = 2048
max_seq_length = 129
dropout = 0.1

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoderModel = EncoderModel(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
decoderModel = DecoderModel(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

encoderModel = encoderModel.to(dev)
decoderModel = decoderModel.to(dev)

transformer = TransformerFromModels(encoderModel, decoderModel)

transformer = transformer.to(dev)

In [None]:
criterion = CrossEntropyLoss(ignore_index=0)
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# keep best validation loss for saving
best_val_loss = np.inf
save_dir = '../saved_models/encoder_decoder_one_hot/'
encoder_path = save_dir + 'encoder_one_hot.pt'
decoder_path = save_dir + 'decoder_one_hot.pt'
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies, chords[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:].contiguous().view(-1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            prediction = output.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction == chords[:, 1:]).sum().item()/prediction.shape[1]
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    # validation
    with torch.no_grad():
        val_loss = 0
        running_loss = 0
        samples_num = 0
        running_accuracy = 0
        accuracy = 0
        print('validation...')
        for melodies, chords in test_loader:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            output = transformer(melodies, chords[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:].contiguous().view(-1))
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            val_loss = running_loss/samples_num
            # accuracy
            prediction = output.argmax(dim=2, keepdim=True).squeeze()
            running_accuracy += (prediction == chords[:, 1:]).sum().item()/prediction.shape[1]
            accuracy = running_accuracy/samples_num
        if best_val_loss > val_loss:
            print('saving!')
            best_val_loss = val_loss
            torch.save(encoderModel.state_dict(), encoder_path)
            torch.save(decoderModel.state_dict(), decoder_path)
        print(f'validation: accuracy={accuracy}, loss={val_loss}')
    # print(f"Epoch: {epoch+1}, training loss: {train_loss} | validation loss {val_loss}")

Epoch 0 | trn: 100%|██████████████████████████████████| 267/267 [02:09<00:00,  2.06batch/s, accuracy=0.0456, loss=0.322]


validation...
saving!
validation: accuracy=0.07223997420262664, loss=0.2615051819727971


Epoch 1 | trn: 100%|██████████████████████████████████| 267/267 [02:24<00:00,  1.85batch/s, accuracy=0.0799, loss=0.242]


validation...
saving!
validation: accuracy=0.08576160881801126, loss=0.23122246739564647


Epoch 2 | trn: 100%|██████████████████████████████████| 267/267 [02:28<00:00,  1.80batch/s, accuracy=0.0914, loss=0.216]


validation...
saving!
validation: accuracy=0.09158800422138837, loss=0.21817203332961835


Epoch 3 | trn: 100%|██████████████████████████████████| 267/267 [02:28<00:00,  1.80batch/s, accuracy=0.0982, loss=0.201]


validation...
saving!
validation: accuracy=0.0949812382739212, loss=0.2107072241087121


Epoch 4 | trn: 100%|███████████████████████████████████| 267/267 [02:29<00:00,  1.78batch/s, accuracy=0.105, loss=0.189]


validation...
saving!
validation: accuracy=0.09759761960600376, loss=0.20659623875179614


Epoch 5 | trn: 100%|████████████████████████████████████| 267/267 [02:29<00:00,  1.79batch/s, accuracy=0.11, loss=0.178]


validation...
saving!
validation: accuracy=0.09852837711069419, loss=0.2036456386322823


Epoch 6 | trn: 100%|███████████████████████████████████| 267/267 [02:29<00:00,  1.78batch/s, accuracy=0.117, loss=0.169]


validation...
saving!
validation: accuracy=0.09989886257035648, loss=0.20135189556493993


Epoch 7 | trn: 100%|███████████████████████████████████| 267/267 [02:29<00:00,  1.79batch/s, accuracy=0.123, loss=0.159]


validation...
saving!
validation: accuracy=0.1006390712945591, loss=0.20040623779368447


Epoch 8 | trn: 100%|████████████████████████████████████| 267/267 [02:28<00:00,  1.79batch/s, accuracy=0.129, loss=0.15]


validation...
saving!
validation: accuracy=0.1020095567542214, loss=0.19954654624717097


Epoch 9 | trn: 100%|███████████████████████████████████| 267/267 [02:30<00:00,  1.77batch/s, accuracy=0.135, loss=0.141]


validation...
validation: accuracy=0.10137928001876173, loss=0.20108424073983314


Epoch 10 | trn: 100%|██████████████████████████████████| 267/267 [02:29<00:00,  1.78batch/s, accuracy=0.142, loss=0.132]


validation...
validation: accuracy=0.1017530487804878, loss=0.20106109423216914


Epoch 11 | trn: 100%|███████████████████████████████████| 267/267 [02:28<00:00,  1.80batch/s, accuracy=0.15, loss=0.123]


validation...
validation: accuracy=0.10107879924953096, loss=0.2036576642328087


Epoch 12 | trn: 100%|██████████████████████████████████| 267/267 [02:27<00:00,  1.80batch/s, accuracy=0.157, loss=0.115]


validation...
validation: accuracy=0.09983290337711069, loss=0.20632456122822432


Epoch 13 | trn: 100%|██████████████████████████████████| 267/267 [02:28<00:00,  1.80batch/s, accuracy=0.164, loss=0.106]


validation...
validation: accuracy=0.10017002814258912, loss=0.2090108300686777


Epoch 14 | trn: 100%|█████████████████████████████████| 267/267 [02:30<00:00,  1.78batch/s, accuracy=0.171, loss=0.0984]


validation...
validation: accuracy=0.09807399155722327, loss=0.21130200980080896


Epoch 15 | trn: 100%|██████████████████████████████████| 267/267 [02:28<00:00,  1.80batch/s, accuracy=0.179, loss=0.091]


validation...
validation: accuracy=0.09899742026266417, loss=0.2151619475509615


Epoch 16 | trn: 100%|█████████████████████████████████| 267/267 [02:28<00:00,  1.80batch/s, accuracy=0.186, loss=0.0835]


validation...
validation: accuracy=0.09781748358348968, loss=0.21856679947693844


Epoch 17 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.193, loss=0.0767]


validation...
validation: accuracy=0.09690138367729831, loss=0.2221321466939758


Epoch 18 | trn: 100%|███████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.2, loss=0.0703]


validation...
validation: accuracy=0.09567747420262664, loss=0.22692838842381233


Epoch 19 | trn: 100%|█████████████████████████████████| 267/267 [02:26<00:00,  1.82batch/s, accuracy=0.207, loss=0.0643]


validation...
validation: accuracy=0.09535500703564728, loss=0.23196345869640472


Epoch 20 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.213, loss=0.0588]


validation...
validation: accuracy=0.09429965994371482, loss=0.23652788420779172


Epoch 21 | trn: 100%|██████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.219, loss=0.054]


validation...
validation: accuracy=0.0945341815196998, loss=0.24038596895801193


Epoch 22 | trn: 100%|█████████████████████████████████| 267/267 [02:26<00:00,  1.82batch/s, accuracy=0.224, loss=0.0493]


validation...
validation: accuracy=0.09432164634146341, loss=0.24539264952711495


Epoch 23 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.229, loss=0.0451]


validation...
validation: accuracy=0.09360342401500939, loss=0.25014967475554734


Epoch 24 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.234, loss=0.0411]


validation...
validation: accuracy=0.09325897045028142, loss=0.25379395328364274


Epoch 25 | trn: 100%|█████████████████████████████████| 267/267 [02:29<00:00,  1.78batch/s, accuracy=0.237, loss=0.0382]


validation...
validation: accuracy=0.09388191838649156, loss=0.2578807407203803


Epoch 26 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.241, loss=0.0351]


validation...
validation: accuracy=0.09441692073170732, loss=0.2624143901059149


Epoch 27 | trn: 100%|█████████████████████████████████| 267/267 [02:29<00:00,  1.78batch/s, accuracy=0.245, loss=0.0326]


validation...
validation: accuracy=0.09246013133208256, loss=0.266965063159506


Epoch 28 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.82batch/s, accuracy=0.248, loss=0.0304]


validation...
validation: accuracy=0.09298047607879925, loss=0.26919845233044076


Epoch 29 | trn: 100%|██████████████████████████████████| 267/267 [02:27<00:00,  1.82batch/s, accuracy=0.25, loss=0.0285]


validation...
validation: accuracy=0.09276061210131333, loss=0.2741384293751242


Epoch 30 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.252, loss=0.0268]


validation...
validation: accuracy=0.09229889774859287, loss=0.27908645897376827


Epoch 31 | trn: 100%|█████████████████████████████████| 267/267 [02:26<00:00,  1.82batch/s, accuracy=0.255, loss=0.0251]


validation...
validation: accuracy=0.09328095684803002, loss=0.2777940023683473


Epoch 32 | trn: 100%|█████████████████████████████████| 267/267 [02:26<00:00,  1.82batch/s, accuracy=0.256, loss=0.0237]


validation...
validation: accuracy=0.09226225375234522, loss=0.28254063804869806


Epoch 33 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.82batch/s, accuracy=0.258, loss=0.0224]


validation...
validation: accuracy=0.09347150562851782, loss=0.2863316249668486


Epoch 34 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.259, loss=0.0214]


validation...
validation: accuracy=0.09237218574108819, loss=0.28939841097783414


Epoch 35 | trn: 100%|█████████████████████████████████| 267/267 [02:26<00:00,  1.82batch/s, accuracy=0.261, loss=0.0202]


validation...
validation: accuracy=0.09271663930581614, loss=0.28955115028438605


Epoch 36 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.82batch/s, accuracy=0.263, loss=0.0192]


validation...
validation: accuracy=0.09304643527204502, loss=0.292732980864133


Epoch 37 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.263, loss=0.0185]


validation...
validation: accuracy=0.09237218574108819, loss=0.2950204615297729


Epoch 38 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.264, loss=0.0176]


validation...
validation: accuracy=0.0921889657598499, loss=0.2990224446409415


Epoch 39 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.265, loss=0.0171]


validation...
validation: accuracy=0.09266533771106941, loss=0.3005620131573131


Epoch 40 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.81batch/s, accuracy=0.266, loss=0.0161]


validation...
validation: accuracy=0.09316369606003752, loss=0.3039305359963852


Epoch 41 | trn: 100%|█████████████████████████████████| 267/267 [02:27<00:00,  1.80batch/s, accuracy=0.267, loss=0.0156]


validation...
validation: accuracy=0.0933469160412758, loss=0.30462620115190686


Epoch 42 | trn:  69%|███████████████████████▎          | 183/267 [01:40<00:47,  1.77batch/s, accuracy=0.268, loss=0.015]

In [None]:
# to load the model
encoderModel = EncoderModel(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
decoderModel = DecoderModel(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

encoderModel.load_state_dict(torch.load(encoder_dir))
decoderModel.load_state_dict(torch.load(decoder_dir))

transformer = TransformerFromModels(encoderModel, decoderModel)
# for inference, not retraining, we need to run
transformer.eval()