In [7]:
%cd /home/jaeheonshim/music-vibes

/home/jaeheonshim/music-vibes


In [8]:
import torch
import numpy as np
from torch import nn, autocast
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split, DataLoader, Subset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict

from vibenet import labels
from vibenet.dataset import FMAWaveformDataset
from vibenet.models.teacher import PANNsMLP
from vibenet.train_utils import compute_losses, compute_metrics

device = 'cuda'


Acousticness, Instrumentalness, and Liveness are **likelihoods** (e.g. a track having a 1.0 acousticness represents a high *confidence* that the track is acoustic). Thus, we use binary cross-entropy loss to optimize them.

Speechiness, Danceability, Energy, and Valence are **perceptual measures** (e.g. a danceability value of 0.0.0 is least danceable and 1.0 is most danceable). We use Huber and MSE loss to optimize Speechiness and Danceability, and concordance correlation coefficient (CCC) to optimize energy and valence (https://arxiv.org/abs/2003.10724)

### Load the dataset and perform train/val split

In [9]:
train_ds = Subset(FMAWaveformDataset('data/preprocessed/waveforms_train'), range(1024))
test_ds = FMAWaveformDataset('data/preprocessed/waveforms_val')

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

### Initialize model and optimizer

In [10]:
model = PANNsMLP()
model = model.to(device)

NUM_EPOCHS = 25

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

### Train/Validate Functions

In [11]:
writer = SummaryWriter()

scaler = torch.amp.GradScaler()

def train():
    global global_step
    
    model.train()
    
    train_losses = []

    with tqdm(train_dl, desc='Training') as pbar:
        for data, label in pbar:
            data, label = data.to(device).float(), label.to(device).float()
            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.bfloat16):
                pred = model(data)
                
                loss_total = 0.0
                losses = compute_losses(pred, label)
                for k, l in losses.items():
                    writer.add_scalar(f"train/loss/{k}", l, global_step)
                    loss_total += l
                
            scaler.scale(loss_total).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss_total.item())
            
            mean_loss = np.mean(train_losses)
            writer.add_scalar(f"train/loss", mean_loss, global_step)
            pbar.set_postfix({'loss': f"{mean_loss:.4f}"})
            
            global_step += 1

def validate():
    global global_step
    
    model.eval()
    
    preds = defaultdict(list)
    targets = []

    with tqdm(test_dl, desc='Validation') as pbar:
        with torch.inference_mode():
            for data, label in pbar:
                data, label = data.to(device).float(), label.to(device).float()

                pred = model(data)
                for l, p in pred.items():
                    preds[l].append(p.detach().cpu())

                targets.append(label.detach().cpu())

    preds = {k: torch.cat(preds[k], dim=0) for k in preds.keys()}
    targets = torch.cat(targets, dim=0)

    losses = compute_losses(preds, targets)
    loss_total = 0.0
            
    for k,l in losses.items():
        loss_total += l
        
    writer.add_scalar(f"eval/loss", loss_total, global_step)
    
    metrics = compute_metrics(preds, targets)

    return metrics, loss_total.item()

In [None]:
global_step = 0

best_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}:")
    
    writer.add_scalar(f"epoch", epoch, global_step)

    train()
    metrics, loss_total = validate()

    print(f"Loss: {loss_total:.4f}")
    for task, stats in metrics.items():
        print(f"\n[{task}]")
        max_name_len = max(len(k) for k in stats.keys())
        for name, val in stats.items():
            writer.add_scalar(f"eval/{task}/{name}", val, global_step)
            
            if isinstance(val, float):
                print(f"  {name:<{max_name_len}} : {val:.4f}")
            else:
                print(f"  {name:<{max_name_len}} : {val}")

    if loss_total < best_loss:
        best_loss = loss_total
        torch.save({'state_dict': model.state_dict()},
                   'checkpoints/pretrained_PANN_best.pt')
        print('Saved new best model')
        
    scheduler.step()
    
    global_step += 1
    
writer.flush()


Epoch 1:


Training: 100%|██████████| 8/8 [00:11<00:00,  1.43s/it, loss=4.1137]
Validation:  52%|█████▏    | 234/451 [01:24<01:08,  3.18it/s]