# MERT LoRA Fine-tuning for Meter Classification

**GPU**: Runtime → Change runtime type → T4 GPU (free) or A100 (Pro)

This notebook:
1. Downloads METER2800 dataset from Harvard Dataverse (local, fast)
2. Downloads WIKIMETER from YouTube to **Google Drive** (persistent, slow)
3. Fine-tunes MERT with LoRA for multi-label meter classification
4. Saves checkpoints to Drive every epoch (crash-safe)

**Resumable**: Just re-run all cells. Everything skips what's already done.

## Configuration

Change these before running. Everything else auto-configures.

In [None]:
# ┌──────────────────────────────────────────┐
# │  CHANGE THESE BEFORE RUNNING             │
# └──────────────────────────────────────────┘

MODEL_NAME = "m-a-p/MERT-v1-330M"  # or "m-a-p/MERT-v1-95M" for free T4
RUN_NAME = None    # None = auto-unique (params + timestamp + random suffix); set manually to resume same run.
AUTO_RESUME_TRACKED_RUN = True  # when RUN_NAME is None, reuse last tracked run name from Drive
USE_LORA = True
LORA_RANK = 16
LORA_ALPHA = 32
EPOCHS = 80
PATIENCE = 15      # early stopping patience (epochs without improvement)
NOISE_STD = 0.01   # augmentation noise
HEAD_DROPOUT = 0.4
WIKI_VAL_RATIO = 0.1  # fraction of WIKIMETER songs held out for val
NUM_WORKERS = 2     # DataLoader workers in Colab (2 is a safe default)
CHUNK_BATCH_SIZE = 0  # 0=auto (2 on 16-24GB GPUs, 4 on >=30GB); higher is faster but uses more VRAM
AUDIO_CACHE_DIR = "/content/audio_cache_24k"  # set to None to disable waveform cache
USE_AMP = True      # mixed precision on CUDA
AMP_DTYPE = "bf16"  # "bf16" (L4/A100) or "fp16" (fallback / older GPUs)
USE_TF32 = True     # faster matmul/conv on NVIDIA Ampere+


## 0. Install & GPU check

In [None]:
!pip install -q transformers peft librosa scikit-learn tqdm yt-dlp
!apt-get -qq install ffmpeg

# Install torchaudio only if missing, matching installed torch build.
import importlib.util
import subprocess
import sys

if importlib.util.find_spec("torchaudio") is None:
    import torch
    torch_ver = torch.__version__.split("+")[0]
    cuda_ver = (torch.version.cuda or "").replace(".", "")
    index_url = f"https://download.pytorch.org/whl/cu{cuda_ver}" if cuda_ver else "https://download.pytorch.org/whl/cpu"
    subprocess.check_call([
        sys.executable,
        "-m",
        "pip",
        "install",
        "-q",
        "--index-url",
        index_url,
        f"torchaudio=={torch_ver}",
    ])

import torchaudio
print(f"torchaudio: {torchaudio.__version__}")


In [None]:
import os

# Helps reduce CUDA allocator fragmentation on long Colab runs.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU'}")
print(f"CUDA alloc conf: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')}")
if not torch.cuda.is_available():
    raise RuntimeError("No GPU! Go to Runtime → Change runtime type → T4 GPU")


## 1. Mount Google Drive & setup paths

All persistent data lives on Drive:
- `wikimeter/audio/` — YouTube segments (survive restarts)
- `checkpoints/` — model checkpoints (every epoch)
- `wikimeter.json` — song catalog (upload once)

In [None]:
from google.colab import drive
from pathlib import Path

drive.mount('/content/drive')

# Persistent paths (Google Drive)
DRIVE_DIR = Path("/content/drive/MyDrive/beatmeter")
DRIVE_DIR.mkdir(parents=True, exist_ok=True)
WIKIMETER_DIR = DRIVE_DIR / "wikimeter"
WIKIMETER_AUDIO = WIKIMETER_DIR / "audio"
WIKIMETER_TAB = WIKIMETER_DIR / "data_wikimeter.tab"
CATALOG_PATH = DRIVE_DIR / "wikimeter.json"
CHECKPOINT_DIR = DRIVE_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# Local paths (ephemeral, re-downloaded each session)
METER2800_DIR = Path("/content/data/meter2800")
METER2800_AUDIO = METER2800_DIR / "audio"

print(f"Drive dir:     {DRIVE_DIR}")
print(f"Wikimeter:     {WIKIMETER_DIR}")
print(f"Checkpoints:   {CHECKPOINT_DIR}")
print(f"METER2800:     {METER2800_DIR}")

## 2. Upload wikimeter.json (first time only)

Upload `scripts/setup/wikimeter.json` from the repo. Only needed once — it stays on Drive.

In [None]:
import json

if CATALOG_PATH.exists():
    catalog = json.loads(CATALOG_PATH.read_text())
    print(f"Catalog already on Drive: {len(catalog)} songs")
else:
    from google.colab import files
    print("Upload wikimeter.json from scripts/setup/wikimeter.json")
    uploaded = files.upload()
    for fname, data in uploaded.items():
        CATALOG_PATH.write_bytes(data)
        catalog = json.loads(data)
        print(f"Saved {len(catalog)} songs to {CATALOG_PATH}")

# Show balance
from collections import Counter
cats = Counter()
for e in catalog:
    keys = list(e['meters'].keys())
    if len(keys) == 1:
        cats[keys[0]] += 1
    else:
        cats['poly'] += 1
print(f"Balance: {dict(sorted(cats.items()))}")

## 3. Download METER2800

In [None]:
import hashlib
import ssl
import tarfile
import time
import urllib.request

AUDIO_EXTENSIONS = {".wav", ".mp3", ".ogg", ".flac", ".oga", ".opus", ".aiff", ".aif"}
DOI = "doi:10.7910/DVN/0CLXBQ"
API_BASE = "https://dataverse.harvard.edu/api"
TAR_FILES = {"FMA.tar.gz", "MAG.tar.gz", "OWN.tar.gz"}


def api_request(url, max_retries=3):
    ctx = ssl.create_default_context()
    for attempt in range(max_retries):
        try:
            req = urllib.request.Request(url, headers={"User-Agent": "RhythmAnalyzer/1.0"})
            with urllib.request.urlopen(req, timeout=30, context=ctx) as resp:
                return resp.read()
        except Exception:
            if attempt < max_retries - 1:
                time.sleep(2 * (2 ** attempt))
            else:
                raise


def download_dataverse_file(file_id, dest):
    url = f"{API_BASE}/access/datafile/{file_id}"
    ctx = ssl.create_default_context()
    req = urllib.request.Request(url, headers={"User-Agent": "RhythmAnalyzer/1.0"})
    dest.parent.mkdir(parents=True, exist_ok=True)
    with urllib.request.urlopen(req, timeout=300, context=ctx) as resp:
        dest.write_bytes(resp.read())


def extract_tar(tar_path, audio_dir):
    audio_dir.mkdir(parents=True, exist_ok=True)
    count = 0
    with tarfile.open(tar_path, "r:gz") as tar:
        for member in tar.getmembers():
            if not member.isfile():
                continue
            name = Path(member.name).name
            if Path(name).suffix.lower() not in AUDIO_EXTENSIONS:
                continue
            dest = audio_dir / name
            if dest.exists():
                dest = audio_dir / f"{Path(member.name).parent.name}_{name}"
            with tar.extractfile(member) as src:
                if src:
                    dest.write_bytes(src.read())
                    count += 1
    return count


if METER2800_AUDIO.exists() and len(list(METER2800_AUDIO.glob("*"))) > 2000:
    n = len([f for f in METER2800_AUDIO.iterdir() if f.suffix.lower() in AUDIO_EXTENSIONS])
    print(f"METER2800 already downloaded: {n} audio files")
else:
    METER2800_DIR.mkdir(parents=True, exist_ok=True)
    downloads_dir = METER2800_DIR / "downloads"
    downloads_dir.mkdir(parents=True, exist_ok=True)
    METER2800_AUDIO.mkdir(parents=True, exist_ok=True)

    print("Fetching METER2800 metadata...")
    metadata = json.loads(api_request(f"{API_BASE}/datasets/:persistentId/?persistentId={DOI}"))
    files = metadata["data"]["latestVersion"]["files"]

    for f in files:
        df = f["dataFile"]
        fname = df["filename"]
        if fname in TAR_FILES:
            dest = downloads_dir / fname
            if not dest.exists():
                print(f"  Downloading {fname}...", end=" ", flush=True)
                download_dataverse_file(df["id"], dest)
                print("OK")
        elif fname.endswith((".csv", ".tab", ".tsv")):
            dest = METER2800_DIR / fname
            if not dest.exists():
                print(f"  Downloading {fname}...", end=" ", flush=True)
                download_dataverse_file(df["id"], dest)
                print("OK")

    total = 0
    for tar_path in sorted(downloads_dir.glob("*.tar.gz")):
        print(f"  Extracting {tar_path.name}...", end=" ", flush=True)
        count = extract_tar(tar_path, METER2800_AUDIO)
        print(f"{count} files")
        total += count
    print(f"Total: {total} audio files")

n_audio = len([f for f in METER2800_AUDIO.iterdir() if f.suffix.lower() in AUDIO_EXTENSIONS])
n_tabs = len(list(METER2800_DIR.glob("*.tab")))
print(f"METER2800: {n_audio} audio, {n_tabs} label files")

## 4. Download WIKIMETER (YouTube → Google Drive)

Downloads are stored on Drive — survives Colab restarts.
Re-running skips already downloaded songs.

In [None]:
import re
import subprocess
import tempfile
import unicodedata

METER_LABELS = {3: "three", 4: "four", 5: "five", 7: "seven", 9: "nine", 11: "eleven"}
MAX_SEGMENTS = 5
SEGMENT_LENGTH = 35


def _slugify(name, ascii_only=False):
    s = name.lower()
    if ascii_only:
        s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    s = re.sub(r"[^\w\s-]", "", s)
    s = re.sub(r"[\s]+", "_", s)
    s = re.sub(r"_+", "_", s).strip("_")
    return s[:80]


def sanitize_filename(artist, title):
    # Canonical stem for new downloads (stable ASCII form).
    return _slugify(f"{artist}_{title}", ascii_only=True)


def legacy_sanitize_filename(artist, title):
    # Backward-compatible stem used by older notebook versions.
    return _slugify(f"{artist}_{title}", ascii_only=False)


def get_duration(path):
    try:
        r = subprocess.run(
            ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", str(path)],
            capture_output=True, text=True, timeout=10)
        return float(json.loads(r.stdout)["format"]["duration"])
    except Exception:
        return None


def download_and_segment(video_id, stem, audio_dir):
    """Download by video_id, segment, return list of segment stems."""
    with tempfile.TemporaryDirectory() as tmpdir:
        tmp_path = Path(tmpdir) / "full.%(ext)s"
        url = f"https://www.youtube.com/watch?v={video_id}"
        cmd = ["yt-dlp", url, "-x", "--audio-format", "mp3", "--audio-quality", "5",
               "-o", str(tmp_path), "--no-playlist", "--quiet", "--no-warnings"]
        try:
            r = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
            if r.returncode != 0:
                return []
        except (subprocess.TimeoutExpired, FileNotFoundError):
            return []

        downloaded = list(Path(tmpdir).glob("full.*"))
        if not downloaded:
            return []
        src = downloaded[0]

        duration = get_duration(src)
        if not duration or duration < 15:
            return []

        # Segment evenly across track
        margin = 10.0 if duration > 40 else 0.0
        usable_start = margin
        usable_duration = duration - 2 * margin
        if usable_duration < 15:
            usable_start, usable_duration = 0, duration

        MIN_GAP = 5.0
        stride = SEGMENT_LENGTH + MIN_GAP
        n_seg = min(MAX_SEGMENTS, max(1, int((usable_duration + MIN_GAP) // stride)))

        if n_seg == 1:
            starts = [max(0.0, (usable_duration - SEGMENT_LENGTH) / 2)]
        else:
            actual_stride = (usable_duration - SEGMENT_LENGTH) / (n_seg - 1)
            starts = [i * actual_stride for i in range(n_seg)]

        audio_dir.mkdir(parents=True, exist_ok=True)
        segments = []
        for idx, offset in enumerate(starts):
            seg_dur = min(SEGMENT_LENGTH, usable_duration - offset)
            if seg_dur < 15:
                continue
            seg_stem = f"{stem}_seg{idx:02d}"
            dest = audio_dir / f"{seg_stem}.mp3"
            try:
                r = subprocess.run(
                    ["ffmpeg", "-y", "-i", str(src),
                     "-ss", f"{usable_start + offset:.1f}", "-t", f"{seg_dur:.1f}",
                     "-acodec", "libmp3lame", "-ab", "192k", "-ar", "44100", "-ac", "1",
                     str(dest)],
                    capture_output=True, text=True, timeout=60)
                if r.returncode == 0 and dest.exists() and dest.stat().st_size > 1000:
                    segments.append(seg_stem)
            except subprocess.TimeoutExpired:
                pass
        return segments


def meters_to_tab_str(meters_dict):
    parts = []
    for m, w in meters_dict.items():
        parts.append(str(m) if w == 1.0 else f"{m}:{w}")
    return ",".join(parts)


# Load catalog
catalog = json.loads(CATALOG_PATH.read_text())
print(f"WIKIMETER: {len(catalog)} songs")

# Download (skips existing)
successful = []
stats = {"ok": 0, "skip": 0, "fail": 0}

for i, song in enumerate(catalog, 1):
    artist, title = song["artist"], song["title"]
    video_id = song["video_id"]
    meters = {int(k): v for k, v in song["meters"].items()}
    stem = sanitize_filename(artist, title)
    legacy_stem = legacy_sanitize_filename(artist, title)
    meters_str = "+".join(str(m) for m in meters)

    existing = []
    if WIKIMETER_AUDIO.exists():
        existing_map = {}
        for stem_candidate in {stem, legacy_stem}:
            for f in WIKIMETER_AUDIO.glob(f"{stem_candidate}_seg*.mp3"):
                if f.stat().st_size > 1000:
                    existing_map[str(f)] = f
        existing = sorted(existing_map.values(), key=lambda x: x.name)
    if existing:
        for f in existing[:MAX_SEGMENTS]:
            successful.append((f.stem, meters))
        stats["skip"] += 1
        continue

    print(f"  [{i:3d}/{len(catalog)}] {artist} — {title} ({meters_str})", end="", flush=True)
    segments = download_and_segment(video_id, stem, WIKIMETER_AUDIO)

    if segments:
        print(f" — {len(segments)} segs")
        for s in segments:
            successful.append((s, meters))
        stats["ok"] += 1
    else:
        print(" — FAILED")
        stats["fail"] += 1

# Write .tab
WIKIMETER_DIR.mkdir(parents=True, exist_ok=True)
with open(WIKIMETER_TAB, "w") as f:
    f.write("filename\tlabel\tmeter\talt_meter\n")
    for stem, meters in successful:
        meter_list = list(meters.keys())
        primary = meter_list[0]
        label = METER_LABELS.get(primary, str(primary))
        meter_str = meters_to_tab_str(meters)
        f.write(f'"/WIKIMETER/{stem}.mp3"\t"{label}"\t{meter_str}\t{primary * 2}\n')

# Summary
seg_cats = Counter()
for _, meters in successful:
    for m in meters:
        seg_cats[m] += 1
seg_summary = ", ".join(f"{seg_cats[m]}×{m}/x" for m in sorted(seg_cats))
print(f"\nDone: {stats['ok']} downloaded, {stats['skip']} skipped, {stats['fail']} failed")
print(f"Segments: {len(successful)} ({seg_summary})")


## 4b. Copy WIKIMETER to local SSD (faster I/O)

Drive FUSE is slow. Copying ~900MB to `/content/` local SSD speeds up training significantly.

In [None]:
import shutil

WIKIMETER_LOCAL = Path("/content/wikimeter_local")
WIKIMETER_LOCAL_AUDIO = WIKIMETER_LOCAL / "audio"
WIKIMETER_LOCAL_TAB = WIKIMETER_LOCAL / "data_wikimeter.tab"

if WIKIMETER_AUDIO.exists():
    n_drive = len(list(WIKIMETER_AUDIO.glob("*.mp3")))
    n_local = len(list(WIKIMETER_LOCAL_AUDIO.glob("*.mp3"))) if WIKIMETER_LOCAL_AUDIO.exists() else 0

    if n_local >= n_drive and n_local > 0:
        print(f"WIKIMETER already on local SSD: {n_local} files")
    else:
        print(f"Copying WIKIMETER to local SSD ({n_drive} files)...", end=" ", flush=True)
        if WIKIMETER_LOCAL.exists():
            shutil.rmtree(WIKIMETER_LOCAL)
        WIKIMETER_LOCAL.mkdir(parents=True, exist_ok=True)
        shutil.copytree(WIKIMETER_AUDIO, WIKIMETER_LOCAL_AUDIO)
        if WIKIMETER_TAB.exists():
            shutil.copy2(WIKIMETER_TAB, WIKIMETER_LOCAL_TAB)
        n_copied = len(list(WIKIMETER_LOCAL_AUDIO.glob("*.mp3")))
        print(f"Done — {n_copied} files")
else:
    WIKIMETER_LOCAL = None
    print("No WIKIMETER audio on Drive yet")

## 5. Training code

In [None]:
import csv
import hashlib
import os
import time
import warnings

import librosa
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, f1_score
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

try:
    import torchaudio
    HAS_TORCHAUDIO = True
except Exception:
    torchaudio = None
    HAS_TORCHAUDIO = False

if not HAS_TORCHAUDIO:
    print("WARNING: torchaudio unavailable; falling back to librosa decoding")

CLASS_METERS = [3, 4, 5, 7, 9, 11]
METER_TO_IDX = {m: i for i, m in enumerate(CLASS_METERS)}
IDX_TO_METER = {i: m for i, m in enumerate(CLASS_METERS)}
MERT_SR = 24000
CHUNK_SAMPLES = 5 * MERT_SR
MAX_DURATION_S = 30
LABEL_SMOOTH_NEG = 0.1

MODEL_CONFIGS = {
    "m-a-p/MERT-v1-95M": {"num_layers": 12, "hidden_dim": 768,
                           "head_lr": 1e-3, "lora_lr": 5e-5, "batch_size": 8, "grad_accum": 4},
    "m-a-p/MERT-v1-330M": {"num_layers": 24, "hidden_dim": 1024,
                            "head_lr": 5e-4, "lora_lr": 5e-5, "batch_size": 2, "grad_accum": 16},
}


def resolve_audio_path(raw_fname, data_dir):
    raw_fname = raw_fname.strip('"').strip("'").strip()
    if not raw_fname:
        return None
    p = Path(raw_fname)
    src_dir = p.parent.name
    stem = p.stem
    audio_dir = data_dir / "audio"
    for ext in (".mp3", ".wav", ".ogg", ".flac", p.suffix):
        for candidate in [
            audio_dir / f"{stem}{ext}",
            audio_dir / f"{src_dir}_{stem}{ext}",
            audio_dir / f"{src_dir}_._{stem}{ext}",
        ]:
            if candidate.exists():
                return candidate
    return None


def parse_label_file(label_path, data_dir, valid_meters=None):
    filename_cols = ["filename", "file", "audio_file", "audio_path", "audio", "path"]
    label_cols = ["meter", "time_signature", "ts", "time_sig", "signature", "label"]
    entries = []
    with open(label_path, "r") as f:
        first_line = f.readline()
    delimiter = "\t" if "\t" in first_line else ","
    with open(label_path, newline="") as f:
        reader = csv.DictReader(f, delimiter=delimiter)
        if not reader.fieldnames:
            return entries
        header_map = {h.strip().lower().strip('"'): h for h in reader.fieldnames}
        fname_key = next((header_map[c] for c in filename_cols if c in header_map), None)
        label_key = next((header_map[c] for c in label_cols if c in header_map), None)
        if not fname_key or not label_key:
            return entries
        for row in reader:
            raw_fname = row.get(fname_key, "").strip().strip('"')
            raw_label = row.get(label_key, "").strip().strip('"')
            if not raw_fname or not raw_label:
                continue
            try:
                meter = int(raw_label.split("/")[0].strip())
            except ValueError:
                continue
            if valid_meters and meter not in valid_meters:
                continue
            audio_path = resolve_audio_path(raw_fname, data_dir)
            if audio_path:
                entries.append((audio_path, meter))
    return entries


def load_split(data_dir, split):
    valid = set(METER_TO_IDX.keys())
    for ext in (".tab", ".csv", ".tsv"):
        p = data_dir / f"data_{split}_4_classes{ext}"
        if p.exists():
            raw = parse_label_file(p, data_dir, valid)
            entries = [(path, [m]) for path, m in raw]
            print(f"  {split}: {len(entries)} entries")
            return entries
    return None


def detect_delimiter(path):
    with open(path, "r", encoding="utf-8") as f:
        first_line = f.readline()
    return "\t" if "\t" in first_line else ","


def load_extra_data(extra_dir):
    extra_dir = Path(extra_dir)
    valid = set(METER_TO_IDX.keys())
    entries = []
    for tab_file in sorted(p for p in extra_dir.iterdir() if p.suffix in (".tab", ".csv", ".tsv")):
        delimiter = detect_delimiter(tab_file)
        with open(tab_file, newline="") as fh:
            reader = csv.DictReader(fh, delimiter=delimiter)
            for row in reader:
                raw_fname = row.get("filename", "").strip().strip('"')
                raw_meter = row.get("meter", "").strip().strip('"')
                if not raw_fname or not raw_meter:
                    continue
                # Parse soft labels: "3:0.9,4:0.8" or hard: "3,4" or "3"
                meters = []
                for part in raw_meter.split(","):
                    part = part.strip()
                    if ":" in part:
                        m = int(part.split(":")[0])
                    else:
                        m = int(part)
                    if m in valid:
                        meters.append(m)
                if not meters:
                    continue
                audio_path = resolve_audio_path(raw_fname, extra_dir)
                if audio_path:
                    entries.append((audio_path, meters))
    return entries


def _audio_cache_key(path):
    st = path.stat()
    raw = f"{path.resolve()}::{st.st_size}::{st.st_mtime_ns}::{MERT_SR}"
    return hashlib.sha1(raw.encode("utf-8")).hexdigest()


def _decode_audio_resampled(path):
    if HAS_TORCHAUDIO:
        try:
            wave, sr = torchaudio.load(str(path))
            if wave.ndim == 2 and wave.shape[0] > 1:
                wave = wave.mean(dim=0, keepdim=True)
            if sr != MERT_SR:
                wave = torchaudio.functional.resample(wave, sr, MERT_SR)
            return wave.squeeze(0).cpu().numpy().astype(np.float32, copy=False)
        except Exception:
            pass

    try:
        audio, _ = librosa.load(str(path), sr=MERT_SR, mono=True)
        return audio.astype(np.float32, copy=False)
    except Exception:
        return np.zeros(MERT_SR, dtype=np.float32)


class MERTAudioDataset(Dataset):
    def __init__(self, entries, augment=False, noise_std=0.01, cache_dir=None):
        self.entries = []
        self.cache_paths = []
        self.augment = augment
        self.noise_std = noise_std
        self.cache_dir = Path(cache_dir).resolve() if cache_dir is not None else None
        if self.cache_dir is not None:
            self.cache_dir.mkdir(parents=True, exist_ok=True)

        for p, m in entries:
            if p.exists() and any(x in METER_TO_IDX for x in m):
                valid = [x for x in m if x in METER_TO_IDX]
                if not valid:
                    continue
                self.entries.append((p, valid))
                if self.cache_dir is not None:
                    self.cache_paths.append(self.cache_dir / f"{_audio_cache_key(p)}.npy")
                else:
                    self.cache_paths.append(None)

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        path, meters = self.entries[idx]
        cache_path = self.cache_paths[idx]

        label = np.full(len(CLASS_METERS), LABEL_SMOOTH_NEG, dtype=np.float32)
        for m in meters:
            if m in METER_TO_IDX:
                label[METER_TO_IDX[m]] = 1.0

        audio = None
        if cache_path is not None and cache_path.exists():
            try:
                audio = np.load(cache_path, allow_pickle=False).astype(np.float32, copy=False)
            except Exception:
                audio = None
        if audio is None:
            audio = _decode_audio_resampled(path)
            if cache_path is not None:
                tmp_name = f"{cache_path.name}.tmp-{os.getpid()}-{time.time_ns()}"
                tmp_path = cache_path.with_name(tmp_name)
                try:
                    with open(tmp_path, "wb") as fh:
                        np.save(fh, audio.astype(np.float32, copy=False))
                    os.replace(tmp_path, cache_path)
                except Exception:
                    try:
                        tmp_path.unlink(missing_ok=True)
                    except Exception:
                        pass

        max_samples = MAX_DURATION_S * MERT_SR
        if len(audio) > max_samples:
            start = np.random.randint(0, len(audio) - max_samples) if self.augment else (len(audio) - max_samples) // 2
            audio = audio[start:start + max_samples]
        if len(audio) < MERT_SR:
            audio = np.pad(audio, (0, MERT_SR - len(audio)))
        if self.augment:
            audio = audio + self.noise_std * np.random.randn(len(audio)).astype(np.float32)
            audio = np.roll(audio, np.random.randint(-MERT_SR // 2, MERT_SR // 2))
        return audio.astype(np.float32), label


def simple_collate(batch):
    audios, labels = zip(*batch)
    return list(audios), torch.tensor(np.stack(labels), dtype=torch.float32)


class MERTClassificationHead(nn.Module):
    def __init__(self, num_layers, pooled_dim, num_classes=6, head_dim=256, dropout=0.4):
        super().__init__()
        self.num_layers = num_layers
        self.layer_logits = nn.Parameter(torch.zeros(num_layers))
        self.head = nn.Sequential(
            nn.LayerNorm(pooled_dim),
            nn.Linear(pooled_dim, head_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(head_dim, num_classes),
        )

    def forward(self, stacked):
        w = torch.softmax(self.layer_logits, dim=0).view(1, self.num_layers, 1)
        fused = (stacked * w).sum(dim=1)
        return self.head(fused)


def mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim, chunk_batch_size=1):
    def _is_oom_error(exc):
        return "out of memory" in str(exc).lower()

    def _run_chunk_items(chunk_items, per_audio_means, per_audio_maxes, batch_size):
        if not chunk_items:
            return

        start = 0
        target_bs = max(1, int(batch_size))
        current_bs = target_bs
        warned_oom_downscale = False

        while start < len(chunk_items):
            bs = min(current_bs, len(chunk_items) - start)
            batch_items = chunk_items[start:start + bs]
            audio_idxs = [aidx for aidx, _ in batch_items]
            chunk_batch = [chunk for _, chunk in batch_items]

            try:
                inputs = processor(
                    chunk_batch,
                    sampling_rate=MERT_SR,
                    return_tensors="pt",
                    return_attention_mask=False,
                    padding=True,
                )
                inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}
                outputs = mert_model(**inputs)
            except RuntimeError as exc:
                if not _is_oom_error(exc) or bs == 1:
                    raise
                current_bs = max(1, bs // 2)
                if device.type == "cuda":
                    torch.cuda.empty_cache()
                if not warned_oom_downscale:
                    print(f"  WARNING: OOM in chunk batching; reducing chunk batch size to {current_bs}")
                    warned_oom_downscale = True
                continue

            for li in range(num_layers):
                hs = outputs.hidden_states[li + 1]  # (bs, T, hidden)
                means = hs.mean(dim=1)
                maxes = hs.max(dim=1).values
                for bi, audio_idx in enumerate(audio_idxs):
                    per_audio_means[audio_idx][li].append(means[bi])
                    per_audio_maxes[audio_idx][li].append(maxes[bi])

            start += bs
            current_bs = target_bs

    per_audio_means = [[[] for _ in range(num_layers)] for _ in audios]
    per_audio_maxes = [[[] for _ in range(num_layers)] for _ in audios]

    full_chunks = []
    tail_chunks = []
    for audio_idx, audio_np in enumerate(audios):
        chunks = [audio_np[s:s + CHUNK_SAMPLES] for s in range(0, len(audio_np), CHUNK_SAMPLES)
                  if len(audio_np[s:s + CHUNK_SAMPLES]) >= MERT_SR]
        if not chunks:
            chunks = [audio_np]
        for chunk in chunks:
            if len(chunk) == CHUNK_SAMPLES:
                full_chunks.append((audio_idx, chunk))
            else:
                tail_chunks.append((audio_idx, chunk))

    _run_chunk_items(full_chunks, per_audio_means, per_audio_maxes, batch_size=max(1, chunk_batch_size))
    _run_chunk_items(tail_chunks, per_audio_means, per_audio_maxes, batch_size=1)

    batch_pooled = []
    for audio_idx in range(len(audios)):
        layer_pooled = []
        for li in range(num_layers):
            mean_agg = torch.stack(per_audio_means[audio_idx][li]).mean(dim=0)
            max_agg = torch.stack(per_audio_maxes[audio_idx][li]).max(dim=0).values
            layer_pooled.append(torch.cat([mean_agg, max_agg]))
        batch_pooled.append(torch.stack(layer_pooled))

    return torch.stack(batch_pooled)



def train_one_epoch(mert_model, head, processor, loader, criterion, optimizer,
                    device, num_layers, hidden_dim, grad_accum=1, use_lora=True,
                    use_amp=False, amp_dtype=torch.float16, scaler=None, chunk_batch_size=1):
    head.train()
    if use_lora:
        mert_model.train()
    total_loss = correct = total = 0
    amp_enabled = bool(use_amp and device.type == "cuda")
    use_scaler = bool(scaler is not None and scaler.is_enabled())
    trainable_params = list(head.parameters()) + [p for p in mert_model.parameters() if p.requires_grad]

    optimizer.zero_grad(set_to_none=True)
    pbar = tqdm(loader, desc="Train", leave=False)
    for step, (audios, labels) in enumerate(pbar):
        labels = labels.to(device, non_blocking=True)

        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp_enabled):
            if use_lora:
                pooled = mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim, chunk_batch_size)
            else:
                with torch.no_grad():
                    pooled = mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim, chunk_batch_size)
                pooled = pooled.detach()
            logits = head(pooled)
            loss = criterion(logits, labels) / grad_accum

        if use_scaler:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % grad_accum == 0 or step == len(loader) - 1:
            if use_scaler:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
            if use_scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad(set_to_none=True)

        total_loss += loss.item() * grad_accum * len(audios)
        preds = logits.argmax(dim=1)
        correct += (preds == labels.argmax(dim=1)).sum().item()
        total += len(audios)
        pbar.set_postfix_str(f"loss={total_loss/total:.3f} acc={correct/total:.0%}")

    return total_loss / max(total, 1), correct / max(total, 1)


@torch.no_grad()
def evaluate(mert_model, head, processor, loader, criterion, device, num_layers, hidden_dim,
             use_amp=False, amp_dtype=torch.float16, chunk_batch_size=1):
    head.eval()
    mert_model.eval()
    total_loss = correct = total = 0
    all_labels_idx, all_preds_idx = [], []
    all_probs, all_labels_raw = [], []
    amp_enabled = bool(use_amp and device.type == "cuda")

    for audios, labels in tqdm(loader, desc="Eval", leave=False):
        labels = labels.to(device, non_blocking=True)
        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp_enabled):
            pooled = mert_forward_pool(audios, mert_model, processor, device, num_layers, hidden_dim, chunk_batch_size)
            logits = head(pooled)
            loss = criterion(logits, labels)

        total_loss += loss.item() * len(audios)
        probs = torch.sigmoid(logits)
        preds = logits.argmax(dim=1)
        true_primary = labels.argmax(dim=1)
        correct += (preds == true_primary).sum().item()
        total += len(audios)
        all_labels_idx.extend(true_primary.cpu().tolist())
        all_preds_idx.extend(preds.cpu().tolist())
        all_probs.append(probs.detach().float().cpu().numpy())
        all_labels_raw.append(labels.detach().float().cpu().numpy())
    return (total_loss / max(total, 1), correct / max(total, 1),
            all_labels_idx, all_preds_idx,
            np.concatenate(all_probs), np.concatenate(all_labels_raw))


def print_eval_metrics(labels_idx, preds_idx, probs, labels_multihot):
    num_classes = len(CLASS_METERS)
    labels_binary = (labels_multihot > 0.5).astype(np.float32)

    # Per-class accuracy
    print(f"\n{'Meter':>6s}  {'Correct':>7s}  {'Total':>5s}  {'Acc':>6s}  {'AP':>6s}")
    print("-" * 40)
    total_correct = total_count = 0
    aps = []
    for i, m in enumerate(CLASS_METERS):
        mask = np.array(labels_idx) == i
        n = mask.sum()
        if n == 0:
            print(f"{m:>4d}/x  {'—':>7s}  {0:>5d}  {'—':>6s}  {'—':>6s}")
            continue
        c = (np.array(preds_idx)[mask] == i).sum()
        acc = c / n
        total_correct += c
        total_count += n
        # AP
        ap_str = "—"
        if labels_binary[:, i].sum() > 0:
            ap = average_precision_score(labels_binary[:, i], probs[:, i])
            aps.append(ap)
            ap_str = f"{ap:.3f}"
        print(f"{m:>4d}/x  {c:>5d}/{n:<5d}       {acc:>5.1%}  {ap_str:>6s}")
    if total_count:
        print("-" * 40)
        print(f"{'Total':>6s}  {total_correct:>5d}/{total_count:<5d}       {total_correct/total_count:>5.1%}  mAP={np.mean(aps):.3f}" if aps else "")

    # Macro-F1
    preds_binary = (probs > 0.5).astype(np.int32)
    cols = [i for i in range(num_classes) if labels_binary[:, i].sum() > 0]
    if cols:
        mf1 = f1_score(labels_binary[:, cols], preds_binary[:, cols], average='macro', zero_division=0)
        print(f"Macro-F1: {mf1:.3f}")

    # Confidence gap
    sorted_p = np.sort(probs, axis=1)[:, ::-1]
    gap = sorted_p[:, 0] - sorted_p[:, 1]
    print(f"Confidence gap: mean={gap.mean():.3f}, median={np.median(gap):.3f}")


def format_per_class_watchdog(labels_idx, preds_idx, probs, labels_multihot):
    labels_arr = np.array(labels_idx)
    preds_arr = np.array(preds_idx)
    labels_binary = (labels_multihot > 0.5).astype(np.float32)
    chunks = []
    for i, m in enumerate(CLASS_METERS):
        mask = labels_arr == i
        support = int(mask.sum())
        if support == 0:
            chunks.append(f"{m}/x n=0 acc=— ap=—")
            continue
        c = int((preds_arr[mask] == i).sum())
        acc = c / support
        ap_str = "—"
        if labels_binary[:, i].sum() > 0:
            ap_str = f"{average_precision_score(labels_binary[:, i], probs[:, i]):.3f}"
        chunks.append(f"{m}/x n={support} acc={acc:.0%} ap={ap_str}")
    return " | ".join(chunks)


print("Training code loaded")


## 6. Setup model & data

In [None]:
# ══════════════════════════════════════════
#  AUTO-DERIVED CONFIG (from top-level settings)
# ══════════════════════════════════════════

_cfg = MODEL_CONFIGS[MODEL_NAME]
BATCH_SIZE = _cfg["batch_size"]     # 95M: 8, 330M: 4
GRAD_ACCUM = _cfg["grad_accum"]     # 95M: 4, 330M: 8 → effective batch = 32
HEAD_LR = _cfg["head_lr"]           # 95M: 1e-3, 330M: 5e-4
LORA_LR = _cfg["lora_lr"]          # 95M: 5e-5, 330M: 5e-5

run_name_state_path = CHECKPOINT_DIR / ".last_run_name.txt"

if RUN_NAME is None and AUTO_RESUME_TRACKED_RUN and run_name_state_path.exists():
    tracked = run_name_state_path.read_text(encoding="utf-8").strip()
    if tracked:
        RUN_NAME = tracked
        print(f"Tracked run name: {RUN_NAME}")

if RUN_NAME is None:
    from datetime import datetime
    import secrets

    _model_short = MODEL_NAME.split("/")[-1].replace("MERT-v1-", "")
    _train_tag = f"lora_r{LORA_RANK}_a{LORA_ALPHA}" if USE_LORA else "frozen"
    _hp_tag = f"hlr{HEAD_LR:.0e}_llr{LORA_LR:.0e}_bs{BATCH_SIZE}_ga{GRAD_ACCUM}_n{NOISE_STD:g}_d{HEAD_DROPOUT:g}"
    _stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    _nonce = secrets.token_hex(2)
    RUN_NAME = f"{_model_short}_{_train_tag}_{_hp_tag}_{_stamp}_{_nonce}".replace("+", "")
    print(f"Generated run name: {RUN_NAME}")

run_name_state_path.parent.mkdir(parents=True, exist_ok=True)
run_name_state_path.write_text(RUN_NAME + "\n", encoding="utf-8")

CKPT_PATH = CHECKPOINT_DIR / f"meter_mert_{RUN_NAME}.pt"
BEST_CKPT_PATH = CHECKPOINT_DIR / f"meter_mert_{RUN_NAME}.best.pt"

device = torch.device("cuda")
num_layers = _cfg["num_layers"]
hidden_dim = _cfg["hidden_dim"]
pooled_dim = hidden_dim * 2

if USE_TF32 and device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision("high")

_amp_mode = str(AMP_DTYPE).strip().lower()
if _amp_mode not in {"bf16", "fp16"}:
    raise ValueError(f"AMP_DTYPE must be 'bf16' or 'fp16', got: {AMP_DTYPE}")
if _amp_mode == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
    print("WARNING: bf16 not supported on this GPU; falling back to fp16")
    _amp_mode = "fp16"
AMP_TORCH_DTYPE = torch.bfloat16 if _amp_mode == "bf16" else torch.float16
AMP_ENABLED = bool(USE_AMP and device.type == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=AMP_ENABLED and AMP_TORCH_DTYPE == torch.float16)

if CHUNK_BATCH_SIZE > 0:
    CHUNK_BATCH_SIZE_EFFECTIVE = int(CHUNK_BATCH_SIZE)
elif device.type == "cuda":
    _vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
    CHUNK_BATCH_SIZE_EFFECTIVE = 4 if _vram_gb >= 30 else 2
else:
    CHUNK_BATCH_SIZE_EFFECTIVE = 1

print(f"Config: {MODEL_NAME}")
print(f"  batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM} (effective={BATCH_SIZE*GRAD_ACCUM})")
print(f"  head_lr={HEAD_LR}, lora_lr={LORA_LR}")
print(f"  noise={NOISE_STD}, dropout={HEAD_DROPOUT}")
print(f"  epochs={EPOCHS}, patience={PATIENCE}")
print(f"  amp={AMP_ENABLED} ({_amp_mode}), tf32={USE_TF32 and device.type == 'cuda'}")
print(f"  chunk_batch={CHUNK_BATCH_SIZE_EFFECTIVE}")

# Load data
print("\nLoading data...")
train_entries = load_split(METER2800_DIR, "train")
val_entries = load_split(METER2800_DIR, "val")
test_entries = load_split(METER2800_DIR, "test")

# WIKIMETER: use local SSD copy if available (faster I/O)
_wiki_dir = WIKIMETER_LOCAL if (WIKIMETER_LOCAL is not None and WIKIMETER_LOCAL_TAB.exists()) else WIKIMETER_DIR
_wiki_tab = _wiki_dir / "data_wikimeter.tab"
if _wiki_tab.exists():
    wiki_all = load_extra_data(_wiki_dir)
    _src = "local SSD" if _wiki_dir == WIKIMETER_LOCAL else "Drive"
    print(f"  WIKIMETER: {len(wiki_all)} segments total (from {_src})")

    # Group segments by song (stem without _segXX)
    from collections import defaultdict
    import random
    random.seed(42)  # reproducible split

    song_segments = defaultdict(list)  # song_stem → [(path, meters), ...]
    for path, meters in wiki_all:
        song_stem = re.sub(r"_seg\d+$", "", path.stem)
        song_segments[song_stem].append((path, meters))

    # Group songs by primary meter for stratified split
    meter_songs = defaultdict(list)  # meter → [song_stem, ...]
    for song_stem, segments in song_segments.items():
        primary_meter = segments[0][1][0]  # first segment's first meter
        meter_songs[primary_meter].append(song_stem)

    # Pick ~10% songs per meter for val
    wiki_val_songs = set()
    for meter, songs in sorted(meter_songs.items()):
        random.shuffle(songs)
        if len(songs) <= 1:
            n_val = 0
        else:
            n_val = max(1, int(len(songs) * WIKI_VAL_RATIO))
            n_val = min(n_val, len(songs) - 1)  # keep at least one song in train
        wiki_val_songs.update(songs[:n_val])

    # Split segments
    wiki_train, wiki_val = [], []
    for song_stem, segments in song_segments.items():
        if song_stem in wiki_val_songs:
            wiki_val.extend(segments)
        else:
            wiki_train.extend(segments)

    train_entries.extend(wiki_train)
    val_entries.extend(wiki_val)
    print(f"  WIKIMETER split: {len(wiki_train)} train ({len(song_segments) - len(wiki_val_songs)} songs)"
          f" + {len(wiki_val)} val ({len(wiki_val_songs)} songs)")

counts = Counter()
for _, meters in train_entries + val_entries + test_entries:
    for m in meters:
        counts[m] += 1
print(f"\nTotal: {len(train_entries)} train, {len(val_entries)} val, {len(test_entries)} test")
for m in CLASS_METERS:
    print(f"  {m}/x: {counts.get(m, 0)}")

# Load MERT
print(f"\nLoading {MODEL_NAME}...")
from transformers import AutoModel, Wav2Vec2FeatureExtractor
processor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME, trust_remote_code=True)
mert_model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True, output_hidden_states=True).to(device)

if USE_LORA:
    from peft import LoraConfig, get_peft_model
    lora_config = LoraConfig(r=LORA_RANK, lora_alpha=LORA_ALPHA,
                             target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
    mert_model = get_peft_model(mert_model, lora_config)
    trainable, total_p = mert_model.get_nb_trainable_parameters()
    print(f"  LoRA: {trainable:,} / {total_p:,} ({trainable/total_p:.2%})")
else:
    for p in mert_model.parameters():
        p.requires_grad = False

head = MERTClassificationHead(num_layers, pooled_dim, len(CLASS_METERS),
                               dropout=HEAD_DROPOUT).to(device)

# Datasets
AUDIO_CACHE_PATH = Path(AUDIO_CACHE_DIR).expanduser().resolve() if AUDIO_CACHE_DIR else None
if AUDIO_CACHE_PATH is not None:
    AUDIO_CACHE_PATH.mkdir(parents=True, exist_ok=True)
audio_backend = "torchaudio" if HAS_TORCHAUDIO else "librosa"
print(f"Audio: backend={audio_backend}, cache={AUDIO_CACHE_PATH if AUDIO_CACHE_PATH is not None else 'disabled'}")

train_ds = MERTAudioDataset(train_entries, augment=True, noise_std=NOISE_STD, cache_dir=AUDIO_CACHE_PATH)
val_ds = MERTAudioDataset(val_entries, cache_dir=AUDIO_CACHE_PATH)
test_ds = MERTAudioDataset(test_entries, cache_dir=AUDIO_CACHE_PATH)

loader_workers = max(0, int(NUM_WORKERS))
loader_kwargs = {
    "batch_size": BATCH_SIZE,
    "collate_fn": simple_collate,
    "num_workers": loader_workers,
    "pin_memory": device.type == "cuda",
}
if loader_workers > 0:
    loader_kwargs["persistent_workers"] = True

train_loader = DataLoader(train_ds, shuffle=True, **loader_kwargs)
val_loader = DataLoader(val_ds, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_ds, shuffle=False, **loader_kwargs)
print(f"DataLoader: workers={loader_workers}, pin_memory={loader_kwargs['pin_memory']}, "
      f"persistent_workers={loader_kwargs.get('persistent_workers', False)}")

# Loss with pos_weight
pos_counts = np.zeros(len(CLASS_METERS), dtype=np.float32)
for _, meters in train_ds.entries:
    for m in meters:
        if m in METER_TO_IDX:
            pos_counts[METER_TO_IDX[m]] += 1
neg_counts = len(train_ds) - pos_counts
with np.errstate(divide="ignore", invalid="ignore"):
    pos_weights = torch.tensor(np.where(pos_counts > 0, neg_counts / pos_counts, 1.0), dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

# Optimizer
param_groups = [{"params": head.parameters(), "lr": HEAD_LR}]
if USE_LORA:
    lora_params = [p for p in mert_model.parameters() if p.requires_grad]
    if lora_params:
        param_groups.append({"params": lora_params, "lr": LORA_LR})
optimizer = torch.optim.AdamW(param_groups, weight_decay=1e-4)

# Scheduler: ReduceLROnPlateau — drops LR only when val stops improving
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.5, patience=5, min_lr=1e-6)

# Resume from checkpoint (auto-detect)
START_EPOCH = 1
best_val_acc = -1.0
if CKPT_PATH.exists():
    print(f"\nResuming from {CKPT_PATH.name}")
    ckpt = torch.load(CKPT_PATH, weights_only=False, map_location=device)
    head.load_state_dict(ckpt["head_state_dict"])
    if ckpt.get("lora_state_dict") and USE_LORA:
        for name, param_data in ckpt["lora_state_dict"].items():
            parts = name.split(".")
            obj = mert_model
            for part in parts[:-1]:
                obj = getattr(obj, part)
            getattr(obj, parts[-1]).data.copy_(param_data.to(device))
    if "optimizer_state_dict" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    if "scheduler_state_dict" in ckpt:
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
    START_EPOCH = ckpt.get("epoch", 0) + 1
    best_val_acc = ckpt.get("val_accuracy", -1.0)
    print(f"  Epoch {START_EPOCH}, best val: {best_val_acc:.1%}")

print(f"\nRun: {RUN_NAME}")
print(f"Ready: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test")

print(f"Checkpoint: {CKPT_PATH}")


In [None]:
# ══════════════════════════════════════════
#  TRAINING LOOP
# ══════════════════════════════════════════

best_val_loss = float("inf")
patience_counter = 0

print(f"{'Ep':>3s}  {'TrLoss':>7s}  {'TrAcc':>6s}  {'VaLoss':>7s}  {'VaAcc':>6s}  {'LR':>8s}  {'Time':>5s}")
print("-" * 55)

for epoch in range(START_EPOCH, EPOCHS + 1):
    t0 = time.time()

    train_loss, train_acc = train_one_epoch(
        mert_model, head, processor, train_loader, criterion, optimizer,
        device, num_layers, hidden_dim, GRAD_ACCUM, USE_LORA,
        AMP_ENABLED, AMP_TORCH_DTYPE, scaler, CHUNK_BATCH_SIZE_EFFECTIVE)

    val_loss, val_acc, val_labels_idx, val_preds_idx, val_probs, val_labels_mh = evaluate(
        mert_model, head, processor, val_loader, criterion,
        device, num_layers, hidden_dim, AMP_ENABLED, AMP_TORCH_DTYPE, CHUNK_BATCH_SIZE_EFFECTIVE)

    scheduler.step(val_acc)
    elapsed = time.time() - t0

    # Track best
    improved = val_acc > best_val_acc
    if improved:
        best_val_acc = val_acc
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1

    # Save every epoch
    ckpt_data = {
        "head_state_dict": {k: v.cpu().clone() for k, v in head.state_dict().items()},
        "lora_state_dict": {n: p.cpu().clone() for n, p in mert_model.named_parameters() if p.requires_grad} if USE_LORA else None,
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "class_map": IDX_TO_METER,
        "model_name": MODEL_NAME,
        "num_layers": num_layers,
        "hidden_dim": hidden_dim,
        "pooled_dim": pooled_dim,
        "head_dim": 256,
        "num_classes": len(CLASS_METERS),
        "dropout": HEAD_DROPOUT,
        "val_accuracy": best_val_acc,
        "val_loss": best_val_loss,
        "epoch": epoch,
        "lora_rank": LORA_RANK if USE_LORA else 0,
        "lora_alpha": LORA_ALPHA if USE_LORA else 0,
        "model_type": "MERTFineTuned",
    }
    torch.save(ckpt_data, CKPT_PATH)
    if improved:
        torch.save(ckpt_data, BEST_CKPT_PATH)

    current_lr = optimizer.param_groups[0]["lr"]
    marker = " *" if improved else ""
    print(f"{epoch:3d}  {train_loss:7.4f}  {train_acc:5.1%}  {val_loss:7.4f}  {val_acc:5.1%}  {current_lr:.1e}  {elapsed:4.0f}s{marker}")
    print(f"      Val class: {format_per_class_watchdog(val_labels_idx, val_preds_idx, val_probs, val_labels_mh)}")

    if patience_counter >= PATIENCE:
        print(f"\nEarly stopping at epoch {epoch}")
        break

print(f"\nBest val: {best_val_acc:.1%}, saved to {BEST_CKPT_PATH.name if BEST_CKPT_PATH.exists() else CKPT_PATH.name}")


## 7. Test evaluation

In [None]:
# Load best checkpoint (fallback to latest)
EVAL_CKPT_PATH = BEST_CKPT_PATH if BEST_CKPT_PATH.exists() else CKPT_PATH
if EVAL_CKPT_PATH.exists():
    ckpt = torch.load(EVAL_CKPT_PATH, weights_only=False, map_location=device)
    head.load_state_dict(ckpt["head_state_dict"])
    head = head.to(device)
    if ckpt.get("lora_state_dict") and USE_LORA:
        for name, param_data in ckpt["lora_state_dict"].items():
            parts = name.split(".")
            obj = mert_model
            for part in parts[:-1]:
                obj = getattr(obj, part)
            getattr(obj, parts[-1]).data.copy_(param_data.to(device))
    print(f"Loaded: {EVAL_CKPT_PATH.name} (epoch {ckpt.get('epoch', '?')}, val {ckpt.get('val_accuracy', 0):.1%})")

test_loss, test_acc, test_labels, test_preds, test_probs, test_labels_mh = evaluate(
    mert_model, head, processor, test_loader, criterion,
    device, num_layers, hidden_dim, AMP_ENABLED, AMP_TORCH_DTYPE, CHUNK_BATCH_SIZE_EFFECTIVE)

n_correct = sum(1 for a, b in zip(test_labels, test_preds) if a == b)
print(f"\nTest: {n_correct}/{len(test_labels)} = {test_acc:.1%}")
print_eval_metrics(test_labels, test_preds, test_probs, test_labels_mh)

# Save test acc to checkpoints
if CKPT_PATH.exists():
    latest_ckpt = torch.load(CKPT_PATH, weights_only=False, map_location=device)
    latest_ckpt["test_accuracy"] = test_acc
    torch.save(latest_ckpt, CKPT_PATH)
if BEST_CKPT_PATH.exists():
    best_ckpt = torch.load(BEST_CKPT_PATH, weights_only=False, map_location=device)
    best_ckpt["test_accuracy"] = test_acc
    torch.save(best_ckpt, BEST_CKPT_PATH)

# Release GPU
print("\nReleasing GPU...")
from google.colab import runtime
runtime.unassign()
