In [1]:
from data_utils import LSTMHarmonyDataset, TransitionMatrixDataset, BagOfTransitionsDataset, LSTMPaddingCollator
from torch.utils.data import DataLoader
import torch
import GridMLM_tokenizers
from GridMLM_tokenizers import CSGridMLMTokenizer
from data_utils import CSGridMLMDataset
import pickle
from models_baseline import LSTMHarmonyModel, TransitionMatrixAutoencoder, BagOfTransitionsAutoencoder
from models_baseline import train_lstm, train_matrix_ae, train_bot_ae
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BoT_vocab_path = 'data/BoT_vocab.pickle'

with open(BoT_vocab_path, 'rb') as f:
    BoT_vocab = pickle.load(f)

In [3]:
# train_dir = '/media/maindisk/data/hooktheory_midi_hr/CA_train'
train_dir = '/media/maindisk/data/hooktheory_midi_hr/CA_test'
val_dir = '/media/maindisk/data/hooktheory_midi_hr/CA_test'

tokenizer = CSGridMLMTokenizer(fixed_length=256)

train_dataset = CSGridMLMDataset(train_dir, tokenizer, frontloading=True, name_suffix='Q4_L80_bar_PC')
val_dataset = CSGridMLMDataset(val_dir, tokenizer, frontloading=True, name_suffix='Q4_L80_bar_PC')

chord_features = GridMLM_tokenizers.CHORD_FEATURES
chord_id_features = {tokenizer.vocab[k]: v for k, v in chord_features.items()}

Loading data file.
Loading data file.


In [4]:
lstm_train_dataset = LSTMHarmonyDataset(train_dataset, chord_id_features)
lstm_val_dataset = LSTMHarmonyDataset(val_dataset, chord_id_features)

In [5]:
matrix_train_dataset = TransitionMatrixDataset(train_dataset, chord_id_features)
matrix_val_dataset = TransitionMatrixDataset(val_dataset, chord_id_features)

In [6]:
bot_train_dataset = BagOfTransitionsDataset(train_dataset, chord_id_features, BoT_vocab)
bot_val_dataset = BagOfTransitionsDataset(val_dataset, chord_id_features, BoT_vocab)

In [7]:
d0 = lstm_train_dataset[0]
print(d0)

{'chord_sequence': tensor([269, 152,  66, 269, 152,  66, 269, 152,  66, 269, 152,  66])}


In [8]:
d0 = matrix_train_dataset[0]
print(d0)

{'transition_matrix': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}


In [9]:
d0 = bot_train_dataset[0]
print(d0)

{'bag_of_transitions': tensor([0., 0., 0.,  ..., 0., 0., 0.])}


In [10]:
collator = LSTMPaddingCollator(pad_id=0)

In [11]:
train_loader_lstm = DataLoader(
    lstm_train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collator
)

val_loader_lstm = DataLoader(
    lstm_val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collator
)

In [12]:
train_loader_matrix = DataLoader(
    matrix_train_dataset,
    batch_size=32,
    shuffle=True
)
val_loader_matrix = DataLoader(
    matrix_val_dataset,
    batch_size=32,
    shuffle=False
)

In [13]:
train_loader_bot = DataLoader(
    bot_train_dataset,
    batch_size=32,
    shuffle=True
)
val_loader_bot = DataLoader(
    bot_val_dataset,
    batch_size=32,
    shuffle=False
)

In [14]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [15]:
model_lstm = LSTMHarmonyModel(vocab_size=len(chord_id_features))
model_lstm.train()
model_lstm.to(device)

model_matrix = TransitionMatrixAutoencoder(D=len(chord_id_features))
model_matrix.train()
model_matrix.to(device)

model_bot = BagOfTransitionsAutoencoder(vocab_size=len(BoT_vocab))
model_bot.train()
model_bot.to(device)

BagOfTransitionsAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=3870, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=128, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=3870, bias=True)
  )
)

In [16]:
optimizer_lstm = torch.optim.AdamW(model_lstm.parameters(), lr=1e-4)
optimizer_matrix = torch.optim.AdamW(model_matrix.parameters(), lr=1e-4)
optimizer_bot = torch.optim.AdamW(model_bot.parameters(), lr=1e-4)

In [17]:
os.makedirs('saved_models', exist_ok=True)

In [18]:
train_bot_ae(model_bot, train_loader_bot, val_loader_bot, optimizer_bot, device, save_path='saved_models/bot.pt', num_epochs=5)

Epoch 0| trn: 100%|██████████| 24/24 [00:00<00:00, 63.92batch/s, loss=6.45]
Epoch 0| val: 100%|██████████| 24/24 [00:00<00:00, 132.01batch/s, loss=6.42]


saving...


Epoch 1| trn: 100%|██████████| 24/24 [00:00<00:00, 88.32batch/s, loss=6.38]
Epoch 1| val: 100%|██████████| 24/24 [00:00<00:00, 143.22batch/s, loss=6.34]


saving...


Epoch 2| trn: 100%|██████████| 24/24 [00:00<00:00, 100.16batch/s, loss=6.25]
Epoch 2| val: 100%|██████████| 24/24 [00:00<00:00, 126.74batch/s, loss=6.12]


saving...


Epoch 3| trn: 100%|██████████| 24/24 [00:00<00:00, 97.54batch/s, loss=5.85]
Epoch 3| val: 100%|██████████| 24/24 [00:00<00:00, 100.89batch/s, loss=5.43]


saving...


Epoch 4| trn: 100%|██████████| 24/24 [00:00<00:00, 106.14batch/s, loss=4.77]
Epoch 4| val: 100%|██████████| 24/24 [00:00<00:00, 149.40batch/s, loss=3.95]


saving...


In [19]:
train_matrix_ae(model_matrix, train_loader_matrix, val_loader_matrix, optimizer_matrix, device, save_path='saved_models/matrix.pt', num_epochs=5)

Epoch 0| trn: 100%|██████████| 24/24 [00:01<00:00, 17.76batch/s, loss=26.2]
Epoch 0| val: 100%|██████████| 24/24 [00:00<00:00, 62.11batch/s, loss=26]  


saving...


Epoch 1| trn: 100%|██████████| 24/24 [00:01<00:00, 19.37batch/s, loss=25.3]
Epoch 1| val: 100%|██████████| 24/24 [00:00<00:00, 62.54batch/s, loss=23.8]


saving...


Epoch 2| trn: 100%|██████████| 24/24 [00:01<00:00, 17.24batch/s, loss=19.1]
Epoch 2| val: 100%|██████████| 24/24 [00:00<00:00, 59.56batch/s, loss=12.6]


saving...


Epoch 3| trn: 100%|██████████| 24/24 [00:01<00:00, 18.97batch/s, loss=11.8]
Epoch 3| val: 100%|██████████| 24/24 [00:00<00:00, 67.76batch/s, loss=11]  


saving...


Epoch 4| trn: 100%|██████████| 24/24 [00:01<00:00, 18.51batch/s, loss=11]  
Epoch 4| val: 100%|██████████| 24/24 [00:00<00:00, 62.28batch/s, loss=10.6]


saving...


In [20]:
train_lstm(model_lstm, train_loader_lstm, val_loader_lstm, optimizer_lstm, device, save_path='saved_models/lstm.pt', num_epochs=5)

  sequences = [torch.tensor(seq, dtype=torch.long) for seq in sequences]
Epoch 0| trn: 100%|██████████| 24/24 [00:00<00:00, 72.57batch/s, loss=5.79]
Epoch 0| val: 100%|██████████| 24/24 [00:00<00:00, 276.38batch/s, loss=5.69]


saving...


Epoch 1| trn: 100%|██████████| 24/24 [00:00<00:00, 94.48batch/s, loss=5.32]
Epoch 1| val: 100%|██████████| 24/24 [00:00<00:00, 215.86batch/s, loss=4.57]


saving...


Epoch 2| trn: 100%|██████████| 24/24 [00:00<00:00, 101.75batch/s, loss=3.93]
Epoch 2| val: 100%|██████████| 24/24 [00:00<00:00, 213.91batch/s, loss=3.51]


saving...


Epoch 3| trn: 100%|██████████| 24/24 [00:00<00:00, 102.72batch/s, loss=3.37]
Epoch 3| val: 100%|██████████| 24/24 [00:00<00:00, 331.44batch/s, loss=3.27]


saving...


Epoch 4| trn: 100%|██████████| 24/24 [00:00<00:00, 97.53batch/s, loss=3.22]
Epoch 4| val: 100%|██████████| 24/24 [00:00<00:00, 215.88batch/s, loss=3.17]


saving...
