In [1]:
from data_utils import LSTMHarmonyDataset, TransitionMatrixDataset, BagOfTransitionsDataset, LSTMPaddingCollator
from torch.utils.data import DataLoader
import torch
import GridMLM_tokenizers
from GridMLM_tokenizers import CSGridMLMTokenizer
from data_utils import CSGridMLMDataset
import pickle
from models_baseline import LSTMHarmonyModel, TransitionMatrixAutoencoder, BagOfTransitionsAutoencoder
from models_baseline import train_lstm, train_matrix_ae, train_bot_ae
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BoT_vocab_path = 'data/BoT_vocab.pickle'

with open(BoT_vocab_path, 'rb') as f:
    BoT_vocab = pickle.load(f)

In [3]:
train_dir = '/media/maindisk/data/hooktheory_midi_hr/CA_train'
# train_dir = '/media/maindisk/data/hooktheory_midi_hr/CA_test'
val_dir = '/media/maindisk/data/hooktheory_midi_hr/CA_test'

tokenizer = CSGridMLMTokenizer(fixed_length=256)

train_dataset = CSGridMLMDataset(train_dir, tokenizer, frontloading=True, name_suffix='Q4_L80_bar_PC')
val_dataset = CSGridMLMDataset(val_dir, tokenizer, frontloading=True, name_suffix='Q4_L80_bar_PC')

chord_features = GridMLM_tokenizers.CHORD_FEATURES
chord_id_features = {tokenizer.vocab[k]: v for k, v in chord_features.items()}

Loading data file.
Loading data file.


In [4]:
lstm_train_dataset = LSTMHarmonyDataset(train_dataset, chord_id_features)
lstm_val_dataset = LSTMHarmonyDataset(val_dataset, chord_id_features)

In [5]:
matrix_train_dataset = TransitionMatrixDataset(train_dataset, chord_id_features, tokenizer)
matrix_val_dataset = TransitionMatrixDataset(val_dataset, chord_id_features, tokenizer)

In [6]:
bot_train_dataset = BagOfTransitionsDataset(train_dataset, chord_id_features, BoT_vocab)
bot_val_dataset = BagOfTransitionsDataset(val_dataset, chord_id_features, BoT_vocab)

In [7]:
d0 = lstm_train_dataset[0]
print(d0)

{'chord_sequence': tensor([269, 123,   7, 123, 269, 123,   7, 123,  66, 123, 269,   7, 123,  66,
        210,   7, 152, 153])}


In [8]:
d0 = matrix_train_dataset[0]
print(d0)

{'transition_matrix': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}


In [9]:
d0 = bot_train_dataset[0]
print(d0)

{'bag_of_transitions': tensor([0.1176, 0.1176, 0.1765,  ..., 0.0000, 0.0000, 0.0000])}


In [10]:
collator = LSTMPaddingCollator(pad_id=0)

In [11]:
train_loader_lstm = DataLoader(
    lstm_train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collator
)

val_loader_lstm = DataLoader(
    lstm_val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collator
)

In [12]:
train_loader_matrix = DataLoader(
    matrix_train_dataset,
    batch_size=32,
    shuffle=True
)
val_loader_matrix = DataLoader(
    matrix_val_dataset,
    batch_size=32,
    shuffle=False
)

In [13]:
train_loader_bot = DataLoader(
    bot_train_dataset,
    batch_size=32,
    shuffle=True
)
val_loader_bot = DataLoader(
    bot_val_dataset,
    batch_size=32,
    shuffle=False
)

In [14]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [18]:
model_lstm = LSTMHarmonyModel(vocab_size=len(chord_id_features)+tokenizer.FIST_CHORD_TOKEN_INDEX)
model_lstm.train()
model_lstm.to(device)

model_matrix = TransitionMatrixAutoencoder(D=len(chord_id_features))
model_matrix.train()
model_matrix.to(device)

model_bot = BagOfTransitionsAutoencoder(vocab_size=len(BoT_vocab))
model_bot.train()
model_bot.to(device)

BagOfTransitionsAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=3870, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=128, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=3870, bias=True)
  )
)

In [19]:
optimizer_lstm = torch.optim.AdamW(model_lstm.parameters(), lr=1e-4)
optimizer_matrix = torch.optim.AdamW(model_matrix.parameters(), lr=1e-4)
optimizer_bot = torch.optim.AdamW(model_bot.parameters(), lr=1e-4)

In [20]:
os.makedirs('saved_models', exist_ok=True)

In [21]:
train_lstm(model_lstm, train_loader_lstm, val_loader_lstm, optimizer_lstm, device, save_path='saved_models/lstm.pt', num_epochs=5)

Epoch 0| trn: 100%|██████████| 459/459 [00:03<00:00, 121.35batch/s, loss=3.33]
Epoch 0| val: 100%|██████████| 24/24 [00:00<00:00, 292.47batch/s, loss=2.8]


saving...


Epoch 1| trn: 100%|██████████| 459/459 [00:04<00:00, 110.49batch/s, loss=2.75]
Epoch 1| val: 100%|██████████| 24/24 [00:00<00:00, 345.58batch/s, loss=2.66]


saving...


Epoch 2| trn: 100%|██████████| 459/459 [00:03<00:00, 126.49batch/s, loss=2.62]
Epoch 2| val: 100%|██████████| 24/24 [00:00<00:00, 296.31batch/s, loss=2.55]


saving...


Epoch 3| trn: 100%|██████████| 459/459 [00:04<00:00, 112.97batch/s, loss=2.54]
Epoch 3| val: 100%|██████████| 24/24 [00:00<00:00, 317.53batch/s, loss=2.48]


saving...


Epoch 4| trn: 100%|██████████| 459/459 [00:03<00:00, 122.81batch/s, loss=2.47]
Epoch 4| val: 100%|██████████| 24/24 [00:00<00:00, 263.83batch/s, loss=2.43]


saving...


In [18]:
train_bot_ae(model_bot, train_loader_bot, val_loader_bot, optimizer_bot, device, save_path='saved_models/bot.pt', num_epochs=5)

Epoch 0| trn: 100%|██████████| 459/459 [00:04<00:00, 93.00batch/s, loss=4.11] 
Epoch 0| val: 100%|██████████| 24/24 [00:00<00:00, 122.74batch/s, loss=3.3] 


saving...


Epoch 1| trn: 100%|██████████| 459/459 [00:04<00:00, 102.62batch/s, loss=3.33]
Epoch 1| val: 100%|██████████| 24/24 [00:00<00:00, 121.34batch/s, loss=3.24]


saving...


Epoch 2| trn: 100%|██████████| 459/459 [00:04<00:00, 92.71batch/s, loss=3.23] 
Epoch 2| val: 100%|██████████| 24/24 [00:00<00:00, 136.20batch/s, loss=3.01]


saving...


Epoch 3| trn: 100%|██████████| 459/459 [00:04<00:00, 93.31batch/s, loss=2.9]  
Epoch 3| val: 100%|██████████| 24/24 [00:00<00:00, 129.82batch/s, loss=2.7] 


saving...


Epoch 4| trn: 100%|██████████| 459/459 [00:04<00:00, 94.35batch/s, loss=2.65] 
Epoch 4| val: 100%|██████████| 24/24 [00:00<00:00, 117.66batch/s, loss=2.51]


saving...


In [19]:
train_matrix_ae(model_matrix, train_loader_matrix, val_loader_matrix, optimizer_matrix, device, save_path='saved_models/matrix.pt', num_epochs=5)

Epoch 0| trn: 100%|██████████| 459/459 [00:24<00:00, 18.77batch/s, loss=14.1]
Epoch 0| val: 100%|██████████| 24/24 [00:00<00:00, 61.77batch/s, loss=11.5]


saving...


Epoch 1| trn: 100%|██████████| 459/459 [00:24<00:00, 18.78batch/s, loss=11]  
Epoch 1| val: 100%|██████████| 24/24 [00:00<00:00, 59.08batch/s, loss=10.5]


saving...


Epoch 2| trn: 100%|██████████| 459/459 [00:24<00:00, 18.69batch/s, loss=9.7] 
Epoch 2| val: 100%|██████████| 24/24 [00:00<00:00, 65.93batch/s, loss=9.4] 


saving...


Epoch 3| trn: 100%|██████████| 459/459 [00:24<00:00, 18.56batch/s, loss=8.35]
Epoch 3| val: 100%|██████████| 24/24 [00:00<00:00, 61.19batch/s, loss=8.23]


saving...


Epoch 4| trn: 100%|██████████| 459/459 [00:24<00:00, 18.53batch/s, loss=7.11]
Epoch 4| val: 100%|██████████| 24/24 [00:00<00:00, 62.95batch/s, loss=7.52]


saving...
