
# Finetune SMAD Student (AST) End-to-End

This notebook builds the fused pseudo-label manifest, validates it, trains the AST student, and evaluates on gold.

**Assumptions**
- Run from within the repo; audio segments live in `data/segments/`.
- Dependencies installed: torch, torchaudio, transformers, pandas, datasets, scikit-learn.
- Teachers HF datasets are under `data/metadata/blocs_smad_v2_*`.


In [1]:
# Resolve project root and set paths
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
if not (PROJECT_ROOT / 'data').exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
print(f"Using project root: {PROJECT_ROOT}")
import os
os.chdir(PROJECT_ROOT)
print(f'Working dir set to: {Path.cwd()}')

METADATA_DIR = PROJECT_ROOT / 'data/metadata'
SEGMENTS_DIR = PROJECT_ROOT / 'data/segments'
MANIFEST = METADATA_DIR / 'blocs_smad_v2_finetune.csv'
GOLD = METADATA_DIR / 'blocs_smad_gold_annotations_v1.csv'
CHECKPOINT = PROJECT_ROOT / 'checkpoints/student_ast.ipynb_run.pt'

Using project root: /Users/benji/Desktop/columbia/dams
Working dir set to: /Users/benji/Desktop/columbia/dams


In [2]:

# Hyperparameters
AST_MODEL = 'MIT/ast-finetuned-audioset-10-10-0.4593'
EPOCHS = 10
BATCH_SIZE_AST = 2
LR = 1e-4
WEIGHT_DECAY = 1e-5
VAL_FRACTION = 0.1



## Build finetune manifest
Uses `scripts/build_finetune_dataset.py`: per-class F1 winner on non-IRR gold; inner-join teachers; writes CSV/Parquet/HF dataset.


In [3]:

from scripts.build_finetune_dataset import build_dataset

out_disk = METADATA_DIR / 'blocs_smad_v2_finetune'
out_parquet = METADATA_DIR / 'blocs_smad_v2_finetune.parquet'
out_csv = MANIFEST

build_dataset(METADATA_DIR, out_disk, out_parquet, out_csv)


Using non-IRR calibration subset with 1569 rows
Best speech teacher: m2d (F1 0.9592)
Best music teacher: ast (F1 0.9768)
Best noise teacher: clap (F1 0.0391)
Teacher row counts: {'ast': 6196, 'clap': 6196, 'm2d': 6196, 'whisper': 6196}
Segment intersection size across teachers: 6196
Built merged dataset with 6196 rows and 34 columns
Wrote Parquet to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune.parquet
Wrote CSV to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune.csv


Saving the dataset (0/1 shards):   0%|          | 0/6196 [00:00<?, ? examples/s]

Saved HF dataset to /Users/benji/Desktop/columbia/dams/data/metadata/blocs_smad_v2_finetune



## Validate manifest
Checks for dupes, required columns, chosen_* nulls, and optional gold sanity.


In [4]:

from scripts.validate_finetune_manifest import main as validate_main
validate_main()


Loaded manifest: data/metadata/blocs_smad_v2_finetune.csv rows=6196 columns=34
No duplicate segment_path entries.
All required teacher and chosen columns present.
Chosen columns have no nulls.
Value counts for chosen_speech_label: {1: 5698, 0: 498}
Value counts for chosen_music_label: {0: 5488, 1: 708}
Value counts for chosen_noise_label: {0: 6178, 1: 18}
Merged IRR gold rows: 174

IRR gold sanity for speech:
              precision    recall  f1-score   support

           0     0.8718    1.0000    0.9315        34
           1     1.0000    0.9643    0.9818       140

    accuracy                         0.9713       174
   macro avg     0.9359    0.9821    0.9567       174
weighted avg     0.9749    0.9713    0.9720       174


IRR gold sanity for music:
              precision    recall  f1-score   support

           0     0.9746    1.0000    0.9871       115
           1     1.0000    0.9492    0.9739        59

    accuracy                         0.9828       174
   macro avg  


## Train AST student
Fine-tune AST with BCEWithLogits, class pos_weight, train/val split.


In [None]:

import argparse
from scripts.train_student import train

train_args = argparse.Namespace(
    manifest=MANIFEST,
    segments_dir=SEGMENTS_DIR,
    sample_rate=16000,
    n_mels=128,
    hop_length=160,
    win_length=400,
    batch_size_ast=BATCH_SIZE_AST,
    epochs=EPOCHS,
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    val_fraction=VAL_FRACTION,
    output=CHECKPOINT,
    ast_model=AST_MODEL,
)
train(train_args)


Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train 1/10:   0%|          | 0/2789 [00:02<?, ?it/s]

libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
Exception ignored in: <generator object tqdm_notebook.__iter__ at 0x107057a60>
Traceback (most recent call last):
  File "/Users/benji/Desktop/columbia/dams/.venv/lib/python3.13/site-packages/tqdm/notebook.py", line 255, in __iter__
    self.disp(bar_style='danger')
  File "/Users/benji/Desktop/columbia/dams/.venv/lib/python3.13/site-packages/tqdm/notebook.py", line 139, in display
    def display(self, msg=None, pos=None,
  File "/Users/benji/Desktop/columbia/dams/.venv/lib/python3.13/site-packages/torch/utils/data/_utils/signal_handling.py", line 73, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader 

KeyboardInterrupt: 

: 


## Evaluate on gold
Default filter is IRR; adjust `gold_filter` or `threshold` as needed.


In [None]:

import argparse
from scripts.evaluate_student import evaluate

eval_args = argparse.Namespace(
    checkpoint=CHECKPOINT,
    manifest=MANIFEST,
    segments_dir=SEGMENTS_DIR,
    gold=GOLD,
    gold_filter='irr',  # options: 'irr', 'non-irr', 'all'
    batch_size=2,
    threshold=0.5,
    sample_rate=16000,
    ast_model=AST_MODEL,
)
evaluate(eval_args)
