In [None]:
import sys
from pathlib import Path


PROJECT_ROOT = Path.cwd()
sys.path.append(str(PROJECT_ROOT))

PYTHONWARNINGS="ignore:In 2.9.*:UserWarning:torchaudio._backend.utils,ignore:'pin_memory'.*:UserWarning:torch.utils.data.dataloader"

current_dir = Path.cwd()
project_root = current_dir.parent.parent.parent
data_root = str(project_root / "src" / "datasets" / "esc50" / "data")

print(f"Current directory: {current_dir}")
print(f"Project root: {project_root}")
print(f"Data root: {data_root}")

In [None]:
import os, torch, torch.nn as nn
from types import SimpleNamespace

from src.datasets.esc50.esc50_dataset2 import make_esc50_loaders
from src.models.v2.build_model import BlockConfig, build_model
from src.utils.common import set_seed, device_auto, amp_autocast, count_params
from src.utils.logging import TB

from src.train_utils.ema import EMA
from src.train_utils.early_stopping import EarlyStopping
from src.train_utils.loops import train_one_epoch, evaluate_one_epoch

# Optional eval helpers
from src.eval.infer import predict_loader
from src.eval.metrics import confusion_matrix
from src.eval.report import print_basic_report, plot_confusion, print_per_class

from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

# --- experiment config ---
args = SimpleNamespace(
    # data
    data_root=data_root,    # <-- set path with audio/ + meta/esc50.csv
    batch=32,
    workers=4,
    fold_val=1,
    feature="melspec",
    sample_rate=16000,
    n_mels=128,
    hop_length=160,                 # 10 ms hop → ~500 frames
    n_fft=512,
    target_num_frames=500,
    augment=True,

    # training
    epochs=100,
    lr=3e-4,
    wd=1e-3,
    amp=True,                       # works on CUDA and Apple MPS (via your amp_autocast)
    seed=42,
    save_dir="./runs/esc50_v2_nb",
    ema=True,
    ema_decay=0.999,
    label_smoothing=0.1,
    warmup_epochs=5,
    patience=20,
    min_delta=0.0,

    # model scaffold
    core="lmu",
    d_model=256,
    depth=6,
    dropout=0.2,
    mlp_ratio=2.0,
    droppath_final=0.1,
    layerscale_init=1e-2,
    residual_gain=1.0,
    pool="mean",
)

In [None]:
set_seed(args.seed)
device = device_auto()
amp = bool(args.amp and device.type in {"cuda","mps"})
os.makedirs(args.save_dir, exist_ok=True)

loaders = make_esc50_loaders(
    data_root=args.data_root, batch_size=args.batch, num_workers=args.workers,
    fold_val=args.fold_val, feature=args.feature, sample_rate=args.sample_rate,
    n_mels=args.n_mels, hop_length=args.hop_length, n_fft=args.n_fft,
    target_num_frames=args.target_num_frames, augment=args.augment, download=False
)

if isinstance(loaders, (list, tuple)) and len(loaders) == 3:
    train_loader, val_loader, _ = loaders
else:
    train_loader, val_loader = loaders

print(f"train size: {len(train_loader.dataset)}, validate size: {len(val_loader.dataset)}, device={device}")

In [None]:
n_classes = 50
d_in = args.n_mels if args.feature == "melspec" else 1

block_cfg = BlockConfig(
    kind=args.core,
    memory_size=256,
    theta=args.target_num_frames,
    dropout=args.dropout,
    mlp_ratio=args.mlp_ratio,
    droppath_final=args.droppath_final,
    layerscale_init=args.layerscale_init,
    residual_gain=args.residual_gain,
    pool=args.pool,
)

model = build_model(
    d_in=d_in, n_classes=n_classes, d_model=args.d_model, depth=args.depth, block_cfg=block_cfg
).to(device)

print(f"device={device}")
print(f"Params: {count_params(model):,}")

opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.9, 0.95))

warmup_epochs = max(0, int(args.warmup_epochs))
if warmup_epochs > 0:
    sch = SequentialLR(
        opt,
        schedulers=[
            LinearLR(opt, start_factor=1e-3, end_factor=1.0, total_iters=warmup_epochs),
            CosineAnnealingLR(opt, T_max=args.epochs - warmup_epochs),
        ],
        milestones=[warmup_epochs],
    )
else:
    sch = CosineAnnealingLR(opt, T_max=args.epochs)


if amp and device.type == "cuda":
    try:
        scaler = torch.amp.GradScaler(device="cuda", enabled=True)
    except AttributeError:  # older PyTorch
        scaler = torch.cuda.amp.GradScaler(enabled=True)
else:
    scaler = None

ema = EMA(model, decay=args.ema_decay) if args.ema else None
tb = TB(args.save_dir)

if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) if args.label_smoothing > 0 else nn.CrossEntropyLoss()


In [None]:
best_acc = 0.0
best_path = os.path.join(args.save_dir, "best.pt")
primed_scheduler = False

history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
stopper = EarlyStopping(patience=args.patience, min_delta=args.min_delta)

for ep in range(args.epochs + 1):
    tr = train_one_epoch(model, train_loader, opt, scaler, device, amp, criterion, ema)
    va = evaluate_one_epoch(model, val_loader, device, amp, criterion, ema)

    # Scheduler AFTER at least one optimizer step
    if tr.get("stepped", False):
        sch.step(); primed_scheduler = True
    elif not primed_scheduler:
        opt.step(); sch.step(); primed_scheduler = True

    history["train_loss"].append(tr["loss"]); history["train_acc"].append(tr["acc"])
    history["val_loss"].append(va["loss"]);   history["val_acc"].append(va["acc"])

    tb.scalars(tr, ep, "train/"); tb.scalars(va, ep, "val/")
    tb.scalars({"lr": tr["lr"]}, ep, "train/")

    print(f"Epoch {ep:03d}/{args.epochs} | "
          f"train {tr['loss']:.4f}/{tr['acc']:.4f} | "
          f"val {va['loss']:.4f}/{va['acc']:.4f} | "
          f"t {tr['time_s']:.1f}s/{va['time_s']:.1f}s | "
          f"lr {tr['lr']:.2e}")

    ckpt = {"model": model.state_dict(), "epoch": ep, "val": va, "args": vars(args)}
    # torch.save(ckpt, os.path.join(args.save_dir, f"epoch_{ep:03d}.pt"))
    if va["acc"] > best_acc:
        best_acc = va["acc"]
        torch.save(ckpt, best_path)
        print(f"✅ new best acc {best_acc:.4f}")

    if stopper.step(va["acc"]):
        print(f"⏹ Early stopping (patience={args.patience}, best_acc={stopper.best:.4f}).")
        break

tb.close()
best_acc

In [None]:
# Load best if desired
best_ckpt = torch.load(os.path.join(args.save_dir, "best.pt"), map_location="cpu")
model.load_state_dict(best_ckpt["model"])
model.to(device)

# Get logits/labels on validation set
logits, labels = predict_loader(model, val_loader, device, amp_autocast, amp)

# Basic report
print_basic_report(logits, labels)

# Confusion matrix + per-class stats
cm = confusion_matrix(logits, labels, num_classes=50)
plot_confusion(cm, class_names=None, normalize=True, figsize=(9,9))
print_per_class(cm, class_names=None, top_k=10)


In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(history["train_acc"], label="train_acc")
plt.plot(history["val_acc"], label="val_acc")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("Accuracy"); plt.show()

plt.figure()
plt.plot(history["train_loss"], label="train_loss")
plt.plot(history["val_loss"], label="val_loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("Loss"); plt.show()