In [None]:
from pathlib import Path


In [None]:
current_dir = Path.cwd()
project_root = current_dir.parent.parent  # Go up from src/notebooks/ to project root
data_root = str(project_root / "src" / "datasets" / "esc50" / "data")

data_root

In [None]:
# python
import warnings

PYTHONWARNINGS="ignore:In 2.9.*:UserWarning:torchaudio._backend.utils,ignore:'pin_memory'.*:UserWarning:torch.utils.data.dataloader"

In [None]:
import sys, pathlib
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
from src.utils.common import device_auto, count_params
from src.models.classifier import SeqClassifier
from src.models.lmu_block import LMUMemoryBlock

# your project root (edit if needed)
PROJECT_ROOT = pathlib.Path.cwd()
sys.path.append(str(PROJECT_ROOT))

# dataset loader you already have
from src.datasets.esc50.esc50_dataset import make_esc50_loaders
from src.datasets.esc50.esc50_dataset2 import make_esc50_loaders as make_esc50_loaders2
from tqdm import tqdm
from src.utils.logging import Timer
from src.utils.metrics import top1
from src.utils.common import amp_autocast

In [None]:
# %%
# ---- Data ----
DATA_ROOT            = data_root
FOLD_VAL             = 1                   # 1..5
FEATURE              = "melspec"           # or "waveform"
SR                   = 16000
N_MELS               = 128
HOP_LENGTH           = 160 #320
N_FFT                = 512 #1024
TARGET_NUM_FRAMES    = 500 #250                 # ~5s @ 16k / hop=320
AUGMENT              = True

BATCH_SIZE           = 32
NUM_WORKERS          = 4

# ---- Training ----
EPOCHS               = 150
LR                   = 3e-4
WEIGHT_DECAY         = 1e-3
USE_AMP              = True
SEED                 = 42

# ---- Model (shared) ----
N_CLASSES            = 50
D_MODEL              = 256
DEPTH                = 4
DROPOUT              = 0.2
MLP_RATIO            = 2.0

# ---- LMU (external pkg) ----
LMU_MEMORY_SIZE      = 256     # "order" Q
LMU_THETA            = TARGET_NUM_FRAMES   # set to sequence length

# ---- Logging / ckpts ----
RUN_NAME             = f"esc50_lmu_d{D_MODEL}x{DEPTH}_Q{LMU_MEMORY_SIZE}_fold{FOLD_VAL}"
SAVE_DIR             = PROJECT_ROOT / "runs" / RUN_NAME
SAVE_DIR.mkdir(parents=True, exist_ok=True)

print("Save dir:", SAVE_DIR)


In [None]:
from src.utils.common import set_seed

set_seed(SEED)

train_loader, val_loader, cmvn_stats = make_esc50_loaders2(
    data_root=DATA_ROOT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    fold_val=FOLD_VAL,
    sample_rate=SR,
    n_mels=N_MELS,
    hop_length=HOP_LENGTH,
    normalize="global_cmvn",
    n_fft=N_FFT,
    target_num_frames=TARGET_NUM_FRAMES,
    augment=AUGMENT,
    download=False,
    wav_time_shift_pct=0.1,
    wav_gain_db=3.0
)

xb, yb, _ = next(iter(train_loader))
xb.shape, yb.shape  # expect: (B, T, D), (B,)


In [None]:
device = device_auto()
amp = bool(USE_AMP := (USE_AMP and device.type=="cuda"))
writer = SummaryWriter(SAVE_DIR.as_posix())

def build_model():
    d_in = N_MELS if FEATURE=="melspec" else 1
    def block_factory(d_model: int):
        return LMUMemoryBlock(d_model, memory_size=LMU_MEMORY_SIZE, theta=LMU_THETA,
                              dropout=DROPOUT, mlp_ratio=MLP_RATIO)
    m = SeqClassifier(d_in, N_CLASSES, D_MODEL, DEPTH, block_factory)
    return m

model = build_model().to(device)
print(f"Params: {count_params(model):,}")

# opt = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9,0.95))
# sch = CosineAnnealingLR(opt, T_max=EPOCHS)
# scaler = torch.cuda.amp.GradScaler(enabled=amp)
#
# if torch.cuda.is_available():
#     torch.cuda.reset_peak_memory_stats()
#
# criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# best_acc = 0.0
# best_path = SAVE_DIR/"best.pt"

# ==== Optimizer: WD=0 לפרמטרי bias/Norm ====
decay, no_decay = [], []
for n, p in model.named_parameters():
    if p.ndim == 1 or n.endswith(".bias") or "norm" in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)
optim_groups = [
    {"params": decay, "weight_decay": WEIGHT_DECAY},
    {"params": no_decay, "weight_decay": 0.0},
]
opt = AdamW(optim_groups, lr=LR, betas=(0.9, 0.95))

# ==== Optimizer: WD=0 לפרמטרי bias/Norm ====
decay, no_decay = [], []
for n, p in model.named_parameters():
    if p.ndim == 1 or n.endswith(".bias") or "norm" in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)
optim_groups = [
    {"params": decay, "weight_decay": WEIGHT_DECAY},
    {"params": no_decay, "weight_decay": 0.0},
]
opt = AdamW(optim_groups, lr=LR, betas=(0.9, 0.95))

# ==== Scheduler: warmup → cosine ====
class WarmupThenCosine:
    def __init__(self, optimizer, warmup_epochs, total_epochs):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.cur_epoch = 0
        self.base_lrs = [g["lr"] for g in optimizer.param_groups]
        self.cosine = None

    def _step_epoch(self):
        self.cur_epoch += 1
        if self.cur_epoch <= self.warmup_epochs:
            scale = self.cur_epoch / max(1, self.warmup_epochs)
            for pg, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
                pg["lr"] = base_lr * scale
        else:
            if self.cosine is None:
                self.cosine = CosineAnnealingLR(self.optimizer, T_max=self.total_epochs - self.warmup_epochs)
            self.cosine.step()

    def step(self):
        self._step_epoch()

sch_wrap = WarmupThenCosine(opt, warmup_epochs=3, total_epochs=EPOCHS)

# ==== Loss ====
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# ==== AMP scaler (CUDA בלבד) ====
scaler = torch.cuda.amp.GradScaler(enabled=amp)
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()

best_acc, best_path = 0.0, SAVE_DIR / "best.pt"


In [None]:
CLIP_NORM = 1.0

def train_one_epoch(ep):
    model.train()
    total_loss=total_acc=0.0; n=0
    with Timer() as t:
        for xb,yb, _ in tqdm(train_loader, leave=False):
            xb=xb.to(device,non_blocking=True); yb=yb.to(device,non_blocking=True)
            opt.zero_grad(set_to_none=True)
            with amp_autocast(amp):
                logits = model(xb)
                loss = criterion(logits, yb)

            if amp and device.type == "cuda":
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
                opt.step()

            bs = xb.size(0)
            total_loss += loss.item()*bs
            total_acc  += top1(logits.detach(), yb)*bs
            n += bs
    return {"loss": total_loss/n, "acc": total_acc/n, "time_s": t.dt}

@torch.no_grad()
def evaluate(ep):
    model.eval()
    total_loss=total_acc=0.0
    n=0
    with Timer() as t:
        for xb,yb, _ in tqdm(val_loader, leave=False):
            xb=xb.to(device,non_blocking=True); yb=yb.to(device,non_blocking=True)
            with amp_autocast(amp):
                logits = model(xb)
                loss = criterion(logits, yb)
            bs = xb.size(0)
            total_loss += loss.item()*bs
            total_acc  += top1(logits, yb)*bs
            n += bs
    return {"loss": total_loss/n, "acc": total_acc/n, "time_s": t.dt}


In [None]:
# def train_one_epoch(ep):
#     model.train_utils()
#     total_loss=total_acc=0.0; n=0
#     with Timer() as t:
#         for xb,yb in tqdm(train_loader, leave=False):
#             xb=xb.to(device,non_blocking=True); yb=yb.to(device,non_blocking=True)
#             opt.zero_grad(set_to_none=True)
#             with amp_autocast(amp):
#                 logits = model(xb)
#                 loss = criterion(logits, yb)
#             if amp:
#                 scaler.scale(loss).backward()
#                 scaler.step(opt); scaler.update()
#             else:
#                 loss.backward(); opt.step()
#             bs = xb.size(0)
#             total_loss += loss.item()*bs
#             total_acc  += top1(logits.detach(), yb)*bs
#             n += bs
#     return {"loss": total_loss/n, "acc": total_acc/n, "time_s": t.dt}
#
# @torch.no_grad()
# def evaluate(ep):
#     model.eval()
#     total_loss=total_acc=0.0; n=0
#     with Timer() as t:
#         for xb,yb in tqdm(val_loader, leave=False):
#             xb=xb.to(device,non_blocking=True); yb=yb.to(device,non_blocking=True)
#             with amp_autocast(amp):
#                 logits = model(xb)
#                 loss = criterion(logits, yb)
#             bs = xb.size(0)
#             total_loss += loss.item()*bs
#             total_acc  += top1(logits, yb)*bs
#             n += bs
#     return {"loss": total_loss/n, "acc": total_acc/n, "time_s": t.dt}

In [None]:
from src.utils.logging import gpu_mem_mb_peak

for ep in range(1, EPOCHS+1):
    tr = train_one_epoch(ep)
    va = evaluate(ep)
    sch_wrap.step()

    # logs
    writer.add_scalar("train_utils/loss", tr["loss"], ep)
    writer.add_scalar("train_utils/acc",  tr["acc"],  ep)
    writer.add_scalar("train_utils/time_s", tr["time_s"], ep)

    writer.add_scalar("val/loss", va["loss"], ep)
    writer.add_scalar("val/acc",  va["acc"],  ep)
    writer.add_scalar("val/time_s", va["time_s"], ep)

    writer.add_scalar("sys/gpu_mem_mb_peak", gpu_mem_mb_peak(), ep)

    # checkpoints
    ckpt = {"model": model.state_dict(), "epoch": ep, "val": va,
            "config": {
              "feature": FEATURE, "sr": SR, "n_mels": N_MELS,
              "d_model": D_MODEL, "depth": DEPTH, "lmu_memory_size": LMU_MEMORY_SIZE,
            }}
    # torch.save(ckpt, SAVE_DIR/f"epoch_{ep:03d}.pt")
    # if va["acc"] > best_acc:
    #     best_acc = va["acc"]
    #     torch.save(ckpt, best_path)
    #     print(f"✅ new best acc {best_acc:.4f} @ epoch {ep}")

    print(f"Epoch {ep}/{EPOCHS} | "
          f"train_utils loss {tr['loss']:.4f} acc {tr['acc']:.4f} "
          f"| val loss {va['loss']:.4f} acc {va['acc']:.4f} "
          f"| t_train {tr['time_s']:.2f}s t_val {va['time_s']:.2f}s")

writer.close()
print("Best ckpt:", best_path)
