In [1]:
%cd /home/jaeheonshim/music-vibes

/home/jaeheonshim/music-vibes


In [2]:
import torch
import numpy as np
from torch import nn, autocast
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split, DataLoader, Subset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict

from vibenet import labels
from vibenet.dataset import FMAWaveformDataset
from vibenet.models.teacher import PANNsMLP
from vibenet.train_utils import compute_losses, compute_metrics

device = 'cuda'


Acousticness, Instrumentalness, and Liveness are **likelihoods** (e.g. a track having a 1.0 acousticness represents a high *confidence* that the track is acoustic). Thus, we use binary cross-entropy loss to optimize them.

Speechiness, Danceability, Energy, and Valence are **perceptual measures** (e.g. a danceability value of 0.0.0 is least danceable and 1.0 is most danceable). We use Huber and MSE loss to optimize Speechiness and Danceability, and concordance correlation coefficient (CCC) to optimize energy and valence (https://arxiv.org/abs/2003.10724)

### Load the dataset and perform train/val split

In [3]:
train_ds = FMAWaveformDataset('data/preprocessed/waveforms_train')
test_ds = FMAWaveformDataset('data/preprocessed/waveforms_val')

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

### Initialize model and optimizer

In [4]:
# Additional helper functions for training last layer of PANN Cnn14
def set_bn_eval(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()

def trainable_params(module):
    return [p for p in module.parameters() if p.requires_grad]

In [5]:
model = PANNsMLP()
model = model.to(device)

NUM_EPOCHS = 25

pg_backbone6 = trainable_params(model.pann.conv_block6)
pg_heads = trainable_params(model.trunk) + trainable_params(model.heads)

optimizer = torch.optim.Adam([
    {"params": pg_backbone6, "lr": 1e-5, "weight_decay": 1e-4},
    {"params": pg_heads, "lr": 1e-3, "weight_decay": 5e-4},
])

scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

### Train/Validate Functions

In [6]:
writer = SummaryWriter()

scaler = torch.amp.GradScaler()

def train():
    global global_step
    
    model.apply(set_bn_eval)
    model.train()
    
    train_losses = []

    with tqdm(train_dl, desc='Training') as pbar:
        for data, label in pbar:
            data, label = data.to(device).float(), label.to(device).float()
            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.bfloat16):
                pred = model(data)
                
                loss_total = 0.0
                losses = compute_losses(pred, label)
                for k, l in losses.items():
                    writer.add_scalar(f"train/loss/{k}", l, global_step)
                    loss_total += l
                
            scaler.scale(loss_total).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss_total.item())
            
            mean_loss = np.mean(train_losses)
            writer.add_scalar(f"train/loss", mean_loss, global_step)
            pbar.set_postfix({'loss': f"{mean_loss:.4f}"})
            
            global_step += 1

def validate():
    global global_step
    
    model.eval()
    
    preds = defaultdict(list)
    targets = []

    with tqdm(test_dl, desc='Validation') as pbar:
        with torch.inference_mode():
            for data, label in pbar:
                data, label = data.to(device).float(), label.to(device).float()

                pred = model(data)
                for l, p in pred.items():
                    preds[l].append(p.detach().cpu())

                targets.append(label.detach().cpu())

    preds = {k: torch.cat(preds[k], dim=0) for k in preds.keys()}
    targets = torch.cat(targets, dim=0)

    losses = compute_losses(preds, targets)
    loss_total = 0.0
            
    for k,l in losses.items():
        loss_total += l
        
    writer.add_scalar(f"eval/loss", loss_total, global_step)
    
    metrics = compute_metrics(preds, targets)

    return metrics, loss_total.item()

In [None]:
global_step = 0

best_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}:")
    
    writer.add_scalar(f"epoch", epoch, global_step)

    train()
    metrics, loss_total = validate()

    print(f"Loss: {loss_total:.4f}")
    for task, stats in metrics.items():
        print(f"\n[{task}]")
        max_name_len = max(len(k) for k in stats.keys())
        for name, val in stats.items():
            writer.add_scalar(f"eval/{task}/{name}", val, global_step)
            
            if isinstance(val, float):
                print(f"  {name:<{max_name_len}} : {val:.4f}")
            else:
                print(f"  {name:<{max_name_len}} : {val}")

    if loss_total < best_loss:
        best_loss = loss_total
        torch.save({'state_dict': model.state_dict()},
                   'checkpoints/pretrained_PANN_best.pt')
        print('Saved new best model')
        
    scheduler.step()
    
    global_step += 1
    
writer.flush()


Epoch 1:
Training:  69%|███████████████████████████████████████████████████████████████████████████████████████▌                                       | 870/1261 [29:34<08:59,  1.38s/it, loss=1.9525]Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:37<00:00,  2.03s/it, loss=1.9502]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:43<00:00,  2.77it/s]


Training:   0%|                                                                                                                                                      | 0/1261 [00:00<?, ?it/s]

Training:   4%|████▊                                                                                                                           | 48/1261 [01:39<16:14,  1.24it/s, loss=2.7747]

Loss: 1.9666

[acousticness]
  logloss : 0.4856

[danceability]
  mse     : 0.0162
  rmse    : 0.1272
  mae     : 0.1005
  pearson : 0.7505
  ccc     : 0.7418
  r2      : 0.5503

[energy]
  mse     : 0.0266
  rmse    : 0.1631
  mae     : 0.1224
  pearson : 0.8482
  ccc     : 0.8432
  r2      : 0.6534

[instrumentalness]
  logloss : 0.5079

[liveness]
  logloss : 0.4757

[speechiness]
  mse     : 0.0104
  rmse    : 0.1019
  mae     : 0.0492
  pearson : 0.7346
  ccc     : 0.7322
  r2      : 0.4923

[valence]
  mse     : 0.0505
  rmse    : 0.2247
  mae     : 0.1704
  pearson : 0.6995
  ccc     : 0.6956
  r2      : 0.3194

Epoch 20:


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:41<00:00,  2.03s/it, loss=1.9413]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:43<00:00,  2.76it/s]


Loss: 1.9701

[acousticness]
  logloss : 0.4817

[danceability]
  mse     : 0.0166
  rmse    : 0.1288
  mae     : 0.1017
  pearson : 0.7462
  ccc     : 0.7400
  r2      : 0.5391

[energy]
  mse     : 0.0290
  rmse    : 0.1704
  mae     : 0.1298
  pearson : 0.8493
  ccc     : 0.8375
  r2      : 0.6214

[instrumentalness]
  logloss : 0.5088

[liveness]
  logloss : 0.4758

[speechiness]
  mse     : 0.0104
  rmse    : 0.1019
  mae     : 0.0482
  pearson : 0.7335
  ccc     : 0.7306
  r2      : 0.4927

[valence]
  mse     : 0.0519
  rmse    : 0.2279
  mae     : 0.1719
  pearson : 0.7014
  ccc     : 0.6951
  r2      : 0.2997

Epoch 21:


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:41<00:00,  2.03s/it, loss=1.9337]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:43<00:00,  2.76it/s]


Loss: 1.9645

[acousticness]
  logloss : 0.4818

[danceability]
  mse     : 0.0165
  rmse    : 0.1284
  mae     : 0.1013
  pearson : 0.7503
  ccc     : 0.7457
  r2      : 0.5419

[energy]
  mse     : 0.0284
  rmse    : 0.1685
  mae     : 0.1278
  pearson : 0.8486
  ccc     : 0.8391
  r2      : 0.6299

[instrumentalness]
  logloss : 0.5069

[liveness]
  logloss : 0.4757

[speechiness]
  mse     : 0.0105
  rmse    : 0.1026
  mae     : 0.0493
  pearson : 0.7307
  ccc     : 0.7293
  r2      : 0.4849

[valence]
  mse     : 0.0514
  rmse    : 0.2267
  mae     : 0.1718
  pearson : 0.7036
  ccc     : 0.6972
  r2      : 0.3071
Saved new best model

Epoch 22:


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:40<00:00,  2.03s/it, loss=1.9259]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:43<00:00,  2.75it/s]


Loss: 1.9641

[acousticness]
  logloss : 0.4807

[danceability]
  mse     : 0.0162
  rmse    : 0.1273
  mae     : 0.1004
  pearson : 0.7523
  ccc     : 0.7458
  r2      : 0.5501

[energy]
  mse     : 0.0273
  rmse    : 0.1652
  mae     : 0.1245
  pearson : 0.8480
  ccc     : 0.8419
  r2      : 0.6443

[instrumentalness]
  logloss : 0.5078

[liveness]
  logloss : 0.4755

[speechiness]
  mse     : 0.0105
  rmse    : 0.1026
  mae     : 0.0491
  pearson : 0.7319
  ccc     : 0.7306
  r2      : 0.4848

[valence]
  mse     : 0.0521
  rmse    : 0.2283
  mae     : 0.1711
  pearson : 0.7015
  ccc     : 0.6941
  r2      : 0.2972
Saved new best model

Epoch 23:


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:39<00:00,  2.03s/it, loss=1.9169]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:43<00:00,  2.76it/s]


Loss: 1.9599

[acousticness]
  logloss : 0.4784

[danceability]
  mse     : 0.0161
  rmse    : 0.1271
  mae     : 0.1003
  pearson : 0.7520
  ccc     : 0.7445
  r2      : 0.5515

[energy]
  mse     : 0.0276
  rmse    : 0.1661
  mae     : 0.1250
  pearson : 0.8491
  ccc     : 0.8408
  r2      : 0.6403

[instrumentalness]
  logloss : 0.5061

[liveness]
  logloss : 0.4756

[speechiness]
  mse     : 0.0108
  rmse    : 0.1038
  mae     : 0.0493
  pearson : 0.7326
  ccc     : 0.7321
  r2      : 0.4731

[valence]
  mse     : 0.0514
  rmse    : 0.2267
  mae     : 0.1703
  pearson : 0.7013
  ccc     : 0.6958
  r2      : 0.3071
Saved new best model

Epoch 24:


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:41<00:00,  2.03s/it, loss=1.9113]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:44<00:00,  2.75it/s]


Loss: 1.9628

[acousticness]
  logloss : 0.4817

[danceability]
  mse     : 0.0162
  rmse    : 0.1273
  mae     : 0.1006
  pearson : 0.7508
  ccc     : 0.7431
  r2      : 0.5500

[energy]
  mse     : 0.0276
  rmse    : 0.1660
  mae     : 0.1251
  pearson : 0.8491
  ccc     : 0.8420
  r2      : 0.6408

[instrumentalness]
  logloss : 0.5066

[liveness]
  logloss : 0.4754

[speechiness]
  mse     : 0.0106
  rmse    : 0.1028
  mae     : 0.0493
  pearson : 0.7357
  ccc     : 0.7352
  r2      : 0.4832

[valence]
  mse     : 0.0518
  rmse    : 0.2275
  mae     : 0.1703
  pearson : 0.7009
  ccc     : 0.6952
  r2      : 0.3020

Epoch 25:


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1261/1261 [42:40<00:00,  2.03s/it, loss=1.9021]
Validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [02:44<00:00,  2.74it/s]


Loss: 1.9593

[acousticness]
  logloss : 0.4792

[danceability]
  mse     : 0.0164
  rmse    : 0.1279
  mae     : 0.1009
  pearson : 0.7502
  ccc     : 0.7430
  r2      : 0.5457

[energy]
  mse     : 0.0275
  rmse    : 0.1657
  mae     : 0.1249
  pearson : 0.8506
  ccc     : 0.8430
  r2      : 0.6419

[instrumentalness]
  logloss : 0.5066

[liveness]
  logloss : 0.4755

[speechiness]
  mse     : 0.0108
  rmse    : 0.1042
  mae     : 0.0495
  pearson : 0.7332
  ccc     : 0.7330
  r2      : 0.4695

[valence]
  mse     : 0.0516
  rmse    : 0.2271
  mae     : 0.1697
  pearson : 0.7028
  ccc     : 0.6959
  r2      : 0.3044
Saved new best model


After 25 epochs, model achieved validation loss of 2.1312