
# Finetune SMAD Student (AST) End-to-End

This notebook builds the fused pseudo-label manifest, validates it, trains the AST student, and evaluates on gold.

**Assumptions**
- Run from within the repo; audio segments live in `data/segments/`.
- Dependencies installed: torch, torchaudio, transformers, pandas, datasets, scikit-learn.
- Teachers HF datasets are under `data/metadata/blocs_smad_v2_*`.


In [1]:
# Resolve project root and set paths
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
if not (PROJECT_ROOT / 'data').exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
print(f"Using project root: {PROJECT_ROOT}")
import os
os.chdir(PROJECT_ROOT)
print(f'Working dir set to: {Path.cwd()}')

METADATA_DIR = PROJECT_ROOT / 'data/metadata'
SEGMENTS_DIR = PROJECT_ROOT / 'data/segments'
MANIFEST = METADATA_DIR / 'blocs_smad_v2_finetune.csv'
GOLD = METADATA_DIR / 'blocs_smad_gold_annotations_v1.csv'
CHECKPOINT = PROJECT_ROOT / 'checkpoints/student_ast.pt'
CHECKPOINT.parent.mkdir(parents=True, exist_ok=True)

Using project root: /Users/benji/Desktop/columbia/dams
Working dir set to: /Users/benji/Desktop/columbia/dams


In [2]:

# Hyperparameters
AST_MODEL = 'MIT/ast-finetuned-audioset-10-10-0.4593'
EPOCHS = 10
BATCH_SIZE_AST = 2
LR = 1e-4
WEIGHT_DECAY = 1e-5
VAL_FRACTION = 0.1



## Build finetune manifest
Uses `scripts/build_finetune_dataset.py`: per-class F1 winner on non-IRR gold; inner-join teachers; writes CSV/Parquet/HF dataset.


In [3]:

from scripts.build_finetune_dataset import build_dataset

out_disk = METADATA_DIR / 'blocs_smad_v2_finetune'
out_parquet = METADATA_DIR / 'blocs_smad_v2_finetune.parquet'
out_csv = MANIFEST

build_dataset(METADATA_DIR, out_disk, out_parquet, out_csv)


Applied consensus voting with threshold 2 over teachers: ['ast', 'clap', 'm2d', 'whisper']
Teacher row counts: {'ast': 6196, 'clap': 6196, 'm2d': 6196, 'whisper': 6196}
Segment intersection size across teachers: 6196
Built merged dataset with 6196 rows and 34 columns
Wrote Parquet to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune.parquet
Wrote CSV to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune.csv


Saving the dataset (0/1 shards):   0%|          | 0/6196 [00:00<?, ? examples/s]

Saved HF dataset to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune



## Validate manifest
Checks for dupes, required columns, chosen_* nulls, and optional gold sanity.


In [4]:

from scripts.validate_finetune_manifest import main as validate_main
validate_main()


Loaded manifest: data/metadata/blocs_smad_v2_finetune.csv rows=6196 columns=34
No duplicate segment_path entries.
All required teacher and chosen columns present.
Chosen columns have no nulls.
Value counts for chosen_speech_label: {1: 5715, 0: 481}
Value counts for chosen_music_label: {0: 5510, 1: 686}
Value counts for chosen_noise_label: {0: 6182, 1: 14}
Merged IRR gold rows: 174

IRR gold sanity for speech:
              precision    recall  f1-score   support

           0     0.8293    1.0000    0.9067        34
           1     1.0000    0.9500    0.9744       140

    accuracy                         0.9598       174
   macro avg     0.9146    0.9750    0.9405       174
weighted avg     0.9666    0.9598    0.9611       174


IRR gold sanity for music:
              precision    recall  f1-score   support

           0     0.9746    1.0000    0.9871       115
           1     1.0000    0.9492    0.9739        59

    accuracy                         0.9828       174
   macro avg  


## Train AST student
Fine-tune AST with BCEWithLogits, class pos_weight, train/val split.



## Evaluate on gold
Default filter is IRR; adjust `gold_filter` or `threshold` as needed.



## Holdout Gold Eval (non-IRR)
We reserve a gold, non-IRR subset for evaluation and drop it from training. Teacher metrics and the student are compared on this holdout set.


In [5]:

import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
from pathlib import Path

CLASSES = ["speech", "music", "noise"]

# Load gold labels and pick non-IRR rows
GOLD_PATH = METADATA_DIR / "blocs_smad_gold_annotations_v1.csv"
gold_df = pd.read_csv(GOLD_PATH)
gold_non_irr = gold_df[gold_df["is_irr_segment"] == False].copy().reset_index(drop=True)

# Sample a holdout set (20% by default)
HOLDOUT_FRAC = 0.2
RNG_SEED = 42
holdout_df = gold_non_irr.sample(frac=HOLDOUT_FRAC, random_state=RNG_SEED)
train_gold_df = gold_non_irr.drop(holdout_df.index)

print(f"Gold non-IRR total: {len(gold_non_irr)} | holdout: {len(holdout_df)} | remaining for calibration: {len(train_gold_df)}")


Gold non-IRR total: 1569 | holdout: 314 | remaining for calibration: 1255


In [6]:

# Reload finetune manifest and drop holdout segments for training
manifest_df = pd.read_csv(MANIFEST)
holdout_paths = set(holdout_df["segment_path"])
train_manifest = manifest_df[~manifest_df["segment_path"].isin(holdout_paths)].reset_index(drop=True)
print(f"Training manifest rows after dropping holdout: {len(train_manifest)} (dropped {len(manifest_df)-len(train_manifest)})")


Training manifest rows after dropping holdout: 5882 (dropped 314)


In [7]:

TRAIN_MANIFEST = METADATA_DIR / 'blocs_smad_v2_finetune_train.csv'
train_manifest.to_csv(TRAIN_MANIFEST, index=False)
print(f"Wrote training manifest (holdout removed) to {TRAIN_MANIFEST}")


Wrote training manifest (holdout removed) to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune_train.csv


In [None]:

import torchaudio
import pandas as pd
import math

# Compute stats for train/holdout/total and gold overlap
manifest_df = pd.read_csv(MANIFEST)
train_df = pd.read_csv(TRAIN_MANIFEST)
holdout_paths = set(manifest_df["segment_path"]) - set(train_df["segment_path"])
holdout_df = manifest_df[manifest_df["segment_path"].isin(holdout_paths)].reset_index(drop=True)

# Derive segment duration from one file
sample_path = SEGMENTS_DIR / manifest_df.iloc[0]["segment_path"]
wav, sr = torchaudio.load(sample_path)
duration_sec = wav.shape[-1] / sr

# Load gold annotations to count gold segments in each split
gold_df = pd.read_csv(GOLD)

def stats(df, name):
    n = len(df)
    hours = n * duration_sec / 3600.0
    speech_pct = float(df["chosen_speech_label"].mean() * 100)
    music_pct = float(df["chosen_music_label"].mean() * 100)
    noise_pct = float(df["chosen_noise_label"].mean() * 100)
    gold_count = gold_df[gold_df["segment_path"].isin(df["segment_path"])].shape[0]
    return {
        "split": name,
        "segments": n,
        "hours": hours,
        "%speech": speech_pct,
        "%music": music_pct,
        "%noise": noise_pct,
        "gold_segs": gold_count,
    }

rows = [
    stats(train_df, "train"),
    stats(holdout_df, "holdout"),
    stats(manifest_df, "total"),
]

stats_df = pd.DataFrame(rows)
stats_df["hours"] = stats_df["hours"].map(lambda h: round(h, 2))
for col in ["%speech", "%music", "%noise"]:
    stats_df[col] = stats_df[col].map(lambda x: round(x, 2))
stats_df


In [8]:

import argparse
from scripts.train_student import train

train_args = argparse.Namespace(
    manifest=TRAIN_MANIFEST,
    segments_dir=SEGMENTS_DIR,
    sample_rate=16000,
    n_mels=128,
    hop_length=160,
    win_length=400,
    batch_size_ast=BATCH_SIZE_AST,
    epochs=EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    val_fraction=VAL_FRACTION,
    output=CHECKPOINT,
    ast_model=AST_MODEL,
)
train(train_args)


Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Frozen encoder; training head only (3,843/86,191,107 parameters).


Train 1/10:   0%|          | 0/2647 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:

# Helper: compute per-class precision/recall/F1 given gold + prediction columns
import numpy as np

def compute_metrics(y_true, y_pred):
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    macro_f1 = float(f1.mean())
    per_class = {c: {"precision": float(p), "recall": float(r), "f1": float(f)} for c, p, r, f in zip(CLASSES, prec, rec, f1)}
    return per_class, macro_f1

# Build gold matrix for holdout
holdout_gold = holdout_df[[f"{c}_gold" for c in CLASSES]].to_numpy()

# Join holdout with teacher predictions
joined = holdout_df.merge(manifest_df, on="segment_path", how="left", suffixes=("_gold", ""))
missing = joined["segment_path"].isna().sum()
if missing:
    print(f"Warning: {missing} holdout segments missing from manifest (will drop)")
    joined = joined.dropna(subset=["segment_path"]).reset_index(drop=True)

teacher_sources = {
    "ast": [f"ast_{c}_label" for c in CLASSES],
    "whisper": [f"whisper_{c}_label" for c in CLASSES],
    "clap": [f"clap_{c}_label" for c in CLASSES],
    "m2d": [f"m2d_{c}_label" for c in CLASSES],
    "consensus": [f"chosen_{c}_label" for c in CLASSES],
}

metrics_rows = []
for name, cols in teacher_sources.items():
    if not set(cols).issubset(joined.columns):
        continue
    y_pred = joined[cols].to_numpy()
    per_class, macro_f1 = compute_metrics(holdout_gold, y_pred)
    metrics_rows.append({"model": name, "macro_f1": macro_f1, **{f"f1_{c}": per_class[c]["f1"] for c in CLASSES}})

metrics_df = pd.DataFrame(metrics_rows)
metrics_df


In [None]:

# Plot per-class F1 for teachers/consensus
import matplotlib.pyplot as plt
import numpy as np

if not metrics_df.empty:
    fig, ax = plt.subplots(figsize=(8, 5))
    x = np.arange(len(metrics_df))
    width = 0.2
    for i, c in enumerate(CLASSES):
        ax.bar(x + i*width, metrics_df[f"f1_{c}"], width, label=c)
    ax.set_xticks(x + width)
    ax.set_xticklabels(metrics_df["model"], rotation=45, ha="right")
    ax.set_ylabel("F1 (holdout)")
    ax.set_title("Teacher/Consensus F1 on gold holdout")
    ax.legend()
    plt.tight_layout()
else:
    print("No metrics to plot (metrics_df empty)")


In [None]:

#evaluate trained student on holdout (requires checkpoint_path)
import torch
from torch.utils.data import DataLoader
from finetune.ast_model import ASTClassifier, get_feature_extractor
from finetune.dataset import SmadDataset
from pathlib import Path

checkpoint_path = Path("checkpoints/student_ast.pt")
if checkpoint_path.exists():
    device = torch.device("mps" if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
    state = torch.load(checkpoint_path, map_location=device)
    student = ASTClassifier(model_name=AST_MODEL, num_labels=3).to(device)
    student.load_state_dict(state["model_state_dict"], strict=False)
    student.eval()
    processor = get_feature_extractor(AST_MODEL)

    holdout_ds = SmadDataset(
        manifest_path=MANIFEST,
        segments_dir=SEGMENTS_DIR,
        sample_rate=16000,
        n_mels=128,
        hop_length=160,
        win_length=400,
        return_waveform=True,
    )
    # filter to holdout paths
    holdout_indices = [i for i, p in enumerate(holdout_ds.df["segment_path"]) if p in holdout_paths]
    holdout_subset = torch.utils.data.Subset(holdout_ds, holdout_indices)
    loader = DataLoader(holdout_subset, batch_size=4, shuffle=False)

    all_logits, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            labels = batch["labels"].to(device)
            wavs = batch["waveform"].squeeze(1)
            wav_list = [w.cpu().numpy() for w in wavs]
            inputs = processor(wav_list, sampling_rate=processor.sampling_rate, return_tensors="pt", padding=True)
            input_values = inputs["input_values"].to(device)
            attention_mask = inputs.get("attention_mask")
            if attention_mask is not None:
                attention_mask = attention_mask.to(device)
            logits = student(input_values=input_values, attention_mask=attention_mask)
            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())
    logits = torch.cat(all_logits)
    labels = torch.cat(all_labels)
    probs = torch.sigmoid(logits).numpy()
    preds = (probs >= 0.5).astype(int)
    student_per_class, student_macro = compute_metrics(labels.numpy(), preds)
    print("Student macro F1 on holdout:", student_macro)
else:
    print(f"No checkpoint found at {checkpoint_path}; skip student eval.")
