In [16]:
import sys
import os
# sys.path.append(os.path.abspath(".."))  # or "." if you're in the root
sys.path.append(os.path.abspath("../../"))
%load_ext autoreload
%autoreload 2



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
import torch

from tqdm.notebook import tqdm

from packages.tokenisers.midi_tokeniser import MidiTokenizer
from packages.dataloaders.midi_dataloader import get_dataloader
from packages.models.transformer import MusicTransformerVAE
from packages.losses.transformer_losses import vae_loss


In [59]:
tokenizer = MidiTokenizer()
# dataloader = get_dataloader('../../maestro-v3.0.0/2018', tokenizer, batch_size=4)
years = ['2004', '2006', '2008', '2009', '2011', '2013', '2014', '2015', '2017', '2018']
batch_size = 8
dataloader = get_dataloader('../../maestro-v3.0.0', years, tokenizer, batch_size=batch_size)
num_epochs = 10
beta = 0.0000001
# Set up device
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Getting dataloader
Processing year: 2004
Processing year: 2006
Processing year: 2008
Processing year: 2009
Processing year: 2011
Processing year: 2013
Processing year: 2014
Processing year: 2015
Processing year: 2017
Processing year: 2018
Found 1276 MIDI files
Creating dataloader
Device: mps


In [60]:
model = MusicTransformerVAE(tokenizer.vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model = model.to(device)

# Set up learning rate scheduler
total_steps = num_epochs * len(dataloader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=total_steps,
    eta_min=1e-5
)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_ce_loss = 0
    train_kl_loss = 0
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        src = batch
        tgt = src.clone()
        
        optimizer.zero_grad()
        logits, mu, logvar = model(src, tgt[:, :-1])
        loss, ce_loss, kl_loss = vae_loss(
            logits, 
            tgt[:, 1:], 
            mu, 
            logvar, 
            beta=beta,
            current_step=epoch * len(dataloader) + i
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        train_loss += loss.item()
        train_ce_loss += ce_loss
        train_kl_loss += kl_loss
        
        # Print batch results
        print(f"Batch {i+1}/{len(dataloader)} - Loss: {loss.item():.4f}, CE: {ce_loss:.4f}, KL: {kl_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
    
    avg_loss = train_loss / len(dataloader)
    avg_ce_loss = train_ce_loss / len(dataloader)
    avg_kl_loss = train_kl_loss / len(dataloader)
    
    print(f"Epoch {epoch+1}/{num_epochs} Summary:")
    print(f"Train Loss: {avg_loss:.4f}")
    print(f"CE Loss: {avg_ce_loss:.4f}")
    print(f"KL Loss: {avg_kl_loss:.4f}")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    print("-" * 50)

Epoch 1/10
Batch 1/159 - Loss: 6.6765, CE: 6.6765, KL: 24.9309, LR: 0.000100
Batch 2/159 - Loss: 6.5619, CE: 6.5619, KL: 35.8271, LR: 0.000100
Batch 3/159 - Loss: 6.4514, CE: 6.4514, KL: 42.4187, LR: 0.000100
Batch 4/159 - Loss: 6.3990, CE: 6.3990, KL: 45.9711, LR: 0.000100
Batch 5/159 - Loss: 6.3037, CE: 6.3037, KL: 49.4307, LR: 0.000100
Batch 6/159 - Loss: 6.2655, CE: 6.2655, KL: 52.9448, LR: 0.000100


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x112c304c0>
Traceback (most recent call last):
  File "/Users/jackfoxabbott/TuneGenerator/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/Users/jackfoxabbott/TuneGenerator/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1582, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/multiprocessing/connecti

KeyboardInterrupt: 