# Notebook 04: YAMNet Transfer Learning Pipeline

## Overview
This notebook implements a transfer learning approach using Google's YAMNet pretrained audio model. YAMNet embeddings capture rich acoustic features from AudioSet pretraining, which are then used to train a lightweight classifier for instrument family classification.

## Workflow
1. **Model and Data Setup** — Load YAMNet model from TensorFlow Hub and configure data splits
2. **Embedding Extraction** — Extract 1024-dimensional YAMNet embeddings for all audio clips and cache to disk
3. **Classifier Architecture** — Define a simple feedforward classifier on top of frozen YAMNet embeddings
4. **Training** — Train classifier with cross-entropy loss and AdamW optimizer
5. **Evaluation** — Assess model performance on test set with detailed metrics and confusion matrix
6. **Visualization** — Plot training curves and interactive confusion matrix with Bokeh

---

In [1]:
# --- Imports and configuration ---

import warnings
warnings.filterwarnings('ignore')

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'

import tensorflow_hub as hub
import tensorflow as tf
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from tqdm import tqdm
import soundfile as sf

from bokeh.plotting import figure, show, output_notebook
from bokeh.layouts import row
from bokeh.models import HoverTool, LinearColorMapper, ColorBar, BasicTicker, ColumnDataSource
from bokeh.transform import transform
from bokeh.palettes import Blues9

output_notebook()

tf.get_logger().setLevel('ERROR')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# Configuration
PROJECT_ROOT = Path("/Users/dghifari/02-University/SEM-2-2025/elec5305-project-520140154")
manifests_dir = PROJECT_ROOT / "Manifests"
train_csv = manifests_dir / "train.csv"
val_csv   = manifests_dir / "val.csv"
test_csv  = manifests_dir / "test.csv"

FAMILY_COLNAME = "family_label"
SAMPLE_RATE = 16000
DURATION_SECONDS = 3
TARGET_NUM_SAMPLES = SAMPLE_RATE * DURATION_SECONDS
BATCH_SIZE = 4
EPOCHS = 25
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

print(f"Device: {DEVICE}")

Device: mps


In [2]:
class Normalize:
    def __call__(self, x: torch.Tensor):
        x = x / (x.abs().max() + 1e-9)
        rms = x.pow(2).mean().sqrt()
        if rms > 0:
            x = x / (rms + 1e-9) * 0.1
        return x

class AudioDatasetForPreprocessing(Dataset):
    """Dataset for loading raw audio files"""
    def __init__(self, csv_path, label_map):
        self.df = pd.read_csv(csv_path)
        self.label_map = label_map
        self.norm = Normalize()
        self.target_length = TARGET_NUM_SAMPLES

    def _fix_length(self, wav: torch.Tensor, target_len: int):
        T = wav.shape[-1]
        if T == target_len:
            return wav
        if T > target_len:
            start = (T - target_len) // 2
            return wav[..., start:start + target_len]
        pad_len = target_len - T
        return torch.nn.functional.pad(wav, (0, pad_len))

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        wav_np, sr = sf.read(row['filepath'], dtype='float32')
        
        wav = torch.from_numpy(wav_np)
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        else:
            wav = wav.T
        
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        
        if sr != SAMPLE_RATE:
            import torchaudio.functional as F
            wav = F.resample(wav, sr, SAMPLE_RATE)
        
        wav = self._fix_length(wav, self.target_length)
        wav = self.norm(wav)
        wav = wav.squeeze(0)
        
        label = self.label_map[row[FAMILY_COLNAME]]
        return wav, label, row['filepath']

    def __len__(self):
        return len(self.df)

In [3]:
# Load metadata and create label mappings
df_train = pd.read_csv(train_csv)
families = sorted(df_train[FAMILY_COLNAME].unique())
family_to_idx = {f:i for i,f in enumerate(families)}
idx_to_family = {i:f for f,i in family_to_idx.items()}
num_classes = len(family_to_idx)

print(f"Classes: {families}")
print(f"Train: {len(pd.read_csv(train_csv))} | Val: {len(pd.read_csv(val_csv))} | Test: {len(pd.read_csv(test_csv))}")

Classes: ['keyboards', 'percussion', 'strings', 'voice', 'winds']
Train: 685 | Val: 79 | Test: 110


In [4]:
# Preprocess: Extract YAMNet embeddings and save to disk

def extract_yamnet_embeddings_batched(yamnet_model, waveforms_batch):
    """Extract YAMNet embeddings for a batch of waveforms"""
    embeddings = []
    for waveform in waveforms_batch:
        with tf.device('/CPU:0'):
            _, embeddings_tf, _ = yamnet_model(waveform)
            embedding = tf.reduce_mean(embeddings_tf, axis=0)
            embeddings.append(embedding.numpy())
    return np.array(embeddings)

def preprocess_and_save_embeddings(split_name, csv_path, yamnet_model, output_path):
    """Extract and save YAMNet embeddings for a dataset split"""
    dataset = AudioDatasetForPreprocessing(csv_path, family_to_idx)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    all_embeddings = []
    all_labels = []
    all_filepaths = []
    
    for waveforms, labels, filepaths in tqdm(loader, desc=f"{split_name}"):
        waveforms_np = waveforms.numpy().astype(np.float32)
        embeddings = extract_yamnet_embeddings_batched(yamnet_model, waveforms_np)
        all_embeddings.append(embeddings)
        all_labels.extend(labels.numpy())
        all_filepaths.extend(filepaths)
    
    all_embeddings = np.vstack(all_embeddings)
    all_labels = np.array(all_labels)
    
    np.savez_compressed(output_path, embeddings=all_embeddings, labels=all_labels, filepaths=all_filepaths)
    print(f"Saved {len(all_labels)} embeddings: {all_embeddings.shape}")

# Load YAMNet
yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')

# Define output paths in Results directory
results_dir = PROJECT_ROOT / "Results"
results_dir.mkdir(exist_ok=True)

train_embeddings_path = results_dir / "embeddings_yamnet_train.npz"
val_embeddings_path = results_dir / "embeddings_yamnet_val.npz"
test_embeddings_path = results_dir / "embeddings_yamnet_test.npz"

# Extract embeddings if not already cached
if train_embeddings_path.exists() and val_embeddings_path.exists() and test_embeddings_path.exists():
    print("Loading cached embeddings from Results directory...")
else:
    print("Extracting YAMNet embeddings...")
    preprocess_and_save_embeddings("train", train_csv, yamnet_model, train_embeddings_path)
    preprocess_and_save_embeddings("val", val_csv, yamnet_model, val_embeddings_path)
    preprocess_and_save_embeddings("test", test_csv, yamnet_model, test_embeddings_path)

# Clean up
del yamnet_model
import gc
gc.collect()

Loading cached embeddings from Results directory...


130458

In [5]:
# Dataset for precomputed embeddings

class PrecomputedEmbeddingsDataset(Dataset):
    """Load precomputed YAMNet embeddings"""
    def __init__(self, embeddings_path):
        data = np.load(embeddings_path, allow_pickle=True)
        self.embeddings = torch.FloatTensor(data['embeddings'])
        self.labels = torch.LongTensor(data['labels'])
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]
    
    def __len__(self):
        return len(self.labels)

# Create datasets and loaders
train_ds = PrecomputedEmbeddingsDataset(train_embeddings_path)
val_ds = PrecomputedEmbeddingsDataset(val_embeddings_path)
test_ds = PrecomputedEmbeddingsDataset(test_embeddings_path)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

print(f"Datasets ready | Batch size: 32")

Datasets ready | Batch size: 32


In [6]:
# Classifier model

class YAMNetClassifier(nn.Module):
    """Classifier trained on YAMNet embeddings"""
    def __init__(self, embedding_dim=1024, num_classes=5):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, embeddings):
        return self.classifier(embeddings)

model = YAMNetClassifier(embedding_dim=1024, num_classes=num_classes).to(DEVICE)
print(f"Model: 1024 -> 128 -> {num_classes} | Parameters: {sum(p.numel() for p in model.parameters()):,}")

Model: 1024 -> 128 -> 5 | Parameters: 131,845


In [7]:
# Training

# Create Results directory if it doesn't exist
results_dir = PROJECT_ROOT / "Results"
results_dir.mkdir(exist_ok=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for embeddings, labels in tqdm(loader, leave=False, desc="Training"):
        embeddings, labels = embeddings.to(DEVICE), labels.to(DEVICE)
        outputs = model(embeddings)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

def evaluate(model, loader):
    model.eval()
    y_true, y_pred = [], []
    
    with torch.no_grad():
        for embeddings, labels in tqdm(loader, leave=False, desc="Evaluating"):
            embeddings, labels = embeddings.to(DEVICE), labels.to(DEVICE)
            outputs = model(embeddings)
            preds = outputs.argmax(1).cpu().numpy()
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds)
    
    f1 = f1_score(y_true, y_pred, average='macro')
    acc = 100. * sum([1 for t, p in zip(y_true, y_pred) if t == p]) / len(y_true)
    return f1, acc, classification_report(y_true, y_pred, target_names=families, digits=3)

# Training loop
best_f1 = 0
train_losses, train_accs, val_f1s, val_accs = [], [], [], []

print("Training...")
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_f1, val_acc, report = evaluate(model, val_loader)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_f1s.append(val_f1)
    val_accs.append(val_acc)
    
    print(f"Epoch {epoch+1:2d}/{EPOCHS} | Loss: {train_loss:.4f} | Train: {train_acc:5.1f}% | Val: {val_acc:5.1f}% | F1: {val_f1:.4f}", end="")
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), results_dir / "model_yamnet_classifier.pt")
        print(" ✓")
    else:
        print()
    
    if epoch % 5 == 0 and epoch > 0:
        print(report)

print(f"Best F1: {best_f1:.4f}")
print(f"Model saved to: {results_dir / 'model_yamnet_classifier.pt'}")

# Training curves
epochs = list(range(1, EPOCHS + 1))

p1 = figure(width=550, height=400, title="Accuracy", x_axis_label="Epoch", y_axis_label="Accuracy (%)")
p1.line(epochs, train_accs, line_width=2.5, color='#2E86AB', legend_label='Train')
p1.line(epochs, val_accs, line_width=2.5, color='#A23B72', legend_label='Val')
p1.legend.location = "bottom_right"
p1.legend.click_policy = "hide"

p2 = figure(width=550, height=400, title="F1 Score", x_axis_label="Epoch", y_axis_label="F1")
p2.line(epochs, val_f1s, line_width=2.5, color='#F18F01', legend_label='Val F1')
best_epoch = val_f1s.index(max(val_f1s)) + 1
p2.scatter([best_epoch], [max(val_f1s)], size=12, color='#C73E1D')
p2.legend.location = "bottom_right"

show(row(p1, p2))

Training...


                                                        

Epoch  1/25 | Loss: 1.5566 | Train:  38.8% | Val:  68.4% | F1: 0.6093 ✓


                                                 

Epoch  2/25 | Loss: 1.4187 | Train:  74.0% | Val:  81.0% | F1: 0.6899 ✓


                                                 

Epoch  3/25 | Loss: 1.2874 | Train:  82.6% | Val:  82.3% | F1: 0.6864


                                                 

Epoch  4/25 | Loss: 1.1557 | Train:  85.1% | Val:  86.1% | F1: 0.7162 ✓


                                                 

Epoch  5/25 | Loss: 1.0352 | Train:  86.0% | Val:  88.6% | F1: 0.8375 ✓


                                                 

Epoch  6/25 | Loss: 0.9257 | Train:  88.5% | Val:  88.6% | F1: 0.8375
              precision    recall  f1-score   support

   keyboards      1.000     0.800     0.889        20
  percussion      0.933     0.933     0.933        15
     strings      1.000     0.400     0.571         5
       voice      0.850     1.000     0.919        17
       winds      0.808     0.955     0.875        22

    accuracy                          0.886        79
   macro avg      0.918     0.818     0.838        79
weighted avg      0.902     0.886     0.880        79



                                                 

Epoch  7/25 | Loss: 0.8116 | Train:  89.8% | Val:  89.9% | F1: 0.8472 ✓


                                                 

Epoch  8/25 | Loss: 0.7213 | Train:  89.3% | Val:  89.9% | F1: 0.8472


                                                 

Epoch  9/25 | Loss: 0.6426 | Train:  90.8% | Val:  89.9% | F1: 0.8329


                                                 

Epoch 10/25 | Loss: 0.5924 | Train:  91.4% | Val:  89.9% | F1: 0.8329


                                                 

Epoch 11/25 | Loss: 0.5313 | Train:  91.2% | Val:  89.9% | F1: 0.8472
              precision    recall  f1-score   support

   keyboards      1.000     0.800     0.889        20
  percussion      0.933     0.933     0.933        15
     strings      1.000     0.400     0.571         5
       voice      0.895     1.000     0.944        17
       winds      0.815     1.000     0.898        22

    accuracy                          0.899        79
   macro avg      0.929     0.827     0.847        79
weighted avg      0.913     0.899     0.892        79



                                                 

Epoch 12/25 | Loss: 0.4837 | Train:  91.1% | Val:  91.1% | F1: 0.8595 ✓


                                                 

Epoch 13/25 | Loss: 0.4368 | Train:  91.4% | Val:  89.9% | F1: 0.8329


                                                 

Epoch 14/25 | Loss: 0.4048 | Train:  93.0% | Val:  89.9% | F1: 0.8329


                                                 

Epoch 15/25 | Loss: 0.3801 | Train:  92.6% | Val:  89.9% | F1: 0.8329


                                                 

Epoch 16/25 | Loss: 0.3696 | Train:  92.8% | Val:  91.1% | F1: 0.8509
              precision    recall  f1-score   support

   keyboards      1.000     0.800     0.889        20
  percussion      0.938     1.000     0.968        15
     strings      0.667     0.400     0.500         5
       voice      1.000     1.000     1.000        17
       winds      0.815     1.000     0.898        22

    accuracy                          0.911        79
   macro avg      0.884     0.840     0.851        79
weighted avg      0.915     0.911     0.906        79



                                                 

Epoch 17/25 | Loss: 0.3431 | Train:  92.4% | Val:  91.1% | F1: 0.8509


                                                 

Epoch 18/25 | Loss: 0.3275 | Train:  92.8% | Val:  89.9% | F1: 0.8083


                                                 

Epoch 19/25 | Loss: 0.3074 | Train:  93.1% | Val:  89.9% | F1: 0.8045


                                                 

Epoch 20/25 | Loss: 0.2869 | Train:  94.0% | Val:  87.3% | F1: 0.7370


                                                 

Epoch 21/25 | Loss: 0.2958 | Train:  92.4% | Val:  87.3% | F1: 0.7370
              precision    recall  f1-score   support

   keyboards      1.000     0.800     0.889        20
  percussion      0.933     0.933     0.933        15
     strings      0.000     0.000     0.000         5
       voice      1.000     1.000     1.000        17
       winds      0.759     1.000     0.863        22

    accuracy                          0.873        79
   macro avg      0.738     0.747     0.737        79
weighted avg      0.857     0.873     0.858        79



                                                 

Epoch 22/25 | Loss: 0.2645 | Train:  93.9% | Val:  88.6% | F1: 0.7439


                                                 

Epoch 23/25 | Loss: 0.2542 | Train:  93.7% | Val:  88.6% | F1: 0.7382


                                                 

Epoch 24/25 | Loss: 0.2671 | Train:  93.9% | Val:  87.3% | F1: 0.7370


                                                 

Epoch 25/25 | Loss: 0.2432 | Train:  94.0% | Val:  88.6% | F1: 0.7439
Best F1: 0.8595
Model saved to: /Users/dghifari/02-University/SEM-2-2025/elec5305-project-520140154/Results/model_yamnet_classifier.pt


In [8]:
# Test evaluation

model.load_state_dict(torch.load(results_dir / "model_yamnet_classifier.pt", map_location=DEVICE))
test_f1, test_acc, test_report = evaluate(model, test_loader)

print(f"\nTest F1: {test_f1:.4f} | Accuracy: {test_acc:.1f}%")
print(test_report)

# Generate predictions
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for embeddings, labels in test_loader:
        embeddings, labels = embeddings.to(DEVICE), labels.to(DEVICE)
        outputs = model(embeddings)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(outputs.argmax(1).cpu().numpy())

# Confusion matrix
cm_normalized = confusion_matrix(y_true, y_pred, normalize='true')
cm_counts = confusion_matrix(y_true, y_pred)
n_classes = len(families)

# Prepare data
true_labels, pred_labels, values, counts, pct_text, count_text, text_colors = [], [], [], [], [], [], []

for i in range(n_classes):
    for j in range(n_classes):
        true_labels.append(families[i])
        pred_labels.append(families[j])
        val = cm_normalized[i, j]
        cnt = int(cm_counts[i, j])
        values.append(val)
        counts.append(cnt)
        pct_text.append(f'{val:.0%}')
        count_text.append(f'({cnt})')
        
        # Text color for readability
        if val < 0.3:
            text_colors.append('#2b2b2b')
        elif val > 0.7:
            text_colors.append('white')
        else:
            text_colors.append('#2b2b2b')

source = ColumnDataSource(data=dict(
    true_label=true_labels,
    pred_label=pred_labels,
    value=values,
    count=counts,
    pct_text=pct_text,
    count_text=count_text,
    text_color=text_colors
))

# Color mapper - Blues gradient (light to dark)
palette = list(reversed(Blues9))
mapper = LinearColorMapper(palette=palette, low=0, high=1)

# Create figure
p = figure(
    title="YAMNet Confusion Matrix (Test Set)",
    x_range=families, 
    y_range=list(reversed(families)),
    width=750, 
    height=600,
    tools="hover,save,reset"
)

p.rect(x="pred_label", y="true_label", width=1, height=1, source=source,
       fill_color=transform('value', mapper), line_color='white', line_width=2)

p.text(x='pred_label', y='true_label', text='pct_text', source=source,
       text_align='center', text_baseline='middle', text_font_size='14pt',
       text_font_style='bold', text_color='text_color', y_offset=6)

p.text(x='pred_label', y='true_label', text='count_text', source=source,
       text_align='center', text_baseline='middle', text_font_size='10pt',
       text_color='text_color', y_offset=-8)

# Color bar
color_bar = ColorBar(color_mapper=mapper, ticker=BasicTicker(desired_num_ticks=10),
                     label_standoff=12, border_line_color=None, location=(0, 0),
                     title='Accuracy', title_text_font_style='bold')
p.add_layout(color_bar, 'right')

# Hover
hover = p.select_one(HoverTool)
hover.tooltips = [("True", "@true_label"), ("Predicted", "@pred_label"), 
                  ("Accuracy", "@value{0.1%}"), ("Count", "@count")]

# Styling
p.grid.grid_line_color = None
p.axis.axis_line_color = None
p.axis.major_tick_line_color = None
p.xaxis.axis_label = 'Predicted Label'
p.yaxis.axis_label = 'True Label'
p.xaxis.axis_label_text_font_style = "bold"
p.yaxis.axis_label_text_font_style = "bold"

show(p)

                                                 


Test F1: 0.8324 | Accuracy: 84.5%
              precision    recall  f1-score   support

   keyboards      0.952     0.952     0.952        21
  percussion      0.708     1.000     0.829        17
     strings      1.000     0.476     0.645        21
       voice      0.880     0.815     0.846        27
       winds      0.800     1.000     0.889        24

    accuracy                          0.845       110
   macro avg      0.868     0.849     0.832       110
weighted avg      0.873     0.845     0.835       110

