In [15]:
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import MultilabelAUROC

from sound_stamp.datasets import MagnaSet
from sound_stamp.tagger import MusicTagger

In [16]:
dataset = MagnaSet()

train, val, test = dataset.random_split(random_state=42)
train_loader = DataLoader(train, batch_size=128, shuffle=True)
val_loader = DataLoader(val, batch_size=128, shuffle=True)
test_loader = DataLoader(test, batch_size=128)

Downloading MagnaTagATune ... 
Dataset already downloaded. Remove 'f:\Git\sound-stamp\data\magna' to download again.


In [17]:
tagger = MusicTagger(dataset.class_names)
tagger.load("models/fcn_tagger.pt")

Model parameters loaded from 'models/fcn_tagger.pt'.


In [6]:
# Training
num_epochs = 5

for epoch in range(num_epochs):
    train_loss = tagger.train(train_loader)
    val_loss = tagger.evaluate(val_loader)

    print(f"{epoch+1}/{num_epochs}: {train_loss=:.4f}, {val_loss=:.4f}")
    
tagger.save("models/fcn_tagger.pt")

1/5: train_loss=0.6253, val_loss=0.6109
2/5: train_loss=0.4866, val_loss=0.4123
3/5: train_loss=0.3540, val_loss=0.2781
4/5: train_loss=0.2698, val_loss=0.2359
5/5: train_loss=0.2201, val_loss=0.2104
Model parameters saved to 'fcn_tagger_1.model'.


In [18]:
# Predictions
predicted_tags = []
true_tags = []

for features, targets in test_loader:
    output = tagger.inference(features)
    predicted_tags.append(output)
    true_tags.append(targets)

predicted_tags = torch.vstack(predicted_tags).cpu()
true_tags = torch.vstack(true_tags).type(torch.int8).cpu()

In [20]:
# Area Under the Receiver Operating Characteristic Curve
ml_auroc = MultilabelAUROC(num_labels=len(dataset.class_names), average="micro", thresholds=None)
ml_auroc(predicted_tags, true_tags).item()

0.8982025980949402