(compact baseline, torchvision-free)

This is a small 3D CNN baseline (C3D-lite). It’s not as strong as an official R(2+1)D/ResNet-3D, but it trains fast and avoids extra library constraints. You can later swap in a stronger backbone with the same dataloaders.

### Cell 1 — Root, config, utils

In [1]:
# --- Cell 1: project root, sys.path, config, utils ---
from pathlib import Path
import sys, yaml

# 1) Point to project root and add to sys.path BEFORE any src imports
root = Path("..").resolve()
if str(root) not in sys.path:
    sys.path.insert(0, str(root))
if str(root / "src") not in sys.path:
    sys.path.insert(0, str(root / "src"))
print("PYTHONPATH added:", root)

# 2) Now it's safe to import from src/*
from src.data.wlasl_ds import WLASLDataset
from src.utils.seed import seed_everything
from src.utils.checkpoints import save_checkpoint, load_checkpoint

# 3) Load config & prepare dirs
cfg_path = root / "configs" / "wlasl100.yaml"
assert cfg_path.exists(), f"Config not found: {cfg_path}"
CFG = yaml.safe_load(open(cfg_path, "r"))

CKPT_DIR = root / CFG["paths"]["checkpoints_dir"]
LOG_DIR  = root / CFG["paths"]["logs_dir"]
CKPT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

# 4) Seed
seed_everything(CFG["wlasl"]["split_seed"])


PYTHONPATH added: /home/falasoul/notebooks/USD/AAI-590/Capstone/AAI-590-G3-ASL


42

In [2]:
from src.utils.seed import seed_everything
from src.utils.checkpoints import save_checkpoint, load_checkpoint
print("imports OK; seed:", seed_everything(42))


imports OK; seed: 42


In [6]:
import torch, numpy as np, cv2, decord, random
from torch.utils.data import Dataset, DataLoader
import pandas as pd

decord.bridge.set_bridge("torch")

def _resize_112(frame_tchw: torch.Tensor) -> torch.Tensor:
    # frame_tchw: [T,C,H,W] float32 [0,1]
    T,C,H,W = frame_tchw.shape
    # Use OpenCV for speed; convert to NHWC
    arr = frame_tchw.permute(0,2,3,1).cpu().numpy()  # T,H,W,C
    out = np.empty((T,112,112,C), dtype=np.float32)
    for t in range(T):
        out[t] = cv2.resize(arr[t], (112,112), interpolation=cv2.INTER_AREA)
    out = torch.from_numpy(out).permute(0,3,1,2)  # T,C,112,112
    return out

def _normalize(frame_tchw: torch.Tensor, mean=(0.45,0.45,0.45), std=(0.225,0.225,0.225)) -> torch.Tensor:
    # per-channel normalization
    mean = torch.tensor(mean, dtype=frame_tchw.dtype, device=frame_tchw.device)[None,:,None,None]
    std  = torch.tensor(std,  dtype=frame_tchw.dtype, device=frame_tchw.device)[None,:,None,None]
    return (frame_tchw - mean) / std

def uniform_temporal_indices(n_total, clip_len, stride):
    # Aim to cover as much as possible; for short videos, loop-pad
    if n_total <= 0: return [0]*clip_len
    wanted = (clip_len-1)*stride + 1
    if n_total >= wanted:
        # center-start for consistent coverage
        start = (n_total - wanted)//2
        return [start + i*stride for i in range(clip_len)]
    # not enough frames: repeat last index
    idxs = [min(i*stride, n_total-1) for i in range(clip_len)]
    return idxs

class WLASLDataset(Dataset):
    def __init__(self, df: pd.DataFrame, clip_len=32, stride=2, train=False):
        self.df = df.reset_index(drop=True)
        self.clip_len = clip_len
        self.stride = stride
        self.train = train

    def __len__(self): return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        path = row["path"]
        label = int(row["label"])
        vr = decord.VideoReader(path)
        n = len(vr)

        idxs = uniform_temporal_indices(n, self.clip_len, self.stride)
        batch = vr.get_batch(idxs)  # [T,H,W,C] uint8
        # to float [0,1], TCHW
        x = batch.float()/255.0
        x = x.permute(0,3,1,2)
        # spatial resize 112x112
        x = _resize_112(x)
        # normalize
        x = _normalize(x)
        return x, label, path


####  Cell 2 — Load manifest & build DataLoaders

In [7]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from collections import Counter

# === Load dataset manifest (created in 02_preprocess_segments.ipynb) ===
MANIFEST = root / "data" / "metadata" / "wlasl100_manifest.csv"
m = pd.read_csv(MANIFEST)
print("Loaded manifest:", MANIFEST)
print("Total samples:", len(m))
print("Splits:", dict(m["split"].value_counts()))

# === Split subsets ===
train_df = m[m["split"] == "train"].copy()
val_df   = m[m["split"] == "val"].copy()
test_df  = m[m["split"] == "test"].copy()

# === Read config values ===
clip_len   = CFG["model"]["clip_len"]
frame_step = CFG["model"]["frame_stride"]
bs         = CFG["train"]["batch_size"]
nw         = CFG["train"]["num_workers"]

# === Import the dataset class (from 03_dataset_preview.ipynb or src/data/wlasl_ds.py) ===
# If you have the Dataset defined in the preview notebook, just re-run that cell before this.
# Otherwise, place it in `src/data/wlasl_ds.py` and import as shown:
# from src.data.wlasl_ds import WLASLDataset

# === Create train/val/test datasets ===
train_ds = WLASLDataset(train_df, clip_len=clip_len, stride=frame_step, train=True)
val_ds   = WLASLDataset(val_df,   clip_len=clip_len, stride=frame_step, train=False)
test_ds  = WLASLDataset(test_df,  clip_len=clip_len, stride=frame_step, train=False)

# === Handle class imbalance via WeightedRandomSampler ===
counts = train_df["label"].value_counts().to_dict()
weights = train_df["label"].map(lambda y: 1.0 / max(1, counts[y])).values
sampler = WeightedRandomSampler(
    torch.tensor(weights, dtype=torch.double),
    num_samples=len(train_df),
    replacement=True
)

# === Build DataLoaders ===
train_loader = DataLoader(
    train_ds, batch_size=bs, sampler=sampler,
    num_workers=nw, pin_memory=True
)
val_loader = DataLoader(
    val_ds, batch_size=bs, shuffle=False,
    num_workers=nw, pin_memory=True
)
test_loader = DataLoader(
    test_ds, batch_size=bs, shuffle=False,
    num_workers=nw, pin_memory=True
)

# === Confirm stats ===
num_classes = m["label"].nunique()
print(f"Classes: {num_classes}")
print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)} | Test batches: {len(test_loader)}")


Loaded manifest: /home/falasoul/notebooks/USD/AAI-590/Capstone/AAI-590-G3-ASL/data/metadata/wlasl100_manifest.csv
Total samples: 752
Splits: {'train': np.int64(547), 'val': np.int64(124), 'test': np.int64(81)}
Classes: 100
Train batches: 69 | Val batches: 16 | Test batches: 11


### Cell 3 (Notebook) — Small 3D CNN, AMP-ready, compile-ready

In [8]:
import torch
import torch.nn as nn

# (Optional) Slightly faster matmul on Ada/Lovelace
torch.set_float32_matmul_precision('high')

class C3Dlite(nn.Module):
    """
    A compact 3D CNN that trains fast on WLASL100 clips (112x112, T=32).
    Input expected as [B, T, C, H, W]; we permute internally to [B, C, T, H, W].
    """
    def __init__(self, num_classes=100, drop=0.5):
        super().__init__()
        def block(cin, cout, pool_t=2):
            return nn.Sequential(
                nn.Conv3d(cin, cout, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm3d(cout),
                nn.ReLU(inplace=True),
                nn.MaxPool3d(kernel_size=(pool_t,2,2), stride=(pool_t,2,2))
            )
        self.stem = nn.Sequential(
            nn.Conv3d(3, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True)
        )
        self.b1 = block(32,  64, pool_t=2)
        self.b2 = block(64, 128, pool_t=2)
        self.b3 = block(128, 256, pool_t=2)
        self.b4 = block(256, 256, pool_t=2)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Dropout(p=drop),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):           # x: [B,T,C,H,W]
        x = x.permute(0,2,1,3,4)    # -> [B,C,T,H,W]
        # ✅ make 5D tensor channels-last for 3D convs (NDHWC)
        x = x.contiguous(memory_format=torch.channels_last_3d)
        x = self.stem(x)
        x = self.b1(x); x = self.b2(x); x = self.b3(x); x = self.b4(x)
        x = self.head(x)
        return x

num_classes = m["label"].nunique()
model = C3Dlite(num_classes=num_classes).cuda()

# Speed/memory hints
#model = model.to(memory_format=torch.channels_last)  # helps on Ada/Lovelace
use_compile = True
if use_compile:
    try:
        model = torch.compile(model)  # PyTorch 2.7.1 present in your env
        print("torch.compile enabled")
    except Exception as e:
        print("compile skipped:", e)

print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


torch.compile enabled
Params: 2.96M


#### Cell 4 (Notebook) — Train loop with AMP, checkpoints, resume

In [9]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler       # ✅ new AMP API

from src.utils.seed import seed_everything
from src.utils.checkpoints import save_checkpoint, load_checkpoint

# reproducibility
seed_everything(CFG["wlasl"]["split_seed"])

# --- helpers to safely cast config values ---
def as_float(x, default):
    try:
        return float(x)
    except Exception:
        return default

def as_int(x, default):
    try:
        return int(x)
    except Exception:
        return default

def as_bool(x, default):
    if isinstance(x, bool):
        return x
    if isinstance(x, str):
        return x.strip().lower() in {"1","true","yes","y","on"}
    return bool(x) if x is not None else default

# --- parse training parameters from YAML ---
epochs   = as_int(CFG["train"].get("epochs", 12), 12)
lr       = as_float(CFG["train"].get("lr", 3e-4), 3e-4)
wd       = as_float(CFG["train"].get("weight_decay", 0.01), 0.01)
amp_on   = as_bool(CFG["train"].get("amp", True), True)
grad_acc = as_int(CFG["train"].get("grad_accum_steps", 1), 1)

print(f"epochs={epochs} (int)  lr={lr} (float)  wd={wd} (float)  "
      f"amp_on={amp_on} (bool)  grad_acc={grad_acc} (int)")

# --- optimizer + scaler ---
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
scaler = GradScaler("cuda", enabled=amp_on)     # ✅ new API

best_val_top1 = -1.0
start_epoch = 0

# optional resume checkpoint
resume_path = CFG["train"].get("resume_checkpoint", None)
if resume_path:
    rp = root / resume_path
    if rp.exists():
        start_epoch, best_val_top1 = load_checkpoint(str(rp), model, opt, scaler)
        print(f"Resumed from {rp} at epoch {start_epoch}, best={best_val_top1:.3f}")

# --- metric helper ---
def topk_acc(logits, target, k=1):
    with torch.no_grad():
        pred = logits.topk(k, dim=1).indices
        return (pred.eq(target[:, None]).any(dim=1).float().mean().item())

# --- one epoch loop ---
def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, total_top1, total_n = 0.0, 0.0, 0
    opt.zero_grad(set_to_none=True)

    for step, (x, y, _) in enumerate(loader):
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)

        with autocast("cuda", enabled=amp_on):   # ✅ updated syntax
            logits = model(x)
            loss = F.cross_entropy(logits, y) / grad_acc

        if train:
            scaler.scale(loss).backward()
            if (step + 1) % grad_acc == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

        with torch.no_grad():
            top1 = topk_acc(logits.detach(), y, k=1)
            bs = x.size(0)
            total_loss += (loss.item() * grad_acc) * bs
            total_top1 += top1 * bs
            total_n += bs

    return total_loss / total_n, total_top1 / total_n

# --- training loop ---
for epoch in range(start_epoch, epochs):
    tr_loss, tr_top1 = run_epoch(train_loader, train=True)
    va_loss, va_top1 = run_epoch(val_loader, train=False)

    print(f"Epoch {epoch+1:03d}/{epochs} | "
          f"train loss {tr_loss:.4f} top1 {tr_top1:.3f} | "
          f"val loss {va_loss:.4f} top1 {va_top1:.3f}")

    # save checkpoints
    state = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optim_state": opt.state_dict(),
        "scaler_state": scaler.state_dict(),
        "best_metric": best_val_top1,
    }
    save_checkpoint(state, is_best=False, ckpt_dir=str(CKPT_DIR), filename="last.pt")

    if va_top1 > best_val_top1:
        best_val_top1 = va_top1
        save_checkpoint(state, is_best=True, ckpt_dir=str(CKPT_DIR), filename="best.pt")


epochs=30 (int)  lr=0.001 (float)  wd=0.0001 (float)  amp_on=True (bool)  grad_acc=1 (int)
Epoch 001/30 | train loss 4.8943 top1 0.015 | val loss 4.8427 top1 0.024
Epoch 002/30 | train loss 4.7149 top1 0.015 | val loss 4.8105 top1 0.000
Epoch 003/30 | train loss 4.6237 top1 0.018 | val loss 4.7979 top1 0.016
Epoch 004/30 | train loss 4.5790 top1 0.027 | val loss 4.6666 top1 0.008
Epoch 005/30 | train loss 4.5973 top1 0.024 | val loss 4.8655 top1 0.008
Epoch 006/30 | train loss 4.6130 top1 0.024 | val loss 4.6790 top1 0.024
Epoch 007/30 | train loss 4.5344 top1 0.022 | val loss 4.6636 top1 0.016
Epoch 008/30 | train loss 4.4970 top1 0.024 | val loss 4.6514 top1 0.016
Epoch 009/30 | train loss 4.5345 top1 0.015 | val loss 4.7258 top1 0.000
Epoch 010/30 | train loss 4.5081 top1 0.024 | val loss 4.6124 top1 0.024
Epoch 011/30 | train loss 4.4786 top1 0.035 | val loss 4.6409 top1 0.016
Epoch 012/30 | train loss 4.4210 top1 0.022 | val loss 4.5000 top1 0.040
Epoch 013/30 | train loss 4.4085 

#### Cell 5 (Notebook) — Evaluate best checkpoint on test split

In [10]:
# Load best.pt and evaluate on test_loader (AMP-friendly eval)
from torch.amp import autocast  # new API

best_path = CKPT_DIR / "best.pt"
if best_path.exists():
    _, _ = load_checkpoint(str(best_path), model)  # weights only
else:
    print("best.pt not found, evaluating with current weights.")

model.eval()
test_loss, test_top1, n = 0.0, 0.0, 0

with torch.no_grad():
    for x, y, _ in test_loader:
        x = x.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)

        # AMP in eval for speed
        with autocast("cuda", enabled=True):
            logits = model(x)
            loss = F.cross_entropy(logits, y)

        bs = x.size(0)
        test_loss += loss.item() * bs
        test_top1 += (logits.topk(1, dim=1).indices.squeeze(1) == y).float().sum().item()
        n += bs

if n > 0:
    print(f"[TEST] loss {test_loss/n:.4f} top1 {test_top1/n:.3f}")
else:
    print("[TEST] loader is empty; no samples to evaluate.")


[TEST] loss 4.7616 top1 0.025


Early epochs with very low val top-1 (0–2%) usually means one of these: label mismatch, BN instability (small batches), learning rate/schedule not helping, or input/normalization quirks.

### A) Verify label space & splits

In [11]:
# Run in the training notebook
import pandas as pd
m = pd.read_csv(root / "data" / "metadata" / "wlasl100_manifest.csv")
print("classes:", m["label"].nunique())
print("train classes:", m[m.split=="train"]["label"].nunique())
print("val classes:",   m[m.split=="val"]["label"].nunique())
print("test classes:",  m[m.split=="test"]["label"].nunique())
missing_in_train = set(m[m.split=="val"]["label"].unique()) - set(m[m.split=="train"]["label"].unique())
print("val-only classes (should be empty or small):", sorted(list(missing_in_train))[:10])
assert m["label"].min()==0 and m["label"].max()==99


classes: 100
train classes: 100
val classes: 82
test classes: 65
val-only classes (should be empty or small): []


### B) Overfit a tiny subset (should climb fast if pipeline is healthy)

In [12]:
# Build a tiny train subset
tiny_idx = train_df.sample(200, random_state=0).index
tiny_ds = WLASLDataset(train_df.loc[tiny_idx], clip_len=clip_len, stride=frame_step, train=True)
tiny_loader = DataLoader(tiny_ds, batch_size=bs, shuffle=True, num_workers=0)

# One-epoch loop on tiny set
model.train()
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
from torch.amp import autocast, GradScaler
scaler = GradScaler("cuda", enabled=True)

for e in range(5):
    tot, acc, n = 0.0, 0.0, 0
    for x,y,_ in tiny_loader:
        x=x.cuda(); y=y.cuda()
        opt.zero_grad(set_to_none=True)
        with autocast("cuda", enabled=True):
            logits = model(x)
            loss = F.cross_entropy(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        # metrics
        with torch.no_grad():
            pred = logits.argmax(1)
            bs = y.size(0); n += bs
            tot += loss.item()*bs
            acc += (pred==y).float().sum().item()
    print(f"TINY epoch {e+1} loss {tot/n:.4f} acc {acc/n:.3f}")


TINY epoch 1 loss 5.0827 acc 0.005
TINY epoch 2 loss 4.7057 acc 0.015
TINY epoch 3 loss 4.6524 acc 0.015
TINY epoch 4 loss 4.5761 acc 0.010
TINY epoch 5 loss 4.5739 acc 0.010


##### since we are not over fitting, we will be doing some sanity and label diagnostic check. 

1) Sanity on labels vs logits

In [13]:
import pandas as pd, torch

# Manifest label space
m = pd.read_csv(root / "data" / "metadata" / "wlasl100_manifest.csv")
print("label dtype:", m["label"].dtype, "min/max:", m["label"].min(), m["label"].max(), "unique:", m["label"].nunique())

# Model output size
dummy_x, dummy_y, _ = next(iter(train_loader))
with torch.no_grad():
    logits = model(dummy_x.cuda())
print("logits shape:", logits.shape, "min/max label in this batch:", int(dummy_y.min()), int(dummy_y.max()))
assert logits.shape[1] == m["label"].nunique(), "num_classes mismatch!"
assert int(m["label"].min()) == 0 and int(m["label"].max()) == logits.shape[1]-1, "labels must be 0..num_classes-1"


label dtype: int64 min/max: 0 99 unique: 100
logits shape: torch.Size([8, 100]) min/max label in this batch: 29 85


2) Check for NaNs / exploding activations

In [14]:
import torch.nn.functional as F

x, y, _ = next(iter(train_loader))
x = x.cuda(); y = y.cuda()

model.eval()
with torch.no_grad():
    z = model(x)
    print("logits stats:", z.float().mean().item(), z.float().std().item(),
          "any_nan:", torch.isnan(z).any().item())
    print("loss:", F.cross_entropy(z, y).item())


logits stats: -11.715280532836914 6.308036804199219 any_nan: False
loss: 5.692608833312988


3) BatchNorm vs small batches

BatchNorm can be unstable with small, variable batches (video lengths, sampler). Two quick experiments:

(a) Freeze BN (uses running stats only)

In [16]:
def set_bn_eval(m):
    if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
        m.eval()
model.apply(set_bn_eval)


OptimizedModule(
  (_orig_mod): C3Dlite(
    (stem): Sequential(
      (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (b1): Sequential(
      (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (b2): Sequential(
      (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (b3): S

In [18]:
# Build a tiny train subset
tiny_idx = train_df.sample(200, random_state=0).index
tiny_ds  = WLASLDataset(train_df.loc[tiny_idx], clip_len=clip_len, stride=frame_step, train=True)
tiny_loader = DataLoader(tiny_ds, batch_size=bs, shuffle=True, num_workers=0)


# One-epoch loop on tiny set
model.train()
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
from torch.amp import autocast, GradScaler
scaler = GradScaler("cuda", enabled=True)

for e in range(5):
    tot, acc, n = 0.0, 0.0, 0
    for x,y,_ in tiny_loader:
        x=x.cuda(); y=y.cuda()
        opt.zero_grad(set_to_none=True)
        with autocast("cuda", enabled=True):
            logits = model(x)
            loss = F.cross_entropy(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        # metrics
        with torch.no_grad():
            pred = logits.argmax(1)
            bs = y.size(0); n += bs
            tot += loss.item()*bs
            acc += (pred==y).float().sum().item()
    print(f"TINY epoch {e+1} loss {tot/n:.4f} acc {acc/n:.3f}")


TINY epoch 1 loss 4.4329 acc 0.045
TINY epoch 2 loss 4.3162 acc 0.045
TINY epoch 3 loss 4.2662 acc 0.055
TINY epoch 4 loss 4.3038 acc 0.055
TINY epoch 5 loss 4.0960 acc 0.065


1) Sanity check: labels vs. logits (must be contiguous 0..99)

In [19]:
import pandas as pd, torch, torch.nn.functional as F

m = pd.read_csv(root / "data" / "metadata" / "wlasl100_manifest.csv")
print("labels unique:", m["label"].nunique(), "min/max:", m["label"].min(), m["label"].max())

xb, yb, _ = next(iter(train_loader))
xb = xb.cuda(); yb = yb.cuda()
with torch.no_grad():
    z = model(xb)  # logits
print("logits shape:", tuple(z.shape), "| batch label min/max:", int(yb.min()), int(yb.max()))
print("any NaN logits?:", torch.isnan(z).any().item(), "loss:", F.cross_entropy(z, yb).item())

# Hard assertions
assert z.shape[1] == m["label"].nunique(), "num_classes mismatch"
assert int(m["label"].min()) == 0 and int(m["label"].max()) == z.shape[1]-1, "labels must be 0..num_classes-1"


labels unique: 100 min/max: 0 99
logits shape: (8, 100) | batch label min/max: 7 81
any NaN logits?: False loss: 5.221561431884766


Great—that output tells us a lot:

Labels are contiguous 0..99 ✅

Model outputs 100 logits ✅

No NaNs ✅

So the pipeline is sane. The remaining culprit is almost certainly BatchNorm instability with small/variable video batches + early LR dynamics.

1) Do the one-batch overfit (must pass)

Run this as-is. If the model can’t overfit a single batch to ~100% within ~200 steps, we still have a hidden issue; if it can, it’s just training dynamics.

In [20]:
# One-batch overfit (should approach ~1.0 acc)
from torch.amp import autocast, GradScaler
import torch.optim as optim, torch.nn.functional as F

model.train()
xb, yb, _ = next(iter(train_loader))
xb = xb.cuda(); yb = yb.cuda()

opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scaler = GradScaler("cuda", enabled=True)

for t in range(200):
    opt.zero_grad(set_to_none=True)
    with autocast("cuda", enabled=True):
        logits = model(xb)
        loss = F.cross_entropy(logits, yb)
    scaler.scale(loss).backward()
    scaler.step(opt); scaler.update()

    if (t+1) % 20 == 0:
        acc = (logits.argmax(1) == yb).float().mean().item()
        print(f"step {t+1:03d} | loss {loss.item():.3f} | acc {acc:.3f}")


step 020 | loss 3.979 | acc 0.250
step 040 | loss 2.976 | acc 0.375
step 060 | loss 2.160 | acc 0.625
step 080 | loss 1.521 | acc 0.750
step 100 | loss 1.232 | acc 0.625
step 120 | loss 0.899 | acc 0.750
step 140 | loss 0.668 | acc 0.875
step 160 | loss 0.335 | acc 0.875
step 180 | loss 0.214 | acc 1.000
step 200 | loss 0.031 | acc 1.000


2) Make normalization video-friendly (small but helpful)

In [21]:
def _normalize(x, mean=(0.432, 0.394, 0.376), std=(0.228, 0.221, 0.223)):
    mean = torch.tensor(mean, dtype=x.dtype, device=x.device)[None,:,None,None]
    std  = torch.tensor(std,  dtype=x.dtype, device=x.device)[None,:,None,None]
    return (x - mean) / std


3) Replace BatchNorm3d with GroupNorm (robust for small batches)

In [22]:
import torch
import torch.nn as nn

def gn(c): return nn.GroupNorm(num_groups=8, num_channels=c)

class C3DliteGN(nn.Module):
    def __init__(self, num_classes=100, drop=0.5):
        super().__init__()
        def block(cin, cout, pool_t=2):
            return nn.Sequential(
                nn.Conv3d(cin, cout, kernel_size=3, padding=1, bias=False),
                gn(cout), nn.ReLU(inplace=True),
                nn.MaxPool3d(kernel_size=(pool_t,2,2), stride=(pool_t,2,2))
            )
        self.stem = nn.Sequential(
            nn.Conv3d(3, 32, kernel_size=3, padding=1, bias=False),
            gn(32), nn.ReLU(inplace=True)
        )
        self.b1 = block(32,  64)
        self.b2 = block(64, 128)
        self.b3 = block(128, 256)
        self.b4 = block(256, 256)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool3d(1), nn.Flatten(),
            nn.Dropout(p=drop), nn.Linear(256, num_classes)
        )

    def forward(self, x):              # x: [B,T,C,H,W]
        x = x.permute(0,2,1,3,4).contiguous()  # [B,C,T,H,W]
        x = self.stem(x)
        x = self.b1(x); x = self.b2(x); x = self.b3(x); x = self.b4(x)
        return self.head(x)

num_classes = m["label"].nunique()
model = C3DliteGN(num_classes=num_classes).cuda()
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")


Params: 2.96M


4) Add warmup + cosine schedule (prevents early stalls)

In [24]:
from torch.optim.lr_scheduler import CosineAnnealingLR

opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
scaler = GradScaler("cuda", enabled=amp_on)
warmup_epochs = 2
sched = CosineAnnealingLR(opt, T_max=max(1, epochs - warmup_epochs), eta_min=lr*0.1)

def run_epoch(loader, train=True, epoch=0):
    model.train() if train else model.eval()
    total_loss=total_top1=total_n=0.0
    opt.zero_grad(set_to_none=True)
    for step, (x,y,_) in enumerate(loader):
        x=x.cuda(non_blocking=True); y=y.cuda(non_blocking=True)
        with autocast("cuda", enabled=amp_on):
            logits = model(x)
            loss = F.cross_entropy(logits, y) / grad_acc
        if train:
            scaler.scale(loss).backward()
            if (step+1) % grad_acc == 0:
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
        with torch.no_grad():
            bs = x.size(0); total_n += bs
            total_loss += (loss.item()*grad_acc)*bs
            total_top1 += (logits.argmax(1)==y).float().sum().item()
    if train:
        if epoch < warmup_epochs:
            for g in opt.param_groups:
                g["lr"] = lr * float(epoch + 1) / warmup_epochs
        else:
            sched.step()
    return total_loss/total_n, total_top1/total_n

# --- training loop ---
for epoch in range(start_epoch, epochs):
    tr_loss, tr_top1 = run_epoch(train_loader, train=True,  epoch=epoch)
    va_loss, va_top1 = run_epoch(val_loader,   train=False, epoch=epoch)

    print(f"Epoch {epoch+1:03d}/{epochs} | "
          f"train loss {tr_loss:.4f} top1 {tr_top1:.3f} | "
          f"val loss {va_loss:.4f} top1 {va_top1:.3f}")

    state = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optim_state": opt.state_dict(),
        "scaler_state": scaler.state_dict(),
        "best_metric": best_val_top1,
    }

    # always save rolling "last.pt"
    save_checkpoint(state, is_best=False, ckpt_dir=str(CKPT_DIR), filename="last.pt")

    # update "best.pt" if improved (let save_checkpoint copy last->best)
    if va_top1 > best_val_top1:
        best_val_top1 = va_top1
        save_checkpoint(state, is_best=True, ckpt_dir=str(CKPT_DIR), filename="last.pt")



Epoch 001/30 | train loss 4.8558 top1 0.011 | val loss 4.7708 top1 0.008
Epoch 002/30 | train loss 4.7311 top1 0.013 | val loss 4.6861 top1 0.008
Epoch 003/30 | train loss 4.7015 top1 0.016 | val loss 4.6427 top1 0.008
Epoch 004/30 | train loss 4.6561 top1 0.013 | val loss 4.6425 top1 0.008
Epoch 005/30 | train loss 4.6802 top1 0.011 | val loss 4.6015 top1 0.016
Epoch 006/30 | train loss 4.6378 top1 0.011 | val loss 4.6174 top1 0.000
Epoch 007/30 | train loss 4.6476 top1 0.009 | val loss 4.6177 top1 0.000
Epoch 008/30 | train loss 4.6480 top1 0.009 | val loss 4.6186 top1 0.016
Epoch 009/30 | train loss 4.6171 top1 0.009 | val loss 4.6275 top1 0.016
Epoch 010/30 | train loss 4.6390 top1 0.009 | val loss 4.6112 top1 0.016
Epoch 011/30 | train loss 4.6401 top1 0.011 | val loss 4.6106 top1 0.008
Epoch 012/30 | train loss 4.6275 top1 0.013 | val loss 4.6050 top1 0.008
Epoch 013/30 | train loss 4.6067 top1 0.020 | val loss 4.6238 top1 0.008
Epoch 014/30 | train loss 4.6326 top1 0.018 | val l