In [1]:
%cd /home/jaeheonshim/music-vibes

/home/jaeheonshim/music-vibes


In [2]:
import torch
import numpy as np
from torch import nn, autocast
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split, DataLoader, Subset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict

from vibenet import labels
from vibenet.dataset import FMAWaveformDataset
from vibenet.models.teacher import PANNsMLP

device = 'cuda'


Acousticness, Instrumentalness, and Liveness are **likelihoods** (e.g. a track having a 1.0 acousticness represents a high *confidence* that the track is acoustic). Thus, we use binary cross-entropy loss to optimize them.

Speechiness, Danceability, Energy, and Valence are **perceptual measures** (e.g. a danceability value of 0.0.0 is least danceable and 1.0 is most danceable). We use Huber and MSE loss to optimize Speechiness and Danceability, and concordance correlation coefficient (CCC) to optimize energy and valence (https://arxiv.org/abs/2003.10724)

### Define loss and validation metric functions

In [3]:
def pearsonr(x, y, eps=1e-8):
    x = x - x.mean()
    y = y - y.mean()
    return (x*y).mean() / (x.std(unbiased=False)*y.std(unbiased=False) + eps)

def ccc(x, y, eps=1e-8):
    mx, my = x.mean(), y.mean()
    vx, vy = x.var(unbiased=False), y.var(unbiased=False)
    cov = ((x-mx)*(y-my)).mean()
    return (2*cov) / (vx + vy + (mx-my).pow(2) + eps)

### Load the dataset and perform train/val split

In [4]:
train_ds = FMAWaveformDataset('data/preprocessed/waveforms_train')
test_ds = FMAWaveformDataset('data/preprocessed/waveforms_val')

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

### Initialize loss functions

In [None]:
huber = nn.SmoothL1Loss(beta=0.2)
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()

# Compute losses for training
def compute_losses(pred, label):
    return {
        'acousticness_bce': bce(pred['acousticness'], label[:,0]),
        'liveness_bce': bce(pred['liveness'], label[:,4]),
        'instrumentalness_bce': bce(pred['instrumentalness'], label[:,3]),
        'speechiness_huber': huber(pred['speechiness'], label[:,5]),
        'danceability_mse': mse(pred['danceability'], label[:,1]),
        'energy_ccc_loss': 1 - ccc(pred['energy'], label[:,2]),
        'valence_ccc_loss': 1 - ccc(pred['valence'], label[:,6]),
    }
    
LIKELIHOODS = {'acousticness','liveness','instrumentalness'}
CONTINUOUS  = {'speechiness', 'danceability','energy','valence'}    

def compute_metrics(pred, label):
    out = {}
    for i, name in enumerate(labels):
        y = label[:, i]
        yhat = pred[name].squeeze(-1) if pred[name].ndim > 1 else pred[name]
        m = {}
        
        if name in LIKELIHOODS:
            m['logloss'] = bce(yhat, y)
        else:
            mse_v = F.mse_loss(yhat, y)
            mae_v = torch.mean(torch.abs(yhat - y))
            m['mse'] = mse_v
            m['rmse'] = torch.sqrt(mse_v)
            m['mae'] = mae_v
            m['pearson'] = pearsonr(yhat, y)
            m['ccc'] = ccc(yhat, y)
            var_y = y.var(unbiased=False) + 1e-8
            m['r2'] = 1.0 - (mse_v / var_y)

        out[name] = {k: (v.item() if torch.is_tensor(v) else float(v)) for k, v in m.items()}
        
    return out

### Initialize model and optimizer

In [6]:
model = PANNsMLP()
model = model.to(device)

NUM_EPOCHS = 25

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

FileNotFoundError: [Errno 2] No such file or directory: 'Cnn14_mAP=0.431.pth'

### Train/Validate Functions

In [None]:
writer = SummaryWriter()

scaler = torch.amp.GradScaler()

def train():
    global global_step
    
    model.train()
    
    train_losses = []

    with tqdm(train_dl, desc='Training') as pbar:
        for data, label in pbar:
            data, label = data.to(device).float(), label.to(device).float()
            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.bfloat16):
                pred = model(data)
                
                loss_total = 0.0
                losses = compute_losses(pred, label)
                for k, l in losses.items():
                    writer.add_scalar(f"train/loss/{k}", l, global_step)
                    loss_total += l
                
            scaler.scale(loss_total).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss_total.item())
            
            mean_loss = np.mean(train_losses)
            writer.add_scalar(f"train/loss", mean_loss, global_step)
            pbar.set_postfix({'loss': f"{mean_loss:.4f}"})
            
            global_step += 1

def validate():
    global global_step
    
    model.eval()
    
    preds = defaultdict(list)
    targets = []

    with tqdm(test_dl, desc='Validation') as pbar:
        with torch.inference_mode():
            for data, label in pbar:
                data, label = data.to(device).float(), label.to(device).float()

                pred = model(data)
                for l, p in pred.items():
                    preds[l].append(p.detach().cpu())

                targets.append(label.detach().cpu())

    preds = {k: torch.cat(preds[k], dim=0) for k in preds.keys()}
    targets = torch.cat(targets, dim=0)

    losses = compute_losses(preds, targets)
    loss_total = 0.0
            
    for k,l in losses.items():
        loss_total += l
        
    writer.add_scalar(f"eval/loss", loss_total, global_step)
    
    metrics = compute_metrics(preds, targets)

    return metrics, loss_total.item()

In [None]:
global_step = 0

best_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}:")
    
    writer.add_scalar(f"epoch", epoch, global_step)

    train()
    metrics, loss_total = validate()

    print(f"Loss: {loss_total:.4f}")
    for task, stats in metrics.items():
        print(f"\n[{task}]")
        max_name_len = max(len(k) for k in stats.keys())
        for name, val in stats.items():
            writer.add_scalar(f"eval/{task}/{name}", val, global_step)
            
            if isinstance(val, float):
                print(f"  {name:<{max_name_len}} : {val:.4f}")
            else:
                print(f"  {name:<{max_name_len}} : {val}")

    if loss_total < best_loss:
        best_loss = loss_total
        torch.save({'state_dict': model.state_dict()},
                   'checkpoints/pretrained_PANN_best.pt')
        print('Saved new best model')
        
    scheduler.step()
    
    global_step += 1
    
writer.flush()


Epoch 1:


Training: 100%|██████████| 8/8 [00:06<00:00,  1.27it/s, loss=3.9243]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


Loss: 2.7351

[acousticness]
  logloss : 0.6202

[danceability]
  mse     : 0.0316
  rmse    : 0.1779
  mae     : 0.1364
  pearson : 0.5124
  ccc     : 0.3944
  r2      : 0.0805

[energy]
  mse     : 0.0825
  rmse    : 0.2873
  mae     : 0.2274
  pearson : 0.7282
  ccc     : 0.6703
  r2      : 0.2086

[instrumentalness]
  logloss : 0.7366

[liveness]
  logloss : 0.4188

[speechiness]
  logloss : 0.6892

[valence]
  mse     : 0.1592
  rmse    : 0.3990
  mae     : 0.3531
  pearson : 0.7234
  ccc     : 0.4249
  r2      : -0.5980
Saved new best model

Epoch 2:


Training: 100%|██████████| 8/8 [00:06<00:00,  1.24it/s, loss=2.7313]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


Loss: 2.1539

[acousticness]
  logloss : 0.5686

[danceability]
  mse     : 0.0129
  rmse    : 0.1134
  mae     : 0.0979
  pearson : 0.8064
  ccc     : 0.7377
  r2      : 0.6263

[energy]
  mse     : 0.0479
  rmse    : 0.2188
  mae     : 0.1625
  pearson : 0.7958
  ccc     : 0.7867
  r2      : 0.5408

[instrumentalness]
  logloss : 0.6037

[liveness]
  logloss : 0.4285

[speechiness]
  logloss : 0.7071

[valence]
  mse     : 0.0591
  rmse    : 0.2430
  mae     : 0.1910
  pearson : 0.6978
  ccc     : 0.6904
  r2      : 0.4071
Saved new best model

Epoch 3:


Training: 100%|██████████| 8/8 [00:05<00:00,  1.37it/s, loss=2.4648]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


Loss: 2.2083

[acousticness]
  logloss : 0.5124

[danceability]
  mse     : 0.0286
  rmse    : 0.1691
  mae     : 0.1351
  pearson : 0.8380
  ccc     : 0.5606
  r2      : 0.1687

[energy]
  mse     : 0.0651
  rmse    : 0.2552
  mae     : 0.1856
  pearson : 0.7284
  ccc     : 0.7027
  r2      : 0.3755

[instrumentalness]
  logloss : 0.5520

[liveness]
  logloss : 0.4246

[speechiness]
  logloss : 0.7729

[valence]
  mse     : 0.0682
  rmse    : 0.2612
  mae     : 0.1996
  pearson : 0.6552
  ccc     : 0.6547
  r2      : 0.3150

Epoch 4:


Training: 100%|██████████| 8/8 [00:05<00:00,  1.42it/s, loss=2.3891]
Validation:   0%|          | 0/2 [00:01<?, ?it/s]


KeyboardInterrupt: 