In [17]:
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import MultilabelAUROC, MulticlassROC

from sound_stamp.datasets import MagnaSet
from sound_stamp.tagger import MusicTagger

In [2]:
dataset = MagnaSet()

train, val, test = dataset.random_split(random_state=42)
train_loader = DataLoader(train, batch_size=128, shuffle=True)
val_loader = DataLoader(val, batch_size=128, shuffle=True)
test_loader = DataLoader(test, batch_size=128)

Downloading MagnaTagATune ... 
Dataset already downloaded. Remove 'f:\Git\sound-stamp\data\magna' to download again.


In [3]:
tagger = MusicTagger(dataset.class_names)
tagger.load("models/fcn_tagger_final.pt")

Model parameters loaded from 'models/fcn_tagger_final.pt'.


In [None]:
# Training
NUM_EPOCHS = 10
PATIENCE = 3
REFINEMENT = 2
REFINEMENT_LR_FACTOR = 0.1
LEARNING_RATE = 3e-4

best_val_loss = float("inf")
cur_patience = PATIENCE
cur_refinement = REFINEMENT
cur_learning_rate = LEARNING_RATE

print("Training ...")
for epoch in range(NUM_EPOCHS):
    train_loss = tagger.train(train_loader, cur_learning_rate)
    val_loss = tagger.evaluate(val_loader)
    print(f"{epoch+1}/{NUM_EPOCHS}: {train_loss=:.4f}, {val_loss=:.4f}")
    
    # Early stopping with refinement
    if val_loss < best_val_loss:
        tagger.save("models/fcn_tagger_ckpt.pt", verbose=False)
        best_val_loss = val_loss
        cur_patience = PATIENCE        
    else:
        if cur_patience == 0:
            if cur_refinement > 0:
                tagger.load("models/fcn_tagger_ckpt.pt", verbose=False)
                cur_learning_rate = cur_learning_rate * REFINEMENT_LR_FACTOR                
                cur_patience = PATIENCE
                cur_refinement -= 1
                print(f"Refinement with new learning rate {cur_learning_rate}.")
            else:
                print("Stopped early!")
                break
        else:
            cur_patience -= 1    

# Save the best model with a new name
tagger.load("models/fcn_tagger_ckpt.pt", verbose=False)
tagger.save("models/fcn_tagger_final.pt")

In [4]:
# Predictions
predicted_tags = []
true_tags = []

for features, targets in test_loader:
    output = tagger.inference(features)
    predicted_tags.append(output)
    true_tags.append(targets)

predicted_tags = torch.vstack(predicted_tags).cpu()
true_tags = torch.vstack(true_tags).type(torch.int8).cpu()

In [20]:
# Area Under the Receiver Operating Characteristic Curve
auroc = MultilabelAUROC(num_labels=len(dataset.class_names), average="micro", thresholds=None)
auroc(predicted_tags, true_tags).item()

0.9183872938156128