In [1]:
%cd /home/jaeheonshim/music-vibes

/home/jaeheonshim/music-vibes


In [2]:
import torch
import numpy as np
from torch import nn, autocast
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split, DataLoader, Subset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict

from vibenet import labels
from vibenet.dataset import FMAWaveformDataset
from vibenet.models.student import EfficientNetRegressor
from vibenet.train_utils import compute_losses, compute_metrics

device = 'cuda'

In [3]:
train_ds = FMAWaveformDataset('data/preprocessed/waveforms_distill_train')
test_ds = FMAWaveformDataset('data/preprocessed/waveforms_distill_test')

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

print(len(train_ds))

90437


In [4]:
net = EfficientNetRegressor()
net = net.to(device)

NUM_EPOCHS = 25

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

In [5]:
writer = SummaryWriter()

scaler = torch.amp.GradScaler()

def train():
    global global_step
    
    net.train()
    
    train_losses = []

    with tqdm(train_dl, desc='Training') as pbar:
        for data, label in pbar:
            data, label = data.to(device).float(), label.to(device).float()
            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.bfloat16):
                pred = net(data)
                
                loss_total = 0.0
                losses = compute_losses(pred, label)
                for k, l in losses.items():
                    loss_total += l
                
            scaler.scale(loss_total).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss_total.item())
            
            mean_loss = np.mean(train_losses)
            pbar.set_postfix({'loss': f"{mean_loss:.4f}"})
            
            global_step += 1

def validate():
    global global_step
    
    net.eval()
    
    preds = defaultdict(list)
    targets = []

    with tqdm(test_dl, desc='Validation') as pbar:
        with torch.inference_mode():
            for data, label in pbar:
                data, label = data.to(device).float(), label.to(device).float()

                pred = net(data)
                for l, p in pred.items():
                    preds[l].append(p.detach().cpu())

                targets.append(label.detach().cpu())

    preds = {k: torch.cat(preds[k], dim=0) for k in preds.keys()}
    targets = torch.cat(targets, dim=0)

    losses = compute_losses(preds, targets)
    loss_total = 0.0
            
    for k,l in losses.items():
        loss_total += l
            
    metrics = compute_metrics(preds, targets)

    return metrics, loss_total.item()

Teacher model metrics (for reference)

Loss: 2.1312

[acousticness]
  logloss : 0.5195

[danceability]
  mse     : 0.0179
  rmse    : 0.1338
  mae     : 0.1061
  pearson : 0.7192
  ccc     : 0.7056
  r2      : 0.5025

[energy]
  mse     : 0.0362
  rmse    : 0.1902
  mae     : 0.1468
  pearson : 0.8009
  ccc     : 0.7911
  r2      : 0.5287

[instrumentalness]
  logloss : 0.5282

[liveness]
  logloss : 0.4789

[speechiness]
  mse     : 0.0108
  rmse    : 0.1041
  mae     : 0.0552
  pearson : 0.7217
  ccc     : 0.7205
  r2      : 0.4706

[valence]
  mse     : 0.0631
  rmse    : 0.2512
  mae     : 0.1958
  pearson : 0.6587
  ccc     : 0.6437
  r2      : 0.1492

In [6]:
global_step = 0

best_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}:")
    
    writer.add_scalar(f"epoch", epoch, global_step)

    train()
    metrics, loss_total = validate()

    print(f"Loss: {loss_total:.4f}")
    for task, stats in metrics.items():
        print(f"\n[{task}]")
        max_name_len = max(len(k) for k in stats.keys())
        for name, val in stats.items():            
            if isinstance(val, float):
                print(f"  {name:<{max_name_len}} : {val:.4f}")
            else:
                print(f"  {name:<{max_name_len}} : {val}")

    if loss_total < best_loss:
        best_loss = loss_total
        torch.save({'state_dict': net.state_dict()},
                   'checkpoints/efficientnet_best.pt')
        print('Saved new best model')
        
    scheduler.step()
    
    global_step += 1
    
writer.flush()


Epoch 1:


Training: 100%|██████████| 707/707 [11:36<00:00,  1.01it/s, loss=2.1483]
Validation:  60%|█████▉    | 845/1414 [02:32<01:42,  5.53it/s]
Exception in thread Thread-7 (_pin_memory_loop):
Traceback (most recent call last):
  File "/usr/lib64/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/jaeheonshim/music-vibes/venv/lib64/python3.11/site-packages/ipykernel/ipkernel.py", line 772, in run_closure
    _threading_Thread_run(self)
  File "/usr/lib64/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jaeheonshim/music-vibes/venv/lib64/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 61, in _pin_memory_loop
    do_one_step()
  File "/home/jaeheonshim/music-vibes/venv/lib64/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 37, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib6

KeyboardInterrupt: 