# MERT LoRA Fine-tuning for Meter Classification

**GPU**: Runtime → Change runtime type → T4 GPU (free) or A100 (Pro)

This notebook:
1. Downloads METER2800 dataset from Harvard Dataverse
2. Downloads WIKIMETER (curated songs from YouTube: 3/4, 4/4, 5/x, 7/x, 9/x, 11/x)
3. Fine-tunes MERT with LoRA for multi-label meter classification

**Estimated time**: ~2-3h on T4 (95M), ~5h on A100 (330M)

## 0. Install dependencies

In [None]:
!pip install -q transformers peft librosa scikit-learn tqdm yt-dlp
!apt-get -qq install ffmpeg

In [None]:
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU'}")
print(f"CUDA: {torch.cuda.is_available()}")
if not torch.cuda.is_available():
    raise RuntimeError("No GPU! Go to Runtime → Change runtime type → T4 GPU")

## 1. Download METER2800

In [None]:
import hashlib
import json
import ssl
import sys
import tarfile
import time
import urllib.error
import urllib.request
from pathlib import Path

DATA_DIR = Path("/content/data/meter2800")
AUDIO_DIR = DATA_DIR / "audio"
DOWNLOADS_DIR = DATA_DIR / "downloads"
AUDIO_EXTENSIONS = {".wav", ".mp3", ".ogg", ".flac", ".oga", ".opus", ".aiff", ".aif"}

DOI = "doi:10.7910/DVN/0CLXBQ"
API_BASE = "https://dataverse.harvard.edu/api"
DATASET_URL = f"{API_BASE}/datasets/:persistentId/?persistentId={DOI}"
TAR_FILES = {"FMA.tar.gz", "MAG.tar.gz", "OWN.tar.gz"}


def api_request(url, max_retries=3):
    ctx = ssl.create_default_context()
    for attempt in range(max_retries):
        try:
            req = urllib.request.Request(url, headers={"User-Agent": "RhythmAnalyzer/1.0", "Accept": "application/json"})
            with urllib.request.urlopen(req, timeout=30, context=ctx) as resp:
                return resp.read()
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(2 * (2 ** attempt))
            else:
                raise


def download_dataverse_file(file_id, dest, expected_size=0):
    url = f"{API_BASE}/access/datafile/{file_id}"
    ctx = ssl.create_default_context()
    req = urllib.request.Request(url, headers={"User-Agent": "RhythmAnalyzer/1.0"})
    dest.parent.mkdir(parents=True, exist_ok=True)
    with urllib.request.urlopen(req, timeout=300, context=ctx) as resp:
        dest.write_bytes(resp.read())
    return dest.exists()


def extract_tar(tar_path, audio_dir):
    audio_dir.mkdir(parents=True, exist_ok=True)
    count = 0
    with tarfile.open(tar_path, "r:gz") as tar:
        for member in tar.getmembers():
            if not member.isfile():
                continue
            name = Path(member.name).name
            if Path(name).suffix.lower() not in AUDIO_EXTENSIONS:
                continue
            dest = audio_dir / name
            if dest.exists():
                parent = Path(member.name).parent.name
                dest = audio_dir / f"{parent}_{name}"
            with tar.extractfile(member) as src:
                if src:
                    dest.write_bytes(src.read())
                    count += 1
    return count


# Check if already downloaded
if AUDIO_DIR.exists() and len(list(AUDIO_DIR.glob("*"))) > 2000:
    print(f"METER2800 already downloaded: {len(list(AUDIO_DIR.glob('*')))} files")
else:
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    DOWNLOADS_DIR.mkdir(parents=True, exist_ok=True)
    AUDIO_DIR.mkdir(parents=True, exist_ok=True)

    # Get file list
    print("Fetching METER2800 metadata...")
    metadata = json.loads(api_request(DATASET_URL))
    files = metadata["data"]["latestVersion"]["files"]
    print(f"Found {len(files)} files")

    for f in files:
        df = f["dataFile"]
        fname = df["filename"]

        if fname in TAR_FILES:
            dest = DOWNLOADS_DIR / fname
            if not dest.exists():
                print(f"Downloading {fname}...", end=" ", flush=True)
                download_dataverse_file(df["id"], dest, df.get("filesize", 0))
                print("OK")
        elif fname.endswith((".csv", ".tab", ".tsv")):
            dest = DATA_DIR / fname
            if not dest.exists():
                print(f"Downloading {fname}...", end=" ", flush=True)
                download_dataverse_file(df["id"], dest)
                print("OK")

    # Extract
    total = 0
    for tar_path in sorted(DOWNLOADS_DIR.glob("*.tar.gz")):
        if tar_path.name in TAR_FILES:
            print(f"Extracting {tar_path.name}...", end=" ", flush=True)
            count = extract_tar(tar_path, AUDIO_DIR)
            print(f"{count} files")
            total += count
    print(f"\nTotal extracted: {total} audio files")

# Verify
n_audio = len([f for f in AUDIO_DIR.iterdir() if f.suffix.lower() in AUDIO_EXTENSIONS])
n_tabs = len(list(DATA_DIR.glob("*.tab")))
print(f"\nMETER2800: {n_audio} audio files, {n_tabs} label files")

## 2. Download WIKIMETER (YouTube)

Curated songs with known time signatures: 3/4, 4/4, 5/x, 7/x, 9/x, 11/x + polyrhythmic.
Expected duration per song filters out albums/compilations.

In [None]:
import re
import subprocess
import tempfile

WIKIMETER_DIR = Path("/content/data/wikimeter")
WIKIMETER_AUDIO = WIKIMETER_DIR / "audio"
WIKIMETER_TAB = WIKIMETER_DIR / "data_wikimeter.tab"
METER_LABELS = {3: "three", 4: "four", 5: "five", 7: "seven", 9: "nine", 11: "eleven"}
MAX_SEGMENTS = 25
DURATION_TOLERANCE = 2.0  # Accept 0.5x–2.0x expected duration
MAX_VIDEO_DURATION_FALLBACK = 900

# (artist, title, meters, search_query_override, expected_duration_s)
SONGS = [
    # === 3/4 ===
    ("Strauss II", "The Blue Danube", [3], "Johann Strauss Blue Danube waltz", 600),
    ("Strauss II", "Tales from the Vienna Woods", [3], "Strauss Tales Vienna Woods waltz", 720),
    ("Strauss II", "Emperor Waltz", [3], "Strauss Emperor Waltz Kaiserwalzer", 660),
    ("Chopin", "Waltz in C-sharp minor Op 64 No 2", [3], "Chopin Waltz Op 64 No 2", 210),
    ("Chopin", "Minute Waltz", [3], "Chopin Minute Waltz Op 64 No 1", 120),
    ("Chopin", "Grande Valse Brillante", [3], "Chopin Grande Valse Brillante Op 18", 300),
    ("Chopin", "Waltz in A minor", [3], "Chopin Waltz A minor B 150", 180),
    ("Tchaikovsky", "Waltz of the Flowers", [3], "Tchaikovsky Waltz of the Flowers Nutcracker", 420),
    ("Tchaikovsky", "Sleeping Beauty Waltz", [3], "Tchaikovsky Sleeping Beauty Waltz", 360),
    ("Tchaikovsky", "Swan Lake Waltz", [3], "Tchaikovsky Swan Lake Waltz", 390),
    ("Shostakovich", "Waltz No 2", [3], "Shostakovich Waltz No 2", 210),
    ("Khachaturian", "Masquerade Waltz", [3], "Khachaturian Masquerade Waltz", 270),
    ("Ravel", "La Valse", [3], "Ravel La Valse orchestral", 780),
    ("Brahms", "Waltz in A-flat major Op 39 No 15", [3], "Brahms Waltz Op 39 No 15", 120),
    ("The Beatles", "Norwegian Wood", [3], None, 125),
    ("Leonard Cohen", "Take This Waltz", [3], None, 420),
    ("Norah Jones", "Come Away With Me", [3], None, 198),
    ("Jeff Buckley", "Lilac Wine", [3], None, 280),
    ("Radiohead", "Codex", [3], None, 280),
    ("Elliott Smith", "Waltz No 2", [3], "Elliott Smith Waltz No 2 XO", 270),
    ("Damien Rice", "The Blower's Daughter", [3], "Damien Rice Blowers Daughter", 260),
    ("Mazzy Star", "Fade Into You", [3], None, 290),
    ("R.E.M.", "Everybody Hurts", [3], None, 318),
    ("Counting Crows", "A Long December", [3], None, 330),
    ("Traditional", "Greensleeves", [3], "Greensleeves traditional", 240),
    ("Traditional", "Scarborough Fair", [3], "Scarborough Fair traditional", 210),
    ("Traditional", "Danny Boy", [3], "Danny Boy traditional Irish", 240),
    ("Traditional", "Amazing Grace", [3], "Amazing Grace traditional", 300),
    ("Traditional", "Edelweiss", [3], "Edelweiss Sound of Music", 150),
    # === 4/4 ===
    ("Queen", "We Will Rock You", [4], None, 122),
    ("Queen", "Another One Bites the Dust", [4], None, 215),
    ("AC/DC", "Back in Black", [4], None, 255),
    ("Deep Purple", "Smoke on the Water", [4], None, 340),
    ("Led Zeppelin", "Whole Lotta Love", [4], None, 333),
    ("The Rolling Stones", "Satisfaction", [4], "Rolling Stones Satisfaction", 224),
    ("Nirvana", "Smells Like Teen Spirit", [4], None, 301),
    ("Metallica", "Enter Sandman", [4], None, 331),
    ("Guns N Roses", "Sweet Child O Mine", [4], "Guns N Roses Sweet Child O Mine", 356),
    ("The White Stripes", "Seven Nation Army", [4], None, 232),
    ("Michael Jackson", "Billie Jean", [4], None, 294),
    ("Michael Jackson", "Beat It", [4], None, 258),
    ("Stevie Wonder", "Superstition", [4], None, 245),
    ("Bee Gees", "Stayin Alive", [4], "Bee Gees Stayin Alive", 285),
    ("ABBA", "Dancing Queen", [4], None, 231),
    ("Daft Punk", "Get Lucky", [4], None, 369),
    ("Daft Punk", "Around the World", [4], None, 427),
    ("Bruno Mars", "Uptown Funk", [4], None, 270),
    ("Pharrell Williams", "Happy", [4], None, 233),
    ("Marvin Gaye", "Aint No Mountain High Enough", [4], "Marvin Gaye Aint No Mountain High Enough", 150),
    ("Kraftwerk", "The Model", [4], None, 220),
    ("Kraftwerk", "Autobahn", [4], "Kraftwerk Autobahn", 660),
    ("New Order", "Blue Monday", [4], None, 442),
    ("Depeche Mode", "Personal Jesus", [4], None, 295),
    ("The Prodigy", "Firestarter", [4], None, 280),
    ("James Brown", "I Got You", [4], "James Brown I Got You I Feel Good", 167),
    ("Parliament", "Give Up the Funk", [4], "Parliament Give Up the Funk", 348),
    ("Grandmaster Flash", "The Message", [4], None, 445),
    ("A Tribe Called Quest", "Can I Kick It", [4], None, 260),
    # === 5/x ===
    ("Dave Brubeck", "Take Five", [5], None, 325),
    ("Paul Desmond", "Take Ten", [5], "Paul Desmond Take Ten", 340),
    ("Brubeck", "Three to Get Ready", [5], "Dave Brubeck Three to Get Ready", 340),
    ("Chick Corea", "Spain", [5], "Chick Corea Spain", 600),
    ("Radiohead", "15 Step", [5], None, 237),
    ("Radiohead", "Morning Bell", [5], None, 270),
    ("Radiohead", "Everything in Its Right Place", [5], None, 250),
    ("Radiohead", "Sail to the Moon", [5], None, 290),
    ("Gorillaz", "5/4", [5], "Gorillaz 5/4", 241),
    ("Jethro Tull", "Living in the Past", [5], None, 205),
    ("Sting", "Seven Days", [5], None, 290),
    ("Nick Drake", "River Man", [5], None, 275),
    ("The Mars Volta", "Inertiatic ESP", [5], None, 230),
    ("Sufjan Stevens", "A Good Man Is Hard to Find", [5], None, 300),
    ("Donovan", "Atlantis", [5], None, 295),
    ("Soundgarden", "My Wave", [5], "Soundgarden My Wave", 312),
    ("Tame Impala", "Apocalypse Dreams", [5], None, 390),
    ("Lalo Schifrin", "Mission Impossible Theme", [5], "Mission Impossible theme original", 180),
    ("Chopin", "Piano Sonata No 1 Larghetto", [5], "Chopin Piano Sonata No 1 Op 4 Larghetto 5/4", 480),
    ("Tchaikovsky", "Symphony No 6 Movement 2", [5], "Tchaikovsky Symphony 6 second movement 5/4", 510),
    ("Dream Theater", "The Mirror", [5], None, 660),
    ("Animals as Leaders", "CAFO", [5], None, 370),
    ("King Crimson", "Discipline", [5], "King Crimson Discipline", 305),
    ("Björk", "Army of Me", [5], None, 224),
    ("Béla Fleck", "Sinister Minister", [5], "Bela Fleck Sinister Minister", 360),
    ("Mike Oldfield", "Tubular Bells Part 1", [5], "Mike Oldfield Tubular Bells opening 5/4", 600),
    ("Vulfpeck", "Dean Town", [5], None, 210),
    ("Jacob Collier", "In My Room", [5], "Jacob Collier In My Room", 300),
    ("Traditional", "Eleno Mome", [5], "Eleno Mome Bulgarian folk", 240),
    ("Traditional", "Paidushko Horo", [5], "Paidushko Horo Bulgarian folk", 240),
    # === 7/x ===
    ("Pink Floyd", "Money", [7], None, 382),
    ("Peter Gabriel", "Solsbury Hill", [7], None, 260),
    ("Soundgarden", "Outshined", [7], "Soundgarden Outshined", 312),
    ("The Beatles", "All You Need Is Love", [7], "Beatles All You Need Is Love", 237),
    ("Radiohead", "2 + 2 = 5", [7], "Radiohead 2+2=5", 202),
    ("Rush", "Tom Sawyer", [7], "Rush Tom Sawyer", 276),
    ("Gentle Giant", "The Runaway", [7], "Gentle Giant The Runaway", 300),
    ("Alice in Chains", "Them Bones", [7], "Alice in Chains Them Bones", 147),
    ("King Crimson", "Frame by Frame", [7], "King Crimson Frame by Frame", 310),
    ("Opeth", "The Drapery Falls", [7], None, 630),
    ("Gentle Giant", "Knots", [7], "Gentle Giant Knots", 300),
    ("Robert Fripp", "Exposure", [7], "Robert Fripp Exposure", 260),
    ("Dave Holland", "Conference of the Birds", [7], "Dave Holland Conference of the Birds", 480),
    ("John McLaughlin", "Meeting of the Spirits", [7], "Mahavishnu Orchestra Meeting of the Spirits", 420),
    ("Traditional", "Rachenitsa", [7], "Rachenitsa Bulgarian folk 7/8", 240),
    ("Traditional", "Makedonsko Devojche", [7], "Makedonsko Devojche folk 7/8", 240),
    ("Traditional", "Chetvorno Horo", [7], "Chetvorno Horo Bulgarian 7/8", 240),
    ("Traditional", "Lesnoto", [7], "Lesnoto Macedonian 7/8", 240),
    ("Traditional", "Ivailo", [7], "Ivailo Bulgarian folk 7/8", 240),
    ("Traditional", "Pravo Horo", [7], "Pravo Horo Bulgarian", 240),
    ("Goran Bregović", "Mesečina", [7], "Goran Bregovic Mesecina", 240),
    ("Fanfare Ciocărlia", "Born to Be Wild", [7], "Fanfare Ciocarlia Born to Be Wild", 270),
    ("Bernstein", "America from West Side Story", [7], "Bernstein America West Side Story", 300),
    ("Hans Zimmer", "Mombasa", [7], "Hans Zimmer Mombasa Inception", 295),
    ("Bear McCreary", "BSG Main Theme", [7], "Bear McCreary Battlestar Galactica theme", 240),
    ("Aimee Mann", "Momentum", [7], None, 260),
    ("Joni Mitchell", "The Silky Veils of Ardor", [7], None, 240),
    ("Broken Social Scene", "7/4 Shoreline", [7], "Broken Social Scene 7/4 Shoreline", 260),
    ("Iron Maiden", "The Loneliness of the Long Distance Runner", [7], None, 390),
    ("Led Zeppelin", "The Ocean", [7], "Led Zeppelin The Ocean", 266),
    ("Jeff Beck", "Led Boots", [7], None, 340),
    ("Porcupine Tree", "The Sound of Muzak", [7], None, 290),
    # === 9/8 ===
    ("Dave Brubeck", "Blue Rondo à la Turk", [9], "Dave Brubeck Blue Rondo a la Turk", 402),
    ("Traditional", "Daichovo Horo", [9], "Daichovo Horo Bulgarian 9/8", 240),
    ("Traditional", "Zeimbekiko", [9], "Zeimbekiko Greek dance 9/8", 300),
    ("Traditional", "Karsilama", [9], "Karsilama Turkish 9/8", 240),
    ("Traditional", "Arap", [9], "Arap Turkish 9/8 dance", 240),
    ("Bartók", "Six Dances in Bulgarian Rhythm No 4", [9], "Bartok Mikrokosmos 151 Bulgarian Rhythm 4", 90),
    ("Bartók", "Six Dances in Bulgarian Rhythm No 5", [9], "Bartok Mikrokosmos 152 Bulgarian Rhythm 5", 90),
    ("Toto", "Mushanga", [9], "Toto Mushanga", 360),
    ("Muse", "Butterflies and Hurricanes", [9], "Muse Butterflies and Hurricanes", 330),
    ("Mahler", "Symphony No 9 Rondo Burleske", [9], "Mahler Symphony 9 Rondo Burleske", 780),
    ("Stravinsky", "The Rite of Spring Sacrificial Dance", [9], "Stravinsky Rite of Spring Sacrificial Dance", 300),
    ("Bartók", "Six Dances in Bulgarian Rhythm No 1", [9], "Bartok Bulgarian Rhythm No 1 Mikrokosmos", 90),
    # === 11/8 ===
    ("Traditional", "Gankino Horo", [11], "Gankino Horo Bulgarian 11/8", 240),
    ("Traditional", "Kopanitsa", [11], "Kopanitsa Bulgarian folk 11/8", 240),
    ("Traditional", "Ispayche", [11], "Ispayche Bulgarian 11/8", 240),
    ("Primus", "Eleven", [11], "Primus Eleven", 330),
    ("Grateful Dead", "The Eleven", [11], "Grateful Dead The Eleven", 480),
    ("Bartók", "Six Dances in Bulgarian Rhythm No 6", [11], "Bartok Mikrokosmos 153 Bulgarian Rhythm 6", 120),
    ("Aksak Maboul", "Saure Gurke", [11], "Aksak Maboul Saure Gurke", 300),
    ("Frank Zappa", "Outside Now", [11], "Frank Zappa Outside Now", 340),
    # === Polyrhythmic (multi-label: e.g. [3,4] = both 3/x and 4/x active) ===
    ("Traditional", "Kpanlogo", [3, 4], "Kpanlogo Ghanaian drumming", 300),
    ("Traditional", "Agbekor", [3, 4], "Agbekor Ewe drumming Ghana", 300),
    ("Traditional", "Gahu", [3, 4], "Gahu drumming Ghana", 300),
    ("Traditional", "Rumba Guaguancó", [3, 4], "Rumba Guaguanco Cuban", 300),
    ("Traditional", "Bembe", [3, 4], "Bembe Cuban drumming 6/8 over 4/4", 300),
    ("Traditional", "Afoxé", [3, 4], "Afoxe Brazilian rhythm", 300),
    ("Fela Kuti", "Zombie", [3, 4], "Fela Kuti Zombie afrobeat", 745),
    ("Fela Kuti", "Water No Get Enemy", [3, 4], "Fela Kuti Water No Get Enemy", 600),
    ("Tony Allen", "Asiko", [3, 4], "Tony Allen Asiko afrobeat", 360),
    ("Babatunde Olatunji", "Jin-Go-Lo-Ba", [3, 4], "Babatunde Olatunji Jingo", 300),
    ("Talking Heads", "I Zimbra", [3, 4], "Talking Heads I Zimbra", 195),
    ("Vampire Weekend", "Cape Cod Kwassa Kwassa", [3, 4], None, 230),
    ("Chopin", "Fantaisie-Impromptu", [3, 4], "Chopin Fantaisie Impromptu", 300),
    ("Debussy", "Clair de Lune", [3, 4], "Debussy Clair de Lune", 330),
    ("Brahms", "Piano Concerto No 1 Rondo", [3, 4], "Brahms Piano Concerto 1 Rondo", 600),
    ("Meshuggah", "Bleed", [5, 4], "Meshuggah Bleed", 445),
    ("Meshuggah", "Rational Gaze", [5, 4], None, 330),
    ("Meshuggah", "New Millennium Cyanide Christ", [7, 4], None, 360),
    ("Traditional", "Gamelan Gong Kebyar", [3, 4], "Gamelan Gong Kebyar Bali", 420),
    ("Traditional", "Gamelan Jegog", [3, 4], "Gamelan Jegog bamboo Bali", 360),
    ("Traditional", "Djembe Dununba", [3, 4], "Dununba djembe rhythm West Africa", 300),
    ("Traditional", "Sinte", [3, 4], "Sinte djembe rhythm Mande", 300),
    ("Traditional", "Kuku", [3, 4], "Kuku djembe rhythm Guinea", 300),
]


def sanitize_filename(artist, title):
    name = f"{artist}_{title}".lower()
    name = re.sub(r"[^\w\s-]", "", name)
    name = re.sub(r"[\s]+", "_", name)
    name = re.sub(r"_+", "_", name).strip("_")
    return name[:80]


def get_duration(path):
    try:
        result = subprocess.run(
            ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", str(path)],
            capture_output=True, text=True, timeout=10)
        return float(json.loads(result.stdout)["format"]["duration"])
    except Exception:
        return None


def download_and_segment(query, stem, audio_dir, segment_length=30, expected_duration=None):
    with tempfile.TemporaryDirectory() as tmpdir:
        tmp_path = Path(tmpdir) / "full.%(ext)s"
        if expected_duration:
            min_dur = int(expected_duration / DURATION_TOLERANCE)
            max_dur = int(expected_duration * DURATION_TOLERANCE)
            duration_filter = f"duration > {min_dur} & duration < {max_dur}"
        else:
            duration_filter = f"duration < {MAX_VIDEO_DURATION_FALLBACK}"
        cmd = ["yt-dlp", "--default-search", "ytsearch1", query,
               "-x", "--audio-format", "mp3", "--audio-quality", "5",
               "-o", str(tmp_path), "--no-playlist", "--max-downloads", "1",
               "--match-filter", duration_filter,
               "--quiet", "--no-warnings"]
        try:
            result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
            if result.returncode not in (0, 101):
                return []
        except (subprocess.TimeoutExpired, FileNotFoundError):
            return []

        downloaded = list(Path(tmpdir).glob("full.*"))
        if not downloaded:
            return []
        src = downloaded[0]

        duration = get_duration(src)
        if not duration or duration < 15:
            return []

        margin = 10.0 if duration > 40 else 0.0
        usable_start = margin
        usable_duration = duration - 2 * margin
        if usable_duration < 15:
            usable_start = 0
            usable_duration = duration

        segments = []
        seg_idx = 0
        offset = 0.0
        audio_dir.mkdir(parents=True, exist_ok=True)

        while offset + 15 <= usable_duration and seg_idx < MAX_SEGMENTS:
            seg_dur = min(segment_length, usable_duration - offset)
            seg_stem = f"{stem}_seg{seg_idx:02d}"
            dest = audio_dir / f"{seg_stem}.mp3"
            ffmpeg_cmd = ["ffmpeg", "-y", "-i", str(src),
                          "-ss", f"{usable_start + offset:.1f}",
                          "-t", f"{seg_dur:.1f}",
                          "-acodec", "libmp3lame", "-ab", "192k",
                          "-ar", "44100", "-ac", "1", str(dest)]
            try:
                r = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, timeout=60)
                if r.returncode == 0 and dest.exists() and dest.stat().st_size > 1000:
                    segments.append(seg_stem)
            except subprocess.TimeoutExpired:
                pass
            offset += segment_length
            seg_idx += 1

        return segments


# Download ALL songs including polyrhythmic (multi-label supported)
print(f"Downloading {len(SONGS)} songs from YouTube...")
print(f"Max {MAX_SEGMENTS} segments per song, duration-filtered per song\n")

successful = []
for i, (artist, title, meters, query_override, exp_dur) in enumerate(SONGS, 1):
    stem = sanitize_filename(artist, title)
    query = query_override or f"{artist} {title}"
    meter_str = "+".join(f"{m}/x" for m in meters)
    print(f"[{i:3d}/{len(SONGS)}] {artist} — {title} ({meter_str})", end="", flush=True)

    existing = list(WIKIMETER_AUDIO.glob(f"{stem}_seg*.mp3"))
    if existing:
        n = min(len(existing), MAX_SEGMENTS)
        for f in sorted(existing)[:n]:
            successful.append((f.stem, meters))
        print(f" — skipped ({n} segs)")
        continue

    segs = download_and_segment(query, stem, WIKIMETER_AUDIO, expected_duration=exp_dur)
    if segs:
        print(f" — OK ({len(segs)} segs)")
        for s in segs:
            successful.append((s, meters))
    else:
        print(" — FAILED")

WIKIMETER_DIR.mkdir(parents=True, exist_ok=True)
with open(WIKIMETER_TAB, "w") as f:
    f.write("filename\tlabel\tmeter\talt_meter\n")
    for stem, meters in successful:
        primary = meters[0]
        label = METER_LABELS.get(primary, str(primary))
        meter_str = ",".join(str(m) for m in meters)
        f.write(f'"/{stem}.mp3"\t"{label}"\t{meter_str}\t{primary * 2}\n')

print(f"\nTotal: {len(successful)} segments saved to {WIKIMETER_TAB}")

## 3. Training code

In [None]:
import csv
import warnings
from collections import Counter

import librosa
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, confusion_matrix, f1_score
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

# ── Constants ──
CLASS_METERS = [3, 4, 5, 7, 9, 11]
METER_TO_IDX = {m: i for i, m in enumerate(CLASS_METERS)}
IDX_TO_METER = {i: m for i, m in enumerate(CLASS_METERS)}
MERT_SR = 24000
CHUNK_SAMPLES = 5 * MERT_SR
MAX_DURATION_S = 30
LABEL_SMOOTH_NEG = 0.1

MODEL_CONFIGS = {
    "m-a-p/MERT-v1-95M": (12, 768),
    "m-a-p/MERT-v1-330M": (24, 1024),
}


# ── Data loading ──
def resolve_audio_path(raw_fname, data_dir):
    raw_fname = raw_fname.strip('"').strip("'").strip()
    if not raw_fname:
        return None
    p = Path(raw_fname)
    src_dir = p.parent.name
    stem = p.stem
    audio_dir = data_dir / "audio"
    for ext in (".mp3", ".wav", ".ogg", ".flac", p.suffix):
        for candidate in [
            audio_dir / f"{stem}{ext}",
            audio_dir / f"{src_dir}_{stem}{ext}",
            audio_dir / f"{src_dir}_._{stem}{ext}",
        ]:
            if candidate.exists():
                return candidate
    return None


def parse_label_file(label_path, data_dir, valid_meters=None):
    filename_cols = ["filename", "file", "audio_file", "audio_path", "audio", "path"]
    label_cols = ["meter", "time_signature", "ts", "time_sig", "signature", "label"]
    entries = []
    with open(label_path, "r") as f:
        first_line = f.readline()
    delimiter = "\t" if "\t" in first_line else ","
    with open(label_path, newline="") as f:
        reader = csv.DictReader(f, delimiter=delimiter)
        if not reader.fieldnames:
            return entries
        header_map = {h.strip().lower().strip('"'): h for h in reader.fieldnames}
        fname_key = next((header_map[c] for c in filename_cols if c in header_map), None)
        label_key = next((header_map[c] for c in label_cols if c in header_map), None)
        if not fname_key or not label_key:
            return entries
        for row in reader:
            raw_fname = row.get(fname_key, "").strip().strip('"')
            raw_label = row.get(label_key, "").strip().strip('"')
            if not raw_fname or not raw_label:
                continue
            try:
                meter = int(raw_label.split("/")[0].strip())
            except ValueError:
                continue
            if valid_meters and meter not in valid_meters:
                continue
            audio_path = resolve_audio_path(raw_fname, data_dir)
            if audio_path:
                entries.append((audio_path, meter))
    return entries


def load_split(data_dir, split):
    valid = set(METER_TO_IDX.keys())
    for ext in (".tab", ".csv", ".tsv"):
        p = data_dir / f"data_{split}_4_classes{ext}"
        if p.exists():
            raw = parse_label_file(p, data_dir, valid)
            entries = [(path, [m]) for path, m in raw]
            print(f"  {split}: {len(entries)} entries from {p.name}")
            return entries
    return None


def load_extra_data(extra_dir):
    extra_dir = Path(extra_dir)
    valid = set(METER_TO_IDX.keys())
    entries = []
    for tab_file in sorted(extra_dir.glob("*.tab")):
        with open(tab_file, newline="") as fh:
            reader = csv.DictReader(fh, delimiter="\t")
            for row in reader:
                raw_fname = row.get("filename", "").strip().strip('"')
                raw_meter = row.get("meter", "").strip().strip('"')
                if not raw_fname or not raw_meter:
                    continue
                try:
                    meters = [int(x) for x in raw_meter.split(",")]
                except ValueError:
                    continue
                meters = [m for m in meters if m in valid]
                if not meters:
                    continue
                audio_path = resolve_audio_path(raw_fname, extra_dir)
                if audio_path:
                    entries.append((audio_path, meters))
    return entries


# ── Dataset ──
class MERTAudioDataset(Dataset):
    def __init__(self, entries, augment=False):
        self.entries = [(p, m) for p, m in entries if p.exists() and any(x in METER_TO_IDX for x in m)]
        self.augment = augment

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        path, meters = self.entries[idx]
        label = np.full(len(CLASS_METERS), LABEL_SMOOTH_NEG, dtype=np.float32)
        for m in meters:
            if m in METER_TO_IDX:
                label[METER_TO_IDX[m]] = 1.0
        try:
            audio, _ = librosa.load(str(path), sr=MERT_SR, mono=True)
        except Exception:
            audio = np.zeros(MERT_SR, dtype=np.float32)
        max_samples = MAX_DURATION_S * MERT_SR
        if len(audio) > max_samples:
            if self.augment:
                start = np.random.randint(0, len(audio) - max_samples)
            else:
                start = (len(audio) - max_samples) // 2
            audio = audio[start:start + max_samples]
        if len(audio) < MERT_SR:
            audio = np.pad(audio, (0, MERT_SR - len(audio)))
        if self.augment:
            audio = audio + 0.005 * np.random.randn(len(audio)).astype(np.float32)
            audio = np.roll(audio, np.random.randint(-MERT_SR // 2, MERT_SR // 2))
        return audio.astype(np.float32), label


def simple_collate(batch):
    audios, labels = zip(*batch)
    return list(audios), torch.tensor(np.stack(labels), dtype=torch.float32)


# ── Model ──
class MERTClassificationHead(nn.Module):
    def __init__(self, num_layers, pooled_dim, num_classes=6, head_dim=256, dropout=0.3):
        super().__init__()
        self.num_layers = num_layers
        self.layer_logits = nn.Parameter(torch.zeros(num_layers))
        self.head = nn.Sequential(
            nn.LayerNorm(pooled_dim),
            nn.Linear(pooled_dim, head_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(head_dim, num_classes),
        )

    def forward(self, stacked):
        w = torch.softmax(self.layer_logits, dim=0).view(1, self.num_layers, 1)
        fused = (stacked * w).sum(dim=1)
        return self.head(fused)


def mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim):
    batch_pooled = []
    for audio_np in audios:
        chunks = [audio_np[s:s + CHUNK_SAMPLES] for s in range(0, len(audio_np), CHUNK_SAMPLES)
                  if len(audio_np[s:s + CHUNK_SAMPLES]) >= MERT_SR]
        if not chunks:
            chunks = [audio_np]
        chunk_means = [[] for _ in range(num_layers)]
        chunk_maxes = [[] for _ in range(num_layers)]
        for chunk in chunks:
            inputs = processor(chunk, sampling_rate=MERT_SR, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = mert_model(**inputs)
            for li in range(num_layers):
                hs = outputs.hidden_states[li + 1].squeeze(0)
                chunk_means[li].append(hs.mean(dim=0))
                chunk_maxes[li].append(hs.max(dim=0).values)
        layer_pooled = []
        for li in range(num_layers):
            mean_agg = torch.stack(chunk_means[li]).mean(dim=0)
            max_agg = torch.stack(chunk_maxes[li]).max(dim=0).values
            layer_pooled.append(torch.cat([mean_agg, max_agg]))
        batch_pooled.append(torch.stack(layer_pooled))
    return torch.stack(batch_pooled)


# ── Training ──
def train_one_epoch(mert_model, head, processor, loader, criterion, optimizer,
                    device, num_layers, hidden_dim, grad_accum=1, use_lora=True):
    head.train()
    if use_lora:
        mert_model.train()
    total_loss = correct = total = 0
    optimizer.zero_grad()
    pbar = tqdm(loader, desc="Train", leave=False)
    for step, (audios, labels) in enumerate(pbar):
        labels = labels.to(device)
        if use_lora:
            pooled = mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim)
        else:
            with torch.no_grad():
                pooled = mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim)
            pooled = pooled.detach()
        logits = head(pooled)
        loss = criterion(logits, labels) / grad_accum
        loss.backward()
        if (step + 1) % grad_accum == 0 or step == len(loader) - 1:
            torch.nn.utils.clip_grad_norm_(
                list(head.parameters()) + [p for p in mert_model.parameters() if p.requires_grad], 1.0)
            optimizer.step()
            optimizer.zero_grad()
        total_loss += loss.item() * grad_accum * len(audios)
        preds = logits.argmax(dim=1)
        correct += (preds == labels.argmax(dim=1)).sum().item()
        total += len(audios)
        pbar.set_postfix_str(f"loss={total_loss/total:.3f} acc={correct/total:.0%}")
    return total_loss / max(total, 1), correct / max(total, 1)


@torch.no_grad()
def evaluate(mert_model, head, processor, loader, criterion, device, num_layers, hidden_dim):
    head.eval()
    mert_model.eval()
    total_loss = correct = total = 0
    all_labels_idx, all_preds_idx = [], []
    all_probs, all_labels_raw = [], []
    for audios, labels in tqdm(loader, desc="Eval", leave=False):
        labels = labels.to(device)
        pooled = mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim)
        logits = head(pooled)
        loss = criterion(logits, labels)
        total_loss += loss.item() * len(audios)
        probs = torch.sigmoid(logits)
        preds = logits.argmax(dim=1)
        true_primary = labels.argmax(dim=1)
        correct += (preds == true_primary).sum().item()
        total += len(audios)
        all_labels_idx.extend(true_primary.cpu().tolist())
        all_preds_idx.extend(preds.cpu().tolist())
        all_probs.append(probs.cpu().numpy())
        all_labels_raw.append(labels.cpu().numpy())
    return (total_loss / max(total, 1), correct / max(total, 1),
            all_labels_idx, all_preds_idx,
            np.concatenate(all_probs), np.concatenate(all_labels_raw))


def print_eval_metrics(labels_idx, preds_idx, probs, labels_multihot):
    meter_names = [f"{m}/x" for m in CLASS_METERS]
    num_classes = len(CLASS_METERS)
    true_primary = np.array(labels_idx)

    # Confusion Matrix
    cm = confusion_matrix(labels_idx, preds_idx, labels=list(range(num_classes)))
    print("\nConfusion Matrix:")
    print("          " + "  ".join(f"{n:>6s}" for n in meter_names))
    for i, row in enumerate(cm):
        row_str = "  ".join(f"{v:6d}" for v in row)
        acc = row[i] / row.sum() * 100 if row.sum() > 0 else 0
        print(f"  {meter_names[i]:>6s} | {row_str}   ({acc:5.1f}%)")

    # mAP
    labels_binary = (labels_multihot > 0.5).astype(np.float32)
    print("\nMulti-label metrics:")
    aps = []
    for i, m in enumerate(CLASS_METERS):
        if labels_binary[:, i].sum() > 0:
            ap = average_precision_score(labels_binary[:, i], probs[:, i])
            aps.append(ap)
            print(f"  AP({m}/x): {ap:.3f}")
    if aps:
        print(f"  mAP: {np.mean(aps):.3f}")

    # Macro-F1
    preds_binary = (probs > 0.5).astype(np.int32)
    cols = [i for i in range(num_classes) if labels_binary[:, i].sum() > 0]
    if cols:
        print(f"  Macro-F1: {f1_score(labels_binary[:, cols], preds_binary[:, cols], average='macro', zero_division=0):.3f}")

    # Confidence Gap
    sorted_p = np.sort(probs, axis=1)[:, ::-1]
    p_top1, p_top2 = sorted_p[:, 0], sorted_p[:, 1]
    gaps = p_top1 - p_top2
    print(f"\nConfidence Gap: mean={gaps.mean():.3f}, median={np.median(gaps):.3f}")

    # Entropy
    eps = 1e-7
    pc = np.clip(probs, eps, 1 - eps)
    H = -(pc * np.log2(pc) + (1 - pc) * np.log2(1 - pc)).sum(axis=1)
    H_norm = H / (num_classes * np.log2(2))
    print(f"H_norm: mean={H_norm.mean():.3f}, median={np.median(H_norm):.3f}")

    # Correlation
    corr = np.corrcoef(probs.T)
    pairs = []
    for i in range(num_classes):
        for j in range(i + 1, num_classes):
            pairs.append((corr[i, j], f"{CLASS_METERS[i]}/x↔{CLASS_METERS[j]}/x"))
    pairs.sort(key=lambda x: -abs(x[0]))
    print("\nTop correlations:")
    for r, name in pairs[:5]:
        flag = " ⚠" if r > 0.3 else ""
        print(f"  {name}: r={r:+.3f}{flag}")

    # Noise floor
    print(f"\nNoise Floor (P_top2): P95={np.percentile(p_top2, 95):.4f}, P99={np.percentile(p_top2, 99):.4f}")

print("Training code loaded ✓")

## 4. Configure & run training

In [None]:
# ══════════════════════════════════════════════
#  CONFIGURATION — edit these!
# ══════════════════════════════════════════════

MODEL_NAME = "m-a-p/MERT-v1-330M"  # or "m-a-p/MERT-v1-95M" for free T4
EPOCHS = 30
BATCH_SIZE = 4
GRAD_ACCUM = 8
LORA_RANK = 16
LORA_ALPHA = 32
USE_LORA = True
USE_EXTRA_DATA = True
CHECKPOINT_PATH = Path("/content/meter_mert_lora.pt")

# Auto-scale LR per model size
LR_CONFIGS = {
    "m-a-p/MERT-v1-95M":  {"head_lr": 5e-4, "lora_lr": 1e-4},
    "m-a-p/MERT-v1-330M": {"head_lr": 1e-4, "lora_lr": 2e-5},
}
lr_cfg = LR_CONFIGS[MODEL_NAME]
HEAD_LR = lr_cfg["head_lr"]
LORA_LR = lr_cfg["lora_lr"]

# ══════════════════════════════════════════════

device = torch.device("cuda")
num_layers, hidden_dim = MODEL_CONFIGS[MODEL_NAME]
pooled_dim = hidden_dim * 2

# Load data
print("Loading data...")
train_entries = load_split(DATA_DIR, "train")
val_entries = load_split(DATA_DIR, "val")
test_entries = load_split(DATA_DIR, "test")

if USE_EXTRA_DATA and WIKIMETER_TAB.exists():
    extra = load_extra_data(WIKIMETER_DIR)
    print(f"  Extra data: +{len(extra)} entries")
    train_entries.extend(extra)

# Show distribution
counts = Counter()
for _, meters in train_entries + val_entries + test_entries:
    for m in meters:
        counts[m] += 1
print(f"\nTotal: {len(train_entries)} train, {len(val_entries)} val, {len(test_entries)} test")
for m in CLASS_METERS:
    print(f"  Meter {m}: {counts.get(m, 0)}")

# Load model
print(f"\nLoading {MODEL_NAME}...")
from transformers import AutoModel, Wav2Vec2FeatureExtractor
processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME, trust_remote_code=True)
mert_model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True, output_hidden_states=True).to(device)
print(f"  Parameters: {sum(p.numel() for p in mert_model.parameters())/1e6:.0f}M")

# LoRA
if USE_LORA:
    from peft import LoraConfig, get_peft_model
    lora_config = LoraConfig(r=LORA_RANK, lora_alpha=LORA_ALPHA,
                             target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
    mert_model = get_peft_model(mert_model, lora_config)
    trainable, total_p = mert_model.get_nb_trainable_parameters()
    print(f"  LoRA: {trainable:,} / {total_p:,} ({trainable/total_p:.2%})")
else:
    for p in mert_model.parameters():
        p.requires_grad = False
    print("  MERT frozen (no LoRA)")

# Head
head = MERTClassificationHead(num_layers, pooled_dim, len(CLASS_METERS)).to(device)
print(f"  Head: {sum(p.numel() for p in head.parameters()):,} params")

# Data
train_ds = MERTAudioDataset(train_entries, augment=True)
val_ds = MERTAudioDataset(val_entries)
test_ds = MERTAudioDataset(test_entries)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=simple_collate)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, collate_fn=simple_collate)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, collate_fn=simple_collate)

# Loss
pos_counts = np.zeros(len(CLASS_METERS), dtype=np.float32)
for _, meters in train_ds.entries:
    for m in meters:
        if m in METER_TO_IDX:
            pos_counts[METER_TO_IDX[m]] += 1
neg_counts = len(train_ds) - pos_counts
with np.errstate(divide="ignore", invalid="ignore"):
    pos_weights = torch.tensor(np.where(pos_counts > 0, neg_counts / pos_counts, 1.0), dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
print(f"Pos weights: {pos_weights.tolist()}")

# Optimizer
param_groups = [{"params": head.parameters(), "lr": HEAD_LR}]
if USE_LORA:
    lora_params = [p for p in mert_model.parameters() if p.requires_grad]
    if lora_params:
        param_groups.append({"params": lora_params, "lr": LORA_LR})
optimizer = torch.optim.AdamW(param_groups, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

print(f"\nReady! Model: {MODEL_NAME}")
print(f"  Head LR: {HEAD_LR}, LoRA LR: {LORA_LR}")
print(f"  Effective batch: {BATCH_SIZE * GRAD_ACCUM}")
print(f"  Datasets: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test")

In [None]:
# ══════════════════════════════════════════════
#  TRAINING LOOP
# ══════════════════════════════════════════════

best_val_acc = 0.0
best_val_loss = float("inf")
best_head_state = None
best_lora_state = None
patience_counter = 0
PATIENCE = 10

print(f"{'Epoch':>5s}  {'TrainLoss':>10s}  {'TrainAcc':>9s}  {'ValLoss':>10s}  {'ValAcc':>9s}  {'Time':>6s}")
print("-" * 60)

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    train_loss, train_acc = train_one_epoch(
        mert_model, head, processor, train_loader, criterion, optimizer,
        device, num_layers, hidden_dim, GRAD_ACCUM, USE_LORA)

    val_loss, val_acc, _, _, _, _ = evaluate(
        mert_model, head, processor, val_loader, criterion,
        device, num_layers, hidden_dim)

    scheduler.step()
    elapsed = time.time() - t0

    marker = ""
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_head_state = {k: v.cpu().clone() for k, v in head.state_dict().items()}
        if USE_LORA:
            best_lora_state = {n: p.cpu().clone() for n, p in mert_model.named_parameters() if p.requires_grad}
        # Save to disk immediately so crashes don't lose progress
        torch.save({
            "head_state_dict": best_head_state,
            "lora_state_dict": best_lora_state,
            "class_map": IDX_TO_METER,
            "model_name": MODEL_NAME,
            "num_layers": num_layers,
            "hidden_dim": hidden_dim,
            "pooled_dim": pooled_dim,
            "head_dim": 256,
            "num_classes": len(CLASS_METERS),
            "dropout": 0.3,
            "val_accuracy": best_val_acc,
            "epoch": epoch,
            "lora_rank": LORA_RANK if USE_LORA else 0,
            "lora_alpha": LORA_ALPHA if USE_LORA else 0,
            "model_type": "MERTFineTuned",
        }, CHECKPOINT_PATH)
        marker = f"  ** saved {CHECKPOINT_PATH.name}"

    print(f"{epoch:5d}  {train_loss:10.4f}  {train_acc:8.1%}  {val_loss:10.4f}  {val_acc:8.1%}  {elapsed:5.0f}s{marker}")

    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\nEarly stopping at epoch {epoch}")
            break

print(f"\nBest val accuracy: {best_val_acc:.1%}")
if CHECKPOINT_PATH.exists():
    print(f"Checkpoint on disk: {CHECKPOINT_PATH} ({CHECKPOINT_PATH.stat().st_size / 1e6:.1f} MB)")

## 5. Test evaluation

In [None]:
# Load best model
if best_head_state:
    head.load_state_dict(best_head_state)
    head = head.to(device)
if best_lora_state:
    for name, param_data in best_lora_state.items():
        parts = name.split(".")
        obj = mert_model
        for part in parts[:-1]:
            obj = getattr(obj, part)
        getattr(obj, parts[-1]).data.copy_(param_data.to(device))

# Evaluate
test_loss, test_acc, test_labels, test_preds, test_probs, test_labels_mh = evaluate(
    mert_model, head, processor, test_loader, criterion, device, num_layers, hidden_dim)

print(f"Test loss: {test_loss:.4f}")
print(f"Test accuracy: {test_acc:.1%} ({sum(1 for a, b in zip(test_labels, test_preds) if a == b)}/{len(test_labels)})")
print(f"Best val accuracy: {best_val_acc:.1%}")

print_eval_metrics(test_labels, test_preds, test_probs, test_labels_mh)

## 6. Save & download checkpoint

In [None]:
# Checkpoint is already saved to disk during training.
# This cell just verifies and triggers download.

if CHECKPOINT_PATH.exists():
    ckpt = torch.load(CHECKPOINT_PATH, weights_only=False)
    print(f"Checkpoint: {CHECKPOINT_PATH}")
    print(f"  Model: {ckpt.get('model_name', '?')}")
    print(f"  Best val: {ckpt.get('val_accuracy', 0):.1%} (epoch {ckpt.get('epoch', '?')})")
    print(f"  Size: {CHECKPOINT_PATH.stat().st_size / 1e6:.1f} MB")

    # Run test eval with saved weights if not already done
    if 'test_accuracy' not in ckpt:
        print("\nRunning test evaluation...")
        head.load_state_dict(ckpt["head_state_dict"])
        head = head.to(device)
        if ckpt.get("lora_state_dict"):
            for name, param_data in ckpt["lora_state_dict"].items():
                parts = name.split(".")
                obj = mert_model
                for part in parts[:-1]:
                    obj = getattr(obj, part)
                getattr(obj, parts[-1]).data.copy_(param_data.to(device))

        test_loss, test_acc, test_labels, test_preds, test_probs, test_labels_mh = evaluate(
            mert_model, head, processor, test_loader, criterion, device, num_layers, hidden_dim)
        print(f"Test accuracy: {test_acc:.1%} ({sum(1 for a, b in zip(test_labels, test_preds) if a == b)}/{len(test_labels)})")
        print_eval_metrics(test_labels, test_preds, test_probs, test_labels_mh)

        ckpt["test_accuracy"] = test_acc
        torch.save(ckpt, CHECKPOINT_PATH)

    from google.colab import files
    files.download(str(CHECKPOINT_PATH))
else:
    print("No checkpoint found! Training may not have completed.")