In [1]:
# https://github.com/lanpa/tensorboardX/tree/8a7925b27cc7233c8fa700c11b741725fc1ce2d9
from Datasets import BinChromaDataset, PermutationsBinChromaDataset
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from transformer.models import TransformerFromModels, ContinuousEncoder, ContinuousDecoderModel
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os

In [2]:
npz_path = '../data/augmented_and_padded_data.npz'
dataset = BinChromaDataset(npz_path)

train_percentage = 0.9
split_idx = int( len(dataset)*train_percentage )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 8
epochs = 1000

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [3]:
permutation_dataset = PermutationsBinChromaDataset(npz_path)
permutation_loader = DataLoader(permutation_dataset, batch_size=batch_size, shuffle=True)

In [4]:
src_vocab_size = 12
tgt_vocab_size = 12
d_model = 256
num_heads = 4
num_layers = 4
d_ff = 256
max_seq_length = 129
dropout = 0.3

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoderModel = ContinuousEncoder(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
decoderModel = ContinuousDecoderModel(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

encoderModel = encoderModel.to(dev)
decoderModel = decoderModel.to(dev)

transformer = TransformerFromModels(encoderModel, decoderModel)

transformer = transformer.to(dev)

In [5]:
criterion = BCEWithLogitsLoss()
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# keep best validation loss for saving
best_val_loss = np.inf
save_dir = '../saved_models/encoder_decoder_binary/'
encoder_path = save_dir + 'encoder_one_hot.pt'
decoder_path = save_dir + 'decoder_one_hot.pt'
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies, chords[:, :-1, :])
            # loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:, :].contiguous().view(-1, tgt_vocab_size))
            loss = criterion(output.permute(0, 2, 1), chords[:, 1:, :].permute(0, 2, 1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords[:, 1:] > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    with tqdm(permutation_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | prm")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies, chords[:, :-1, :])
            # loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:, :].contiguous().view(-1, tgt_vocab_size))
            loss = criterion(output.permute(0, 2, 1), chords[:, 1:, :].permute(0, 2, 1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords[:, 1:] > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    # validation
    with torch.no_grad():
        val_loss = 0
        running_loss = 0
        samples_num = 0
        running_accuracy = 0
        accuracy = 0
        print('validation...')
        for melodies, chords in test_loader:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            output = transformer(melodies, chords[:, :-1, :])
            # loss = criterion(output.contiguous().view(-1, tgt_vocab_size), chords[:, 1:].contiguous().view(-1, tgt_vocab_size))
            loss = criterion(output.permute(0, 2, 1), chords[:, 1:, :].permute(0, 2, 1))
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            val_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords[:, 1:] > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
        if best_val_loss > val_loss:
            print('saving!')
            best_val_loss = val_loss
            torch.save(encoderModel.state_dict(), encoder_path)
            torch.save(decoderModel.state_dict(), decoder_path)
        print(f'validation: accuracy={accuracy}, loss={val_loss}')

Epoch 0 | trn: 100%|███████████████████████████████████████████████████████| 600/600 [00:49<00:00, 12.11batch/s, accuracy=tensor(0.0883, device='cuda:0'), loss=0.0273]
Epoch 0 | prm: 100%|███████████████████████████████████████████████████████| 666/666 [00:57<00:00, 11.66batch/s, accuracy=tensor(0.0883, device='cuda:0'), loss=0.0245]


validation...
saving!
validation: accuracy=0.08851504325866699, loss=0.020841262326007935


Epoch 1 | trn: 100%|███████████████████████████████████████████████████████| 600/600 [00:51<00:00, 11.68batch/s, accuracy=tensor(0.0882, device='cuda:0'), loss=0.0199]
Epoch 1 | prm: 100%|███████████████████████████████████████████████████████| 666/666 [00:57<00:00, 11.54batch/s, accuracy=tensor(0.0882, device='cuda:0'), loss=0.0204]


validation...
saving!
validation: accuracy=0.08871255069971085, loss=0.020141068237919297


Epoch 2 | trn: 100%|███████████████████████████████████████████████████████| 600/600 [00:51<00:00, 11.59batch/s, accuracy=tensor(0.0885, device='cuda:0'), loss=0.0191]
Epoch 2 | prm: 100%|███████████████████████████████████████████████████████| 666/666 [00:57<00:00, 11.50batch/s, accuracy=tensor(0.0885, device='cuda:0'), loss=0.0198]


validation...
saving!
validation: accuracy=0.08918452262878418, loss=0.01958494317520403


Epoch 3 | trn: 100%|███████████████████████████████████████████████████████| 600/600 [00:49<00:00, 12.18batch/s, accuracy=tensor(0.0891, device='cuda:0'), loss=0.0185]
Epoch 3 | prm:  22%|████████████▏                                          | 148/666 [00:11<00:40, 12.70batch/s, accuracy=tensor(0.0892, device='cuda:0'), loss=0.0187]


KeyboardInterrupt: 