In [1]:
from Datasets import BinChromaDataset
import numpy as np
from torch.utils.data import DataLoader, Subset
import sys
sys.path.insert(0, '..')
from transformer.models import ContinuousEncoderOnlyWrapper, ContinuousEncoder
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
from tqdm import tqdm
import os

In [2]:
npz_path = '../data/augmented_and_padded_data.npz'
dataset = BinChromaDataset(npz_path)

train_percentage = 0.9
split_idx = int( len(dataset)*train_percentage )

train_set = Subset(dataset, range(0,split_idx))
test_set = Subset(dataset, range(split_idx, len(dataset)))

batch_size = 8
epochs = 1000

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [3]:
src_vocab_size = 12
tgt_vocab_size = 12
d_model = 128
num_heads = 16
num_layers = 4
d_ff = 128
max_seq_length = 129
dropout = 0.3

dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

encoderModel = ContinuousEncoder(src_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

transformer = ContinuousEncoderOnlyWrapper(encoderModel)

transformer = transformer.to(dev)

In [4]:
criterion = CrossEntropyLoss()
optimizer = Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# keep best validation loss for saving
best_val_loss = np.inf
save_dir = '../saved_models/encoderOnly_one_hot/'
encoder_path = save_dir + 'encoderOnly_one_hot.pt'
os.makedirs(save_dir, exist_ok=True)

for epoch in range(epochs):
    train_loss = 0
    running_loss = 0
    samples_num = 0
    running_accuracy = 0
    accuracy = 0
    with tqdm(train_loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch} | trn")
        for melodies, chords in tepoch:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            optimizer.zero_grad()
            output = transformer(melodies)
            # output = transformer(chords) # identity check
            loss = criterion(output.contiguous().view(-1), chords.contiguous().view(-1))
            loss.backward()
            optimizer.step()
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            train_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
            tepoch.set_postfix(loss=train_loss, accuracy=accuracy) # tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
    # validation
    with torch.no_grad():
        val_loss = 0
        running_loss = 0
        samples_num = 0
        running_accuracy = 0
        accuracy = 0
        print('validation...')
        for melodies, chords in test_loader:
            melodies = melodies.to(dev)
            chords = chords.to(dev)
            output = transformer(melodies)
            # output = transformer(chords) # identity check
            loss = criterion(output.contiguous().view(-1), chords.contiguous().view(-1))
            # update loss
            samples_num += melodies.shape[0]
            running_loss += loss.item()
            val_loss = running_loss/samples_num
            # accuracy
            bin_output = output > 0.5
            bin_chords = chords > 0.5
            tmp_acc = 0
            tmp_count = 0
            for b_i in range(bin_output.shape[0]):
                for s_i in range(bin_output.shape[1]):
                    tmp_count += 1
                    tmp_acc += torch.all(bin_output[b_i, s_i, :].eq(bin_chords[b_i, s_i, :]))
            running_accuracy += tmp_acc/tmp_count
            accuracy = running_accuracy/samples_num
        if best_val_loss > val_loss:
            print('saving!')
            best_val_loss = val_loss
            torch.save(transformer.state_dict(), encoder_path)
        print(f'validation: accuracy={accuracy}, loss={val_loss}')

Epoch 0 | trn:   0%|                                                                                                                        | 0/600 [00:00<?, ?batch/s]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   0%|▎                                                       | 3/600 [00:00<02:00,  4.96batch/s, accuracy=tensor(0.0088, device='cuda:0'), loss=1.52e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   1%|▍                                                       | 5/600 [00:00<01:20,  7.37batch/s, accuracy=tensor(0.0133, device='cuda:0'), loss=1.54e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   2%|▊                                                        | 9/600 [00:01<00:56, 10.52batch/s, accuracy=tensor(0.0161, device='cuda:0'), loss=1.6e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   2%|█                                                      | 11/600 [00:01<00:51, 11.52batch/s, accuracy=tensor(0.0197, device='cuda:0'), loss=1.57e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   2%|█▍                                                     | 15/600 [00:01<00:45, 12.77batch/s, accuracy=tensor(0.0230, device='cuda:0'), loss=1.56e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   3%|█▌                                                     | 17/600 [00:01<00:43, 13.29batch/s, accuracy=tensor(0.0273, device='cuda:0'), loss=1.54e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   4%|█▉                                                     | 21/600 [00:02<00:42, 13.56batch/s, accuracy=tensor(0.0307, device='cuda:0'), loss=1.54e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   4%|██                                                     | 23/600 [00:02<00:41, 13.75batch/s, accuracy=tensor(0.0342, device='cuda:0'), loss=1.54e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   4%|██▍                                                    | 27/600 [00:02<00:41, 13.88batch/s, accuracy=tensor(0.0377, device='cuda:0'), loss=1.54e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   5%|██▋                                                    | 29/600 [00:02<00:40, 14.06batch/s, accuracy=tensor(0.0410, device='cuda:0'), loss=1.53e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   6%|███                                                    | 33/600 [00:02<00:40, 14.14batch/s, accuracy=tensor(0.0439, device='cuda:0'), loss=1.52e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   6%|███▏                                                   | 35/600 [00:03<00:39, 14.25batch/s, accuracy=tensor(0.0461, device='cuda:0'), loss=1.52e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   6%|███▌                                                   | 39/600 [00:03<00:38, 14.43batch/s, accuracy=tensor(0.0480, device='cuda:0'), loss=1.53e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   7%|███▊                                                   | 41/600 [00:03<00:38, 14.37batch/s, accuracy=tensor(0.0499, device='cuda:0'), loss=1.53e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   8%|████▏                                                  | 45/600 [00:03<00:38, 14.38batch/s, accuracy=tensor(0.0516, device='cuda:0'), loss=1.54e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   8%|████▎                                                  | 47/600 [00:03<00:38, 14.44batch/s, accuracy=tensor(0.0535, device='cuda:0'), loss=1.52e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   8%|████▊                                                   | 51/600 [00:04<00:37, 14.55batch/s, accuracy=tensor(0.0552, device='cuda:0'), loss=1.5e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:   9%|████▊                                                  | 53/600 [00:04<00:38, 14.30batch/s, accuracy=tensor(0.0563, device='cuda:0'), loss=1.51e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  10%|█████▎                                                  | 57/600 [00:04<00:37, 14.35batch/s, accuracy=tensor(0.0577, device='cuda:0'), loss=1.5e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  10%|█████▍                                                 | 59/600 [00:04<00:37, 14.33batch/s, accuracy=tensor(0.0586, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  10%|█████▊                                                 | 63/600 [00:04<00:37, 14.26batch/s, accuracy=tensor(0.0595, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  11%|█████▉                                                 | 65/600 [00:05<00:37, 14.33batch/s, accuracy=tensor(0.0602, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  12%|██████▎                                                | 69/600 [00:05<00:37, 14.20batch/s, accuracy=tensor(0.0608, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  12%|██████▌                                                | 71/600 [00:05<00:37, 14.11batch/s, accuracy=tensor(0.0614, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  12%|██████▉                                                | 75/600 [00:05<00:37, 13.92batch/s, accuracy=tensor(0.0620, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  13%|███████                                                | 77/600 [00:06<00:38, 13.74batch/s, accuracy=tensor(0.0626, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  14%|███████▍                                               | 81/600 [00:06<00:39, 13.29batch/s, accuracy=tensor(0.0632, device='cuda:0'), loss=1.49e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  14%|███████▌                                               | 83/600 [00:06<00:38, 13.32batch/s, accuracy=tensor(0.0637, device='cuda:0'), loss=1.48e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  14%|███████▉                                               | 87/600 [00:06<00:38, 13.26batch/s, accuracy=tensor(0.0642, device='cuda:0'), loss=1.48e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  15%|████████▏                                              | 89/600 [00:06<00:39, 12.95batch/s, accuracy=tensor(0.0647, device='cuda:0'), loss=1.47e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  16%|████████▌                                              | 93/600 [00:07<00:39, 12.71batch/s, accuracy=tensor(0.0652, device='cuda:0'), loss=1.47e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  16%|████████▋                                              | 95/600 [00:07<00:39, 12.69batch/s, accuracy=tensor(0.0656, device='cuda:0'), loss=1.46e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  16%|█████████                                              | 99/600 [00:07<00:39, 12.73batch/s, accuracy=tensor(0.0661, device='cuda:0'), loss=1.46e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  17%|█████████                                             | 101/600 [00:07<00:38, 12.81batch/s, accuracy=tensor(0.0665, device='cuda:0'), loss=1.45e+3]

torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


Epoch 0 | trn:  18%|█████████▍                                            | 105/600 [00:08<00:38, 12.92batch/s, accuracy=tensor(0.0668, device='cuda:0'), loss=1.45e+3]


torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])
torch.Size([8, 129, 12])
torch.Size([8, 1, 1, 129])


KeyboardInterrupt: 