In [17]:
# traffic_classifier_resnet.py
import os
import sys
import json
import random
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet18, ResNet18_Weights
import cv2

# --------------------
# Reproducibility
# --------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# --------------------
# Paths
# --------------------
IMAGES_PATH = "../output/traffic_lights_images.npy"
LABELS_PATH = "../output/traffic_lights_label.npy"
SAVE_DIR = "../Models"
CKPT_DIR = os.path.join(SAVE_DIR, "checkpoints")
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

# Checkpoint cadence
SAVE_EVERY = 2  # set 0 to disable periodic epoch_XXX saves

# --------------------
# cuDNN speedups
# --------------------
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

def traffic_light_crop(img: np.ndarray) -> np.ndarray:
    """
    Black out the bottom third, blur and desaturate the left & right quarters.
    Works for (H, W) grayscale or (H, W, C) RGB images.
    Returns a copy with same shape/dtype.
    """
    if img is None or img.ndim not in (2, 3):
        return img

    out = img.copy()
    h, w = out.shape[:2]

    # ---- Center zoom ----
    new_w = max(1, int(w / z))
    new_h = max(1, int(h / z))
    x1 = max(0, (w - new_w) // 2)
    y1 = max(0, (h - new_h) // 2)
    cropped = out[y1:y1 + new_h, x1:x1 + new_w]
    if cropped.size > 0:
        out = cv2.resize(cropped, (w, h))

    # --- Blackout bottom third ---
    out[h * 2 // 3 :, ...] = 0

    # --- Prepare for color manipulation ---
    is_gray = (out.ndim == 2) or (out.shape[2] == 1)
    if is_gray:
        # Expand grayscale to 3 channels for consistent processing
        out_rgb = cv2.cvtColor(out, cv2.COLOR_GRAY2BGR)
    else:
        out_rgb = out

    quarter_w = max(1, w // 4)

    def blur_and_desaturate(region: np.ndarray) -> np.ndarray:
        blurred = cv2.GaussianBlur(region, (21, 21), 0)
        # hsv = cv2.cvtColor(blurred, cv2.COLOR_BGR2HSV)
        # hsv[:, :, 1] = 0
        # desat = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        return blurred
    
    # Left strip
    out_rgb[:, :quarter_w, :] = blur_and_desaturate(out_rgb[:, :quarter_w, :])
    # Right strip
    out_rgb[:, -quarter_w:, :] = blur_and_desaturate(out_rgb[:, -quarter_w:, :])

    # If original was grayscale, convert back
    if is_gray:
        out_final = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2GRAY)
    else:
        out_final = out_rgb

    return out_final

# --------------------
# Load data (NumPy)
# --------------------
X = np.load(IMAGES_PATH)  # (N,H,W) or (N,H,W,C)
y_raw = np.load(LABELS_PATH)  # (N,) in {-1,0,1}

if X.ndim not in (3, 4):
    raise ValueError(f"Expected images with 3 or 4 dims (N,H,W[,C]); got {X.shape}")

# >>> Apply traffic_light_crop to every image BEFORE any channel moves/normalization
X = np.array([traffic_light_crop(x) for x in X], dtype=X.dtype)

# Channel-first
if X.ndim == 3:
    X = np.expand_dims(X, axis=1)             # (N,1,H,W)
else:
    X = np.transpose(X, (0, 3, 1, 2))         # (N,C,H,W)

X = X.astype("float32")
if X.max() > 1.5:
    X /= 255.0

# Labels {-1,0,1} -> {0,1,2}
label_to_index = {-1: 0, 0: 1, 1: 2}
index_to_label = {v: k for k, v in label_to_index.items()}
try:
    y = np.vectorize(label_to_index.__getitem__)(y_raw)
except KeyError as e:
    raise ValueError(f"Unknown label {e}; expected only -1, 0, 1")

num_classes = 3

# --------------------
# Tensors
# --------------------
X_torch = torch.from_numpy(X).contiguous()                    # (N,C,H,W) float32
y_torch = torch.from_numpy(y.astype(np.int64)).contiguous()   # (N,)

# --------------------
# Stratified split (80/20)
# --------------------
X_train_np, X_val_np, y_train_np, y_val_np = train_test_split(
    X_torch.numpy(), y_torch.numpy(),
    test_size=0.2, random_state=SEED, stratify=y_torch.numpy(), shuffle=True
)
X_train = torch.from_numpy(X_train_np).contiguous()
X_val   = torch.from_numpy(X_val_np).contiguous()
y_train = torch.from_numpy(y_train_np).contiguous()
y_val   = torch.from_numpy(y_val_np).contiguous()

# --------------------
# Transforms (RGB, 224, flip/rotate, cutout, ImageNet norm)
# --------------------
def to_rgb_tensor(x: torch.Tensor) -> torch.Tensor:
    # x: (C,H,W) float in [0,1]
    if x.dim() != 3:
        raise ValueError(f"Expected (C,H,W), got {tuple(x.shape)}")
    if x.shape[0] == 1:  # gray -> RGB
        x = x.repeat(3, 1, 1)
    elif x.shape[0] != 3:
        raise ValueError(f"Unsupported channels: {x.shape[0]}, expected 1 or 3.")
    return x

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_transform = T.Compose([
    T.Lambda(to_rgb_tensor),
    T.ToPILImage(),
    T.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=10, interpolation=InterpolationMode.BILINEAR, expand=False),
    T.ToTensor(),
    # blackout/cutout (chain for occasional multiple holes)
    T.RandomErasing(p=0.5, scale=(0.02, 0.12), ratio=(0.3, 3.3), value=0.0, inplace=False),
    T.RandomErasing(p=0.25, scale=(0.01, 0.06), ratio=(0.3, 3.3), value=0.0, inplace=False),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_transform = T.Compose([
    T.Lambda(to_rgb_tensor),
    T.ToPILImage(),
    T.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# --------------------
# Dataset wrapper
# --------------------
class AugmentedTensorDataset(TensorDataset):
    def __init__(self, *tensors, transform=None):
        super().__init__(*tensors); self.transform = transform
    def __getitem__(self, index):
        *xs, y = super().__getitem__(index)
        x = xs[0]
        if self.transform is not None:
            x = self.transform(x)
        return x, y

# --------------------
# DataLoaders (Windows/Jupyter safe)
# --------------------
def _workers():
    is_windows = (os.name == "nt")
    is_interactive = ("ipykernel" in sys.modules)
    if is_windows or is_interactive:
        return 0
    c = os.cpu_count() or 2
    return max(min(4, c - 1), 0)

num_workers = _workers()
persistent = True if num_workers > 0 else False

train_ds = AugmentedTensorDataset(X_train, y_train, transform=train_transform)
val_ds   = AugmentedTensorDataset(X_val,   y_val,   transform=val_transform)

train_loader = DataLoader(
    train_ds, batch_size=64, shuffle=True,
    num_workers=num_workers, pin_memory=True, persistent_workers=persistent
)
val_loader = DataLoader(
    val_ds, batch_size=64, shuffle=False,
    num_workers=num_workers, pin_memory=True, persistent_workers=persistent
)

# --------------------
# Model: ResNet-18 wrapped in ResNetClassifier (so keys are "net.*")
# --------------------
class ResNetClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)  # expects 3ch
        in_feat = backbone.fc.in_features
        backbone.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(in_feat, num_classes))
        self.net = backbone

    def forward(self, x):
        return self.net(x)

# --------------------
# Device & AMP
# --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
print("AMP enabled:", torch.cuda.is_available())

model = ResNetClassifier(num_classes=num_classes).to(device)

# --------------------
# Loss / Optim / Scheduler
# --------------------
ctr = Counter(y_train.tolist())
counts = np.array([ctr.get(i, 0) for i in range(num_classes)], dtype=np.float32)
if np.any(counts == 0):
    class_weights = None
else:
    inv = 1.0 / counts
    class_weights = torch.tensor(inv / inv.sum() * num_classes, dtype=torch.float32, device=device)

criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

# two lrs: backbone vs head
head_params = list(model.net.fc.parameters())
backbone_params = [p for n, p in model.net.named_parameters() if not n.startswith("fc.")]
optimizer = optim.AdamW(
    [
        {"params": backbone_params, "lr": 1e-4},
        {"params": head_params,     "lr": 3e-4},
    ],
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6
)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# --------------------
# Checkpoint helpers
# --------------------
def _checkpoint_payload(epoch:int, best_val_loss:float, best_val_acc:float):
    return {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "val_loss": float(best_val_loss),
        "val_acc": float(best_val_acc),
        "num_classes": int(num_classes),
        "label_to_index": label_to_index,
        "index_to_label": index_to_label,
        "imagenet_mean": IMAGENET_MEAN,
        "imagenet_std": IMAGENET_STD,
        "seed": SEED,
    }

def save_last(epoch, best_val_loss, best_val_acc):
    path = os.path.join(CKPT_DIR, "last.pth")
    torch.save(_checkpoint_payload(epoch, best_val_loss, best_val_acc), path)
    print(f"[ckpt] Saved last -> {path}")

def save_best(epoch, best_val_loss, best_val_acc):
    path = os.path.join(CKPT_DIR, "best.pth")
    torch.save(_checkpoint_payload(epoch, best_val_loss, best_val_acc), path)
    print(f"[ckpt] ✅ New best (val_loss={best_val_loss:.4f}) -> {path}")

def save_periodic(epoch, best_val_loss, best_val_acc):
    if SAVE_EVERY and (epoch % SAVE_EVERY == 0):
        path = os.path.join(CKPT_DIR, f"epoch_{epoch:03d}.pth")
        torch.save(_checkpoint_payload(epoch, best_val_loss, best_val_acc), path)
        print(f"[ckpt] Saved periodic -> {path}")

# --------------------
# Eval
# --------------------
def evaluate(loader):
    model.eval()
    total, correct, run_loss = 0, 0, 0.0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(xb)
                loss = criterion(logits, yb)
            run_loss += loss.item() * yb.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    return run_loss / max(total, 1), correct / max(total, 1)

# --------------------
# Train
# --------------------
EPOCHS = 15
best_val_acc = 0.0
best_val_loss = float("inf")
best_state = None
es_patience = 6
es_wait = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    run_loss = 0.0; seen = 0

    for xb, yb in train_loader:
        xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(xb)
            loss = criterion(logits, yb)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()
        run_loss += loss.item() * yb.size(0); seen += yb.size(0)

    train_loss = run_loss / max(seen, 1)
    val_loss, val_acc = evaluate(val_loader)

    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    if new_lr != old_lr:
        print(f"[LR] reduced: {old_lr:.6g} -> {new_lr:.6g}")

    print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")

    # Save "last" every epoch + periodic snapshots
    save_last(epoch, best_val_loss, best_val_acc)
    save_periodic(epoch, best_val_loss, best_val_acc)

    # Track & save best
    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        best_val_acc = val_acc
        best_state = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_acc": float(val_acc),
            "val_loss": float(val_loss),
            "num_classes": int(num_classes),
        }
        save_best(epoch, best_val_loss, best_val_acc)
        es_wait = 0
    else:
        es_wait += 1
        if es_wait >= es_patience:
            print(f"Early stopping at epoch {epoch} (no val_loss improvement for {es_patience} epochs).")
            break

# --------------------
# Final eval
# --------------------
if best_state is not None:
    model.load_state_dict(best_state["model_state_dict"])
final_val_loss, final_val_acc = evaluate(val_loader)
print(f"Best validation loss:     {best_val_loss:.4f}")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Final validation loss:    {final_val_loss:.4f}")
print(f"Final validation accuracy:{final_val_acc:.4f}")

# --------------------
# Save artifacts (keys start with 'net.' -> matches testing script)
# --------------------
pth_path = os.path.join(SAVE_DIR, "traffic_classifier_state.pth")
torch.save(best_state if best_state is not None else model.state_dict(), pth_path)
print(f"Saved state to: {pth_path}")

# TorchScript (scripted model expects 3x224x224 normalized)
# Build a normalized example via val_transform
ex = val_transform(X_val[:1].squeeze(0) if X_val[:1].ndim == 4 else X_val[:1])
example = ex.unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
    scripted = torch.jit.trace(model, example)
ts_path = os.path.join(SAVE_DIR, "traffic_classifier_scripted.pt")
scripted.save(ts_path)
print(f"Saved TorchScript model to: {ts_path}")

# Class mapping & metrics
with open(os.path.join(SAVE_DIR, "class_mapping.json"), "w") as f:
    json.dump(
        {
            "index_to_label": {str(k): int(v) for k, v in index_to_label.items()},
            "label_to_index": {str(k): int(v) for k, v in label_to_index.items()},
            "semantics": {"-1": "no light", "0": "red", "1": "green"}
        },
        f,
        indent=2
    )

with open(os.path.join(SAVE_DIR, "metrics.txt"), "w") as f:
    f.write(f"Final validation accuracy: {final_val_acc:.4f}\n")
    f.write(f"Final validation loss: {final_val_loss:.4f}\n")
    f.write(f"Best validation accuracy: {best_val_acc:.4f}\n")
    f.write(f"Best validation loss: {best_val_loss:.4f}\n")

# --------------------
# Prediction helper (applies same val pipeline)
# --------------------
def predict_labels_numpy(np_batch: np.ndarray, batch_size: int = 128) -> np.ndarray:
    model.eval()
    if np_batch.ndim == 3:
        np_batch = np.expand_dims(np_batch, axis=0)
    if np_batch.shape[-1] in (1,3) and np_batch.shape[1] not in (1,3):
        np_batch = np.transpose(np_batch, (0,3,1,2))
    if np_batch.dtype != np.float32:
        np_batch = np_batch.astype("float32")
    if np_batch.max() > 1.5:
        np_batch /= 255.0

    preds = []
    with torch.no_grad():
        for i in range(0, np_batch.shape[0], batch_size):
            chunk = torch.from_numpy(np_batch[i:i+batch_size])
            batch_t = [val_transform(x) for x in chunk]  # ensure RGB+224+norm
            batch_t = torch.stack(batch_t, dim=0).to(device, non_blocking=True)
            logits = model(batch_t)
            preds.append(logits.argmax(dim=1).cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    return np.vectorize(index_to_label.__getitem__)(preds)


NameError: name 'z' is not defined

In [19]:
# traffic_classifier_resnet.py
import os
import sys
import json
import random
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet18, ResNet18_Weights
import cv2

# --------------------
# Reproducibility
# --------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

# --------------------
# Paths
# --------------------
IMAGES_PATH = "../output/traffic_lights_images.npy"
LABELS_PATH = "../output/traffic_lights_label.npy"
SAVE_DIR = "../Models"
CKPT_DIR = os.path.join(SAVE_DIR, "checkpoints")
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)

# Checkpoint cadence
SAVE_EVERY = 2  # set 0 to disable periodic epoch_XXX saves

# --------------------
# cuDNN speedups
# --------------------
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

def traffic_light_crop(img: np.ndarray) -> np.ndarray:
    """
    Black out the bottom third, blur and desaturate the left & right quarters.
    Works for (H, W) grayscale or (H, W, C) RGB images.
    Returns a copy with same shape/dtype.
    """
    if img is None or img.ndim not in (2, 3):
        return img

    out = img.copy()
    h, w = out.shape[:2]

    # ---- Center zoom ----
    z = 1.0 + (20.0 / 100.0)
    new_w = max(1, int(w / z))
    new_h = max(1, int(h / z))
    x1 = max(0, (w - new_w) // 2)
    y1 = max(0, (h - new_h) // 2)
    cropped = out[y1:y1 + new_h, x1:x1 + new_w]
    if cropped.size > 0:
        out = cv2.resize(cropped, (w, h))

    # --- Blackout bottom third ---
    out[h * 2 // 3 :, ...] = 0

    # --- Prepare for color manipulation ---
    is_gray = (out.ndim == 2) or (out.shape[2] == 1)
    if is_gray:
        # Expand grayscale to 3 channels for consistent processing
        out_rgb = cv2.cvtColor(out, cv2.COLOR_GRAY2BGR)
    else:
        out_rgb = out

    quarter_w = max(1, w // 4)

    def blur_and_desaturate(region: np.ndarray) -> np.ndarray:
        # Blur
        blurred = cv2.GaussianBlur(region, (21, 21), 0)
        # # Convert to HSV for saturation adjustment
        # hsv = cv2.cvtColor(blurred, cv2.COLOR_BGR2HSV)
        # hsv[:, :, 1] = 0  # set saturation to zero (pure grayscale)
        # desat = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        return blurred

    # Left strip
    out_rgb[:, :quarter_w, :] = blur_and_desaturate(out_rgb[:, :quarter_w, :])
    # Right strip
    out_rgb[:, -quarter_w:, :] = blur_and_desaturate(out_rgb[:, -quarter_w:, :])

    # If original was grayscale, convert back
    if is_gray:
        out_final = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2GRAY)
    else:
        out_final = out_rgb

    return out_final

# --------------------
# Load data (NumPy)
# --------------------
X = np.load(IMAGES_PATH)  # (N,H,W) or (N,H,W,C)
y_raw = np.load(LABELS_PATH)  # (N,) in {-1,0,1}

if X.ndim not in (3, 4):
    raise ValueError(f"Expected images with 3 or 4 dims (N,H,W[,C]); got {X.shape}")

# >>> Apply traffic_light_crop to every image BEFORE any channel moves/normalization
X = np.array([traffic_light_crop(x) for x in X], dtype=X.dtype)

# Channel-first
if X.ndim == 3:
    X = np.expand_dims(X, axis=1)             # (N,1,H,W)
else:
    X = np.transpose(X, (0, 3, 1, 2))         # (N,C,H,W)

X = X.astype("float32")
if X.max() > 1.5:
    X /= 255.0

# Labels {-1,0,1} -> {0,1,2}
label_to_index = {-1: 0, 0: 1, 1: 2}
index_to_label = {v: k for k, v in label_to_index.items()}
try:
    y = np.vectorize(label_to_index.__getitem__)(y_raw)
except KeyError as e:
    raise ValueError(f"Unknown label {e}; expected only -1, 0, 1")

num_classes = 3

# --------------------
# Tensors
# --------------------
X_torch = torch.from_numpy(X).contiguous()                    # (N,C,H,W) float32
y_torch = torch.from_numpy(y.astype(np.int64)).contiguous()   # (N,)

# --------------------
# Stratified split (80/20)
# --------------------
X_train_np, X_val_np, y_train_np, y_val_np = train_test_split(
    X_torch.numpy(), y_torch.numpy(),
    test_size=0.2, random_state=SEED, stratify=y_torch.numpy(), shuffle=True
)
X_train = torch.from_numpy(X_train_np).contiguous()
X_val   = torch.from_numpy(X_val_np).contiguous()
y_train = torch.from_numpy(y_train_np).contiguous()
y_val   = torch.from_numpy(y_val_np).contiguous()

# --------------------
# Transforms (RGB, 224, flip/rotate, cutout, ImageNet norm)
# --------------------
def to_rgb_tensor(x: torch.Tensor) -> torch.Tensor:
    # x: (C,H,W) float in [0,1]
    if x.dim() != 3:
        raise ValueError(f"Expected (C,H,W), got {tuple(x.shape)}")
    if x.shape[0] == 1:  # gray -> RGB
        x = x.repeat(3, 1, 1)
    elif x.shape[0] != 3:
        raise ValueError(f"Unsupported channels: {x.shape[0]}, expected 1 or 3.")
    return x

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_transform = T.Compose([
    T.Lambda(to_rgb_tensor),
    T.ToPILImage(),
    T.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=10, interpolation=InterpolationMode.BILINEAR, expand=False),
    T.ToTensor(),
    # blackout/cutout (chain for occasional multiple holes)
    T.RandomErasing(p=0.5, scale=(0.02, 0.12), ratio=(0.3, 3.3), value=0.0, inplace=False),
    T.RandomErasing(p=0.25, scale=(0.01, 0.06), ratio=(0.3, 3.3), value=0.0, inplace=False),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_transform = T.Compose([
    T.Lambda(to_rgb_tensor),
    T.ToPILImage(),
    T.Resize((224, 224), interpolation=InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# --------------------
# Dataset wrapper
# --------------------
class AugmentedTensorDataset(TensorDataset):
    def __init__(self, *tensors, transform=None):
        super().__init__(*tensors); self.transform = transform
    def __getitem__(self, index):
        *xs, y = super().__getitem__(index)
        x = xs[0]
        if self.transform is not None:
            x = self.transform(x)
        return x, y

# --------------------
# DataLoaders (Windows/Jupyter safe) + Balanced Sampler
# --------------------
def _workers():
    is_windows = (os.name == "nt")
    is_interactive = ("ipykernel" in sys.modules)
    if is_windows or is_interactive:
        return 0
    c = os.cpu_count() or 2
    return max(min(4, c - 1), 0)

num_workers = _workers()
persistent = True if num_workers > 0 else False

train_ds = AugmentedTensorDataset(X_train, y_train, transform=train_transform)
val_ds   = AugmentedTensorDataset(X_val,   y_val,   transform=val_transform)

# ---- Balanced sampler to equalize exposure per class each epoch ----
class_counts = np.bincount(y_train_np, minlength=num_classes).astype(np.float64)
class_counts[class_counts == 0] = 1.0  # safety
inv_freq = 1.0 / class_counts
sample_weights = inv_freq[y_train_np]  # per-sample weights

train_sampler = WeightedRandomSampler(
    weights=torch.from_numpy(sample_weights).float(),
    num_samples=len(y_train_np),  # approx one pass over "balanced" epoch
    replacement=True
)

train_loader = DataLoader(
    train_ds, batch_size=64, shuffle=False,    # shuffle must be False when sampler is set
    sampler=train_sampler,
    num_workers=num_workers, pin_memory=True, persistent_workers=persistent
)
val_loader = DataLoader(
    val_ds, batch_size=64, shuffle=False,
    num_workers=num_workers, pin_memory=True, persistent_workers=persistent
)

# Optional: quick look at sampled mix this epoch
with torch.no_grad():
    idxs = torch.tensor(list(train_sampler))
epoch_counts = np.bincount(y_train_np[idxs.numpy()], minlength=num_classes)
print("Sampled class counts (this epoch):", epoch_counts.tolist(),
      " [order: index 0=-1 (no), 1=0 (red), 2=1 (green)]")

# --------------------
# Model: ResNet-18 wrapped in ResNetClassifier (so keys are "net.*")
# --------------------
class ResNetClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)  # expects 3ch
        in_feat = backbone.fc.in_features
        backbone.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(in_feat, num_classes))
        self.net = backbone

    def forward(self, x):
        return self.net(x)

# --------------------
# Device & AMP
# --------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
print("AMP enabled:", torch.cuda.is_available())

model = ResNetClassifier(num_classes=num_classes).to(device)

# --------------------
# Loss / Optim / Scheduler
# --------------------
# With balanced sampling, do NOT also weight the loss (avoid double-compensation)
class_weights = None
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

# two lrs: backbone vs head
head_params = list(model.net.fc.parameters())
backbone_params = [p for n, p in model.net.named_parameters() if not n.startswith("fc.")]
optimizer = optim.AdamW(
    [
        {"params": backbone_params, "lr": 1e-4},
        {"params": head_params,     "lr": 3e-4},
    ],
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6
)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# --------------------
# Checkpoint helpers
# --------------------
def _checkpoint_payload(epoch:int, best_val_loss:float, best_val_acc:float):
    return {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "val_loss": float(best_val_loss),
        "val_acc": float(best_val_acc),
        "num_classes": int(num_classes),
        "label_to_index": label_to_index,
        "index_to_label": index_to_label,
        "imagenet_mean": IMAGENET_MEAN,
        "imagenet_std": IMAGENET_STD,
        "seed": SEED,
    }

def save_last(epoch, best_val_loss, best_val_acc):
    path = os.path.join(CKPT_DIR, "last.pth")
    torch.save(_checkpoint_payload(epoch, best_val_loss, best_val_acc), path)
    print(f"[ckpt] Saved last -> {path}")

def save_best(epoch, best_val_loss, best_val_acc):
    path = os.path.join(CKPT_DIR, "best.pth")
    torch.save(_checkpoint_payload(epoch, best_val_loss, best_val_acc), path)
    print(f"[ckpt] ✅ New best (val_loss={best_val_loss:.4f}) -> {path}")

def save_periodic(epoch, best_val_loss, best_val_acc):
    if SAVE_EVERY and (epoch % SAVE_EVERY == 0):
        path = os.path.join(CKPT_DIR, f"epoch_{epoch:03d}.pth")
        torch.save(_checkpoint_payload(epoch, best_val_loss, best_val_acc), path)
        print(f"[ckpt] Saved periodic -> {path}")

# --------------------
# Eval (returns loss, overall acc, per-class accs, macro acc)
# --------------------
SEMANTICS = { -1: "no light", 0: "red", 1: "green" }

def per_class_metrics(all_targets, all_preds, num_classes):
    cm = confusion_matrix(all_targets, all_preds, labels=list(range(num_classes)))
    per_class_acc = cm.diagonal() / np.maximum(cm.sum(axis=1), 1)
    macro = per_class_acc.mean() if len(per_class_acc) else 0.0
    return per_class_acc, macro

def evaluate(loader):
    model.eval()
    total, correct, run_loss = 0, 0, 0.0
    all_t, all_p = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(xb)
                loss = criterion(logits, yb)
            run_loss += loss.item() * yb.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
            all_t.append(yb.cpu().numpy()); all_p.append(preds.cpu().numpy())
    all_t = np.concatenate(all_t) if all_t else np.array([])
    all_p = np.concatenate(all_p) if all_p else np.array([])
    per_cls, macro = per_class_metrics(all_t, all_p, num_classes) if all_t.size else (np.zeros(num_classes), 0.0)

    # Pretty per-class print using label semantics
    # index_to_label: {0:-1, 1:0, 2:1}
    readable = []
    for idx in range(num_classes):
        raw_label = index_to_label[idx]
        name = SEMANTICS[raw_label]
        readable.append(f"{name}={per_cls[idx]:.3f}")
    print("Per-class acc:", ", ".join(readable), f"| macro={macro:.3f}")

    return run_loss / max(total, 1), correct / max(total, 1), per_cls, macro

# --------------------
# Train
# --------------------
EPOCHS = 15
best_val_acc = 0.0
best_val_loss = float("inf")
best_state = None
es_patience = 6
es_wait = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    run_loss = 0.0; seen = 0

    for xb, yb in train_loader:
        xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(xb)
            loss = criterion(logits, yb)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()
        run_loss += loss.item() * yb.size(0); seen += yb.size(0)

    train_loss = run_loss / max(seen, 1)
    val_loss, val_acc, val_per_cls, val_macro = evaluate(val_loader)

    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    if new_lr != old_lr:
        print(f"[LR] reduced: {old_lr:.6g} -> {new_lr:.6g}")

    print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f} | val_macro: {val_macro:.4f}")

    # Save "last" every epoch + periodic snapshots
    save_last(epoch, best_val_loss, best_val_acc)
    save_periodic(epoch, best_val_loss, best_val_acc)

    # Track & save best by val_loss
    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        best_val_acc = val_acc
        best_state = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_acc": float(val_acc),
            "val_loss": float(val_loss),
            "num_classes": int(num_classes),
        }
        save_best(epoch, best_val_loss, best_val_acc)
        es_wait = 0
    else:
        es_wait += 1
        if es_wait >= es_patience:
            print(f"Early stopping at epoch {epoch} (no val_loss improvement for {es_patience} epochs).")
            break

# --------------------
# Final eval
# --------------------
if best_state is not None:
    model.load_state_dict(best_state["model_state_dict"])
final_val_loss, final_val_acc, final_per_cls, final_macro = evaluate(val_loader)
print(f"Best validation loss:     {best_val_loss:.4f}")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Final validation loss:    {final_val_loss:.4f}")
print(f"Final validation accuracy:{final_val_acc:.4f}")
print("Final per-class acc:",
      f"no-light={final_per_cls[0]:.4f}, red={final_per_cls[1]:.4f}, green={final_per_cls[2]:.4f}",
      f"| macro={final_macro:.4f}")

# --------------------
# Save artifacts (keys start with 'net.' -> matches testing script)
# --------------------
pth_path = os.path.join(SAVE_DIR, "traffic_classifier_state.pth")
torch.save(best_state if best_state is not None else model.state_dict(), pth_path)
print(f"Saved state to: {pth_path}")

# TorchScript (scripted model expects 3x224x224 normalized)
# Build a normalized example via val_transform
ex = val_transform(X_val[:1].squeeze(0) if X_val[:1].ndim == 4 else X_val[:1])
example = ex.unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
    scripted = torch.jit.trace(model, example)
ts_path = os.path.join(SAVE_DIR, "traffic_classifier_scripted.pt")
scripted.save(ts_path)
print(f"Saved TorchScript model to: {ts_path}")

# Class mapping & metrics
with open(os.path.join(SAVE_DIR, "class_mapping.json"), "w") as f:
    json.dump(
        {
            "index_to_label": {str(k): int(v) for k, v in index_to_label.items()},
            "label_to_index": {str(k): int(v) for k, v in label_to_index.items()},
            "semantics": {"-1": "no light", "0": "red", "1": "green"}
        },
        f,
        indent=2
    )

with open(os.path.join(SAVE_DIR, "metrics.txt"), "w") as f:
    f.write(f"Final validation accuracy: {final_val_acc:.4f}\n")
    f.write(f"Final validation loss: {final_val_loss:.4f}\n")
    f.write(f"Best validation accuracy: {best_val_acc:.4f}\n")
    f.write(f"Best validation loss: {best_val_loss:.4f}\n")
    f.write(f"Per-class accuracy (no, red, green): "
            f"{final_per_cls[0]:.4f}, {final_per_cls[1]:.4f}, {final_per_cls[2]:.4f} | macro={final_macro:.4f}\n")

# --------------------
# Prediction helper (applies same val pipeline)
# --------------------
def predict_labels_numpy(np_batch: np.ndarray, batch_size: int = 128) -> np.ndarray:
    model.eval()
    if np_batch.ndim == 3:
        np_batch = np.expand_dims(np_batch, axis=0)
    if np_batch.shape[-1] in (1,3) and np_batch.shape[1] not in (1,3):
        np_batch = np.transpose(np_batch, (0,3,1,2))
    if np_batch.dtype != np.float32:
        np_batch = np_batch.astype("float32")
    if np_batch.max() > 1.5:
        np_batch /= 255.0

    preds = []
    with torch.no_grad():
        for i in range(0, np_batch.shape[0], batch_size):
            chunk = torch.from_numpy(np_batch[i:i+batch_size])
            batch_t = [val_transform(x) for x in chunk]  # ensure RGB+224+norm
            batch_t = torch.stack(batch_t, dim=0).to(device, non_blocking=True)
            logits = model(batch_t)
            preds.append(logits.argmax(dim=1).cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    return np.vectorize(index_to_label.__getitem__)(preds)


Sampled class counts (this epoch): [835, 848, 814]  [order: index 0=-1 (no), 1=0 (red), 2=1 (green)]
CUDA available: True
GPU: NVIDIA GeForce GTX 1060 6GB
AMP enabled: True


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.847, red=0.792, green=0.782 | macro=0.807
Epoch 01/15 | train_loss: 0.8714 | val_loss: 0.5925 | val_acc: 0.8048 | val_macro: 0.8068
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] ✅ New best (val_loss=0.5925) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.952, red=0.903, green=0.852 | macro=0.902
Epoch 02/15 | train_loss: 0.4872 | val_loss: 0.4246 | val_acc: 0.8992 | val_macro: 0.9024
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_002.pth
[ckpt] ✅ New best (val_loss=0.4246) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.958, red=0.903, green=0.913 | macro=0.925
Epoch 03/15 | train_loss: 0.3585 | val_loss: 0.3693 | val_acc: 0.9232 | val_macro: 0.9246
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] ✅ New best (val_loss=0.3693) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.984, red=0.961, green=0.921 | macro=0.956
Epoch 04/15 | train_loss: 0.3129 | val_loss: 0.3038 | val_acc: 0.9536 | val_macro: 0.9556
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_004.pth
[ckpt] ✅ New best (val_loss=0.3038) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.989, red=0.971, green=0.965 | macro=0.975
Epoch 05/15 | train_loss: 0.2936 | val_loss: 0.2619 | val_acc: 0.9744 | val_macro: 0.9752
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] ✅ New best (val_loss=0.2619) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.963, red=0.971, green=0.974 | macro=0.969
Epoch 06/15 | train_loss: 0.2824 | val_loss: 0.2823 | val_acc: 0.9696 | val_macro: 0.9693
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_006.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.968, red=0.957, green=0.956 | macro=0.960
Epoch 07/15 | train_loss: 0.2642 | val_loss: 0.2721 | val_acc: 0.9600 | val_macro: 0.9604
[ckpt] Saved last -> ../Models\checkpoints\last.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=1.000, red=0.976, green=0.948 | macro=0.974
Epoch 08/15 | train_loss: 0.2529 | val_loss: 0.2617 | val_acc: 0.9728 | val_macro: 0.9745
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_008.pth
[ckpt] ✅ New best (val_loss=0.2617) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.995, red=0.971, green=0.991 | macro=0.986
Epoch 09/15 | train_loss: 0.2441 | val_loss: 0.2377 | val_acc: 0.9856 | val_macro: 0.9857
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] ✅ New best (val_loss=0.2377) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=1.000, red=0.971, green=0.983 | macro=0.985
Epoch 10/15 | train_loss: 0.2494 | val_loss: 0.2305 | val_acc: 0.9840 | val_macro: 0.9845
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_010.pth
[ckpt] ✅ New best (val_loss=0.2305) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.995, red=0.990, green=0.991 | macro=0.992
Epoch 11/15 | train_loss: 0.2269 | val_loss: 0.2278 | val_acc: 0.9920 | val_macro: 0.9921
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] ✅ New best (val_loss=0.2278) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=0.974, red=0.990, green=0.987 | macro=0.984
Epoch 12/15 | train_loss: 0.2302 | val_loss: 0.2322 | val_acc: 0.9840 | val_macro: 0.9836
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_012.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=1.000, red=0.981, green=0.978 | macro=0.986
Epoch 13/15 | train_loss: 0.2257 | val_loss: 0.2323 | val_acc: 0.9856 | val_macro: 0.9863
[ckpt] Saved last -> ../Models\checkpoints\last.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=1.000, red=0.981, green=0.978 | macro=0.986
Epoch 14/15 | train_loss: 0.2199 | val_loss: 0.2228 | val_acc: 0.9856 | val_macro: 0.9863
[ckpt] Saved last -> ../Models\checkpoints\last.pth
[ckpt] Saved periodic -> ../Models\checkpoints\epoch_014.pth
[ckpt] ✅ New best (val_loss=0.2228) -> ../Models\checkpoints\best.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=1.000, red=0.976, green=0.974 | macro=0.983
Epoch 15/15 | train_loss: 0.2137 | val_loss: 0.2241 | val_acc: 0.9824 | val_macro: 0.9832
[ckpt] Saved last -> ../Models\checkpoints\last.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):


Per-class acc: no light=1.000, red=0.976, green=0.974 | macro=0.983
Best validation loss:     0.2228
Best validation accuracy: 0.9856
Final validation loss:    0.2241
Final validation accuracy:0.9824
Final per-class acc: no-light=1.0000, red=0.9758, green=0.9738 | macro=0.9832
Saved state to: ../Models\traffic_classifier_state.pth
Saved TorchScript model to: ../Models\traffic_classifier_scripted.pt
