In [1]:
%cd /home/jaeheonshim/music-vibes

/home/jaeheonshim/music-vibes


In [2]:
import torch
import numpy as np
from torch import nn, autocast
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import random_split, DataLoader, Subset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict

from vibenet import labels
from vibenet.dataset import FMAWaveformDataset
from vibenet.models.teacher import PANNsMLP
from vibenet.train_utils import compute_losses, compute_metrics

device = 'cuda'


Acousticness, Instrumentalness, and Liveness are **likelihoods** (e.g. a track having a 1.0 acousticness represents a high *confidence* that the track is acoustic). Thus, we use binary cross-entropy loss to optimize them.

Speechiness, Danceability, Energy, and Valence are **perceptual measures** (e.g. a danceability value of 0.0.0 is least danceable and 1.0 is most danceable). We use Huber and MSE loss to optimize Speechiness and Danceability, and concordance correlation coefficient (CCC) to optimize energy and valence (https://arxiv.org/abs/2003.10724)

### Load the dataset and perform train/val split

In [None]:
train_ds = FMAWaveformDataset('data/preprocessed/waveforms_train')
test_ds = FMAWaveformDataset('data/preprocessed/waveforms_val')

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)

### Initialize model and optimizer

In [4]:
model = PANNsMLP()
model = model.to(device)

NUM_EPOCHS = 25

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-5)

### Train/Validate Functions

In [5]:
writer = SummaryWriter()

scaler = torch.amp.GradScaler()

def train():
    global global_step
    
    model.train()
    
    train_losses = []

    with tqdm(train_dl, desc='Training') as pbar:
        for data, label in pbar:
            data, label = data.to(device).float(), label.to(device).float()
            optimizer.zero_grad()

            with autocast(device_type='cuda', dtype=torch.bfloat16):
                pred = model(data)
                
                loss_total = 0.0
                losses = compute_losses(pred, label)
                for k, l in losses.items():
                    writer.add_scalar(f"train/loss/{k}", l, global_step)
                    loss_total += l
                
            scaler.scale(loss_total).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss_total.item())
            
            mean_loss = np.mean(train_losses)
            writer.add_scalar(f"train/loss", mean_loss, global_step)
            pbar.set_postfix({'loss': f"{mean_loss:.4f}"})
            
            global_step += 1

def validate():
    global global_step
    
    model.eval()
    
    preds = defaultdict(list)
    targets = []

    with tqdm(test_dl, desc='Validation') as pbar:
        with torch.inference_mode():
            for data, label in pbar:
                data, label = data.to(device).float(), label.to(device).float()

                pred = model(data)
                for l, p in pred.items():
                    preds[l].append(p.detach().cpu())

                targets.append(label.detach().cpu())

    preds = {k: torch.cat(preds[k], dim=0) for k in preds.keys()}
    targets = torch.cat(targets, dim=0)

    losses = compute_losses(preds, targets)
    loss_total = 0.0
            
    for k,l in losses.items():
        loss_total += l
        
    writer.add_scalar(f"eval/loss", loss_total, global_step)
    
    metrics = compute_metrics(preds, targets)

    return metrics, loss_total.item()

In [6]:
global_step = 0

best_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}:")
    
    writer.add_scalar(f"epoch", epoch, global_step)

    train()
    metrics, loss_total = validate()

    print(f"Loss: {loss_total:.4f}")
    for task, stats in metrics.items():
        print(f"\n[{task}]")
        max_name_len = max(len(k) for k in stats.keys())
        for name, val in stats.items():
            writer.add_scalar(f"eval/{task}/{name}", val, global_step)
            
            if isinstance(val, float):
                print(f"  {name:<{max_name_len}} : {val:.4f}")
            else:
                print(f"  {name:<{max_name_len}} : {val}")

    if loss_total < best_loss:
        best_loss = loss_total
        torch.save({'state_dict': model.state_dict()},
                   'checkpoints/pretrained_PANN_best.pt')
        print('Saved new best model')
        
    scheduler.step()
    
    global_step += 1
    
writer.flush()


Epoch 1:


Training: 100%|██████████| 8/8 [00:09<00:00,  1.14s/it, loss=3.9124]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.24it/s]


Loss: 2.3906

[acousticness]
  logloss : 0.6125

[danceability]
  mse     : 0.0172
  rmse    : 0.1310
  mae     : 0.1111
  pearson : 0.8084
  ccc     : 0.6809
  r2      : 0.5013

[energy]
  mse     : 0.0622
  rmse    : 0.2493
  mae     : 0.2077
  pearson : 0.7795
  ccc     : 0.7614
  r2      : 0.4039

[instrumentalness]
  logloss : 0.6290

[liveness]
  logloss : 0.4195

[speechiness]
  mse     : 0.0164
  rmse    : 0.1282
  mae     : 0.0981
  pearson : -0.3925
  ccc     : -0.1727
  r2      : -7.2803

[valence]
  mse     : 0.0929
  rmse    : 0.3048
  mae     : 0.2364
  pearson : 0.5733
  ccc     : 0.5656
  r2      : 0.0672
Saved new best model

Epoch 2:


Training: 100%|██████████| 8/8 [00:06<00:00,  1.33it/s, loss=2.7092]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s]


Loss: 2.5135

[acousticness]
  logloss : 0.5945

[danceability]
  mse     : 0.0202
  rmse    : 0.1420
  mae     : 0.1122
  pearson : 0.7979
  ccc     : 0.6800
  r2      : 0.4136

[energy]
  mse     : 0.0680
  rmse    : 0.2607
  mae     : 0.2015
  pearson : 0.7640
  ccc     : 0.7151
  r2      : 0.3481

[instrumentalness]
  logloss : 0.5887

[liveness]
  logloss : 0.4434

[speechiness]
  mse     : 0.0384
  rmse    : 0.1958
  mae     : 0.1802
  pearson : -0.0499
  ccc     : -0.0072
  r2      : -18.3087

[valence]
  mse     : 0.0901
  rmse    : 0.3002
  mae     : 0.2468
  pearson : 0.6010
  ccc     : 0.5091
  r2      : 0.0957

Epoch 3:


Training: 100%|██████████| 8/8 [00:04<00:00,  1.61it/s, loss=2.4631]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.53it/s]


Loss: 2.3369

[acousticness]
  logloss : 0.5330

[danceability]
  mse     : 0.1083
  rmse    : 0.3292
  mae     : 0.3067
  pearson : 0.7815
  ccc     : 0.3378
  r2      : -2.1494

[energy]
  mse     : 0.0561
  rmse    : 0.2368
  mae     : 0.1904
  pearson : 0.7881
  ccc     : 0.7480
  r2      : 0.4625

[instrumentalness]
  logloss : 0.5452

[liveness]
  logloss : 0.4435

[speechiness]
  mse     : 0.0030
  rmse    : 0.0549
  mae     : 0.0418
  pearson : -0.0016
  ccc     : -0.0015
  r2      : -0.5156

[valence]
  mse     : 0.0759
  rmse    : 0.2755
  mae     : 0.2226
  pearson : 0.5953
  ccc     : 0.5526
  r2      : 0.2383
Saved new best model

Epoch 4:


Training: 100%|██████████| 8/8 [00:05<00:00,  1.58it/s, loss=2.3567]
Validation: 100%|██████████| 2/2 [00:01<00:00,  1.53it/s]


Loss: 2.0840

[acousticness]
  logloss : 0.4159

[danceability]
  mse     : 0.0789
  rmse    : 0.2808
  mae     : 0.2571
  pearson : 0.7632
  ccc     : 0.3389
  r2      : -1.2926

[energy]
  mse     : 0.0456
  rmse    : 0.2136
  mae     : 0.1706
  pearson : 0.7896
  ccc     : 0.7847
  r2      : 0.5624

[instrumentalness]
  logloss : 0.5353

[liveness]
  logloss : 0.4378

[speechiness]
  mse     : 0.0043
  rmse    : 0.0659
  mae     : 0.0556
  pearson : 0.2776
  ccc     : 0.2267
  r2      : -1.1845

[valence]
  mse     : 0.0689
  rmse    : 0.2626
  mae     : 0.1958
  pearson : 0.6282
  ccc     : 0.6100
  r2      : 0.3080
Saved new best model

Epoch 5:


Training:   0%|          | 0/8 [00:01<?, ?it/s]
Exception in thread Thread-14 (_pin_memory_loop):
Traceback (most recent call last):
  File "/usr/lib64/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/jaeheonshim/music-vibes/venv/lib64/python3.11/site-packages/ipykernel/ipkernel.py", line 772, in run_closure
    _threading_Thread_run(self)
  File "/usr/lib64/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jaeheonshim/music-vibes/venv/lib64/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 61, in _pin_memory_loop
    do_one_step()
  File "/home/jaeheonshim/music-vibes/venv/lib64/python3.11/site-packages/torch/utils/data/_utils/pin_memory.py", line 37, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.lo

KeyboardInterrupt: 