In [None]:
import math
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader


# -----------------------------
# Utilities
# -----------------------------

def seed_everything(seed: int = 42):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


def masked_sparsity(model: nn.Module, masks: Dict[str, torch.Tensor]) -> Tuple[int, int, float]:
    total, zero = 0, 0
    for name, p in model.named_parameters():
        if 'weight' in name and name in masks:
            m = masks[name]
            total += m.numel()
            zero += int((m == 0).sum().item())
    sp = 0.0 if total == 0 else zero / total
    return total, zero, sp


# -----------------------------
# CNN Backbone
# -----------------------------

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, dropout: float = 0.0):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.drop = nn.Dropout2d(p=dropout) if dropout > 0 else nn.Identity()

        # He init
        nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x, inplace=True)
        x = self.drop(x)
        return x


class SimpleCNN(nn.Module):
    """
    Channels list -> sequence of ConvBlocks. MaxPool(2) after every 2 blocks.
    Global Average Pooling -> Linear head.
    """
    def __init__(self, in_channels: int, num_classes: int, channels: List[int], dropout: float = 0.0):
        super().__init__()
        assert len(channels) >= 1, "Provide at least one conv block."

        blocks = []
        c_in = in_channels
        for idx, c_out in enumerate(channels):
            blocks.append(ConvBlock(c_in, c_out, dropout=dropout))
            if (idx + 1) % 2 == 0:
                blocks.append(nn.MaxPool2d(kernel_size=2))
            c_in = c_out
        self.features = nn.Sequential(*blocks)
        self.head = nn.Linear(channels[-1], num_classes)
        nn.init.kaiming_normal_(self.head.weight, nonlinearity='linear')
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        x = self.features(x)
        # GAP head
        x = F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-2)  # (B, C)
        return self.head(x)


# -----------------------------
# Neff Pruning
# -----------------------------

@dataclass
class NeffConfig:
    beta: float = 1.0              # keep top floor(beta * Neff)
    scope: str = "tensor"          # "tensor" or "per_out_channel"
    prune_bias: bool = False
    # schedule
    prune_at_epoch: Optional[int] = None  # if None: no pruning; else apply at given epoch (1-indexed)
    prune_every: Optional[int] = None     # if set (e.g., 1,2,5), prune every K epochs starting after warmup_epochs
    warmup_epochs: int = 1               # epochs before periodic pruning starts
    verbose: bool = True


def _compute_neff_from_abs(abs_vec: torch.Tensor) -> int:
    """
    abs_vec: flattened absolute values (>=0)
    Returns Neff = floor(1 / sum(p_i^2)), where p = abs_vec / abs_vec.sum() (if sum>0).
    Guarantees at least 1 when sum>0; returns 0 if all zeros.
    """
    sum_abs = abs_vec.sum()
    if sum_abs <= 0:
        return 0
    p = abs_vec / sum_abs
    # numerical guard: use float32 for stability
    neff = int(torch.floor(1.0 / torch.clamp((p * p).sum(), min=1e-12)).item())
    neff = max(neff, 1)
    return neff


def _topk_mask_by_neff(weight: torch.Tensor, beta: float) -> torch.Tensor:
    """
    One-shot mask for the entire tensor.
    """
    w = weight.detach()
    w_abs = w.abs().reshape(-1)
    N = w_abs.numel()
    if N == 0:
        return torch.ones_like(weight)

    neff = _compute_neff_from_abs(w_abs)
    k = max(1, min(N, int(math.floor(beta * neff))))
    # If abs sum is zero, neff returns 0 -> clamp k to 1 => keep one arbitrary largest (which will be 0)
    vals, idx = torch.topk(w_abs, k, largest=True, sorted=False)
    mask_flat = torch.zeros(N, dtype=w.dtype, device=w.device)
    mask_flat[idx] = 1.0
    return mask_flat.view_as(weight)


def _topk_mask_by_neff_per_out_channel(weight: torch.Tensor, beta: float) -> torch.Tensor:
    """
    Per-output-channel Neff pruning.
    Conv: weight shape (out, in, kH, kW)
    Linear: weight shape (out, in)  -> "out" treated as per-row
    """
    w = weight.detach()
    if w.ndim == 2:
        O, I = w.shape
        mask = torch.zeros_like(w)
        for o in range(O):
            vec = w[o].reshape(-1).abs()
            N = vec.numel()
            neff = _compute_neff_from_abs(vec)
            k = max(1, min(N, int(math.floor(beta * neff)))) if N > 0 else 0
            if k > 0:
                idx = torch.topk(vec, k, largest=True, sorted=False).indices
                mask[o].view(-1)[idx] = 1.0
        return mask
    elif w.ndim == 4:
        O, I, KH, KW = w.shape
        mask = torch.zeros_like(w)
        for o in range(O):
            vec = w[o].reshape(-1).abs()
            N = vec.numel()
            neff = _compute_neff_from_abs(vec)
            k = max(1, min(N, int(math.floor(beta * neff)))) if N > 0 else 0
            if k > 0:
                idx = torch.topk(vec, k, largest=True, sorted=False).indices
                mask[o].view(-1)[idx] = 1.0
        return mask
    else:
        # Fallback to whole-tensor if unexpected shape
        return _topk_mask_by_neff(w, beta)


def build_neff_masks(model: nn.Module, cfg: NeffConfig) -> Dict[str, torch.Tensor]:
    """
    Build masks for Conv2d/Linear weights according to Neff (β·Neff).
    """
    masks = {}
    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                if hasattr(module, "weight") and module.weight is not None:
                    wname = f"{name}.weight"
                    if cfg.scope == "per_out_channel":
                        mask = _topk_mask_by_neff_per_out_channel(module.weight, cfg.beta)
                    else:
                        mask = _topk_mask_by_neff(module.weight, cfg.beta)
                    masks[wname] = mask.to(module.weight.dtype)
                if cfg.prune_bias and hasattr(module, "bias") and module.bias is not None:
                    bname = f"{name}.bias"
                    # Bias magnitudes are generally tiny; pruning them rarely helps. Still, we support it.
                    mask_b = _topk_mask_by_neff(module.bias, cfg.beta)
                    masks[bname] = mask_b.to(module.bias.dtype)
    return masks


def apply_masks_inplace(model: nn.Module, masks: Dict[str, torch.Tensor]):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in masks:
                param.mul_(masks[name])


def attach_gradient_mask_hooks(model: nn.Module, masks: Dict[str, torch.Tensor]):
    """
    Ensure pruned weights stay zero by zeroing their gradients.
    """
    for name, p in model.named_parameters():
        if name in masks:
            m = masks[name]
            p.register_hook(lambda g, m=m: g * m)


# -----------------------------
# Training / Evaluation
# -----------------------------

def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            logits = model(x)
            loss = criterion(logits, y)
            loss_sum += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += x.size(0)
    return (loss_sum / max(1, total)), (correct / max(1, total))


def train_one_model(
    name: str,
    cfg: Dict,
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    neff_cfg: Optional[NeffConfig] = None,
    weight_decay: float = 0.0,
    grad_clip: Optional[float] = None,
) -> Dict[str, float]:
    model = model.to(device)
    epochs = int(cfg.get("epochs", 10))
    lr = float(cfg.get("lr", 3e-4))

    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    # Placeholders for masks (applied if/when pruning happens)
    masks: Dict[str, torch.Tensor] = {}

    # Initial eval
    val_loss, val_acc = evaluate(model, val_loader, device)
    print(f"[{name}] Init  | val_loss={val_loss:.4f}  val_acc={val_acc:.4f}")

    start_time = time.time()
    for epoch in range(1, epochs + 1):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            if grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            # re-apply masks post-step to keep sparsity enforced
            if masks:
                apply_masks_inplace(model, masks)

        # Scheduling: when to prune
        did_prune = False
        if neff_cfg is not None:
            if neff_cfg.prune_at_epoch is not None and epoch == int(neff_cfg.prune_at_epoch):
                masks = build_neff_masks(model, neff_cfg)
                attach_gradient_mask_hooks(model, masks)
                apply_masks_inplace(model, masks)
                did_prune = True

            if neff_cfg.prune_every is not None:
                if epoch > neff_cfg.warmup_epochs and ((epoch - neff_cfg.warmup_epochs) % neff_cfg.prune_every == 0):
                    masks = build_neff_masks(model, neff_cfg)
                    attach_gradient_mask_hooks(model, masks)
                    apply_masks_inplace(model, masks)
                    did_prune = True

        vloss, vacc = evaluate(model, val_loader, device)
        if neff_cfg and did_prune and neff_cfg.verbose:
            total, zero, sp = masked_sparsity(model, masks)
            print(f"[{name}] Epoch {epoch:02d}  val_loss={vloss:.4f}  val_acc={vacc:.4f}  "
                  f"PRUNE: β={neff_cfg.beta:.3f}, scope={neff_cfg.scope}, sparsity={sp:.3%} ({zero}/{total})")
        else:
            print(f"[{name}] Epoch {epoch:02d}  val_loss={vloss:.4f}  val_acc={vacc:.4f}")

    total_params = count_params(model)
    total, zero, sp = masked_sparsity(model, masks) if masks else (0, 0, 0.0)
    elapsed = time.time() - start_time

    print(f"[{name}] Done in {elapsed:.1f}s | params={total_params/1e6:.2f}M | sparsity={sp:.3%}")

    return {
        "val_acc": vacc,
        "val_loss": vloss,
        "params": float(total_params),
        "sparsity": float(sp),
        "elapsed_sec": float(elapsed),
    }


# -----------------------------
# High-level training harness
# -----------------------------

def build_cnn_from_config(
    in_channels: int,
    num_classes: int,
    cfg: Dict,
) -> nn.Module:
    channels = cfg.get("hidden_size", None) or cfg.get("channels", None)
    if channels is None:
        raise ValueError("Config must include 'hidden_size' (interpreted as conv channels list) or 'channels'.")
    dropout = float(cfg.get("dropout", 0.0))
    return SimpleCNN(in_channels, num_classes, channels=channels, dropout=dropout)


def train_cnn_suite(
    model_configs: Dict[str, Dict],
    train_loader: DataLoader,
    val_loader: DataLoader,
    in_channels: int,
    num_classes: int,
    device: Optional[torch.device] = None,
    use_models: Optional[List[str]] = None,
    neff_cfg: Optional[NeffConfig] = None,
    weight_decay: float = 0.0,
    grad_clip: Optional[float] = None,
) -> Dict[str, Dict[str, float]]:
    """
    Trains a selection of models in model_configs and returns a dict of final metrics per model.
    """
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}
    names = use_models or list(model_configs.keys())

    for name in names:
        cfg = model_configs[name]
        model = build_cnn_from_config(in_channels, num_classes, cfg)
        print(f"\n=== Training {name} ===")
        print(model)
        stats = train_one_model(
            name=name,
            cfg=cfg,
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device,
            neff_cfg=neff_cfg,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
        )
        results[name] = stats
    return results


# -----------------------------
# Example config (copy/paste yours or modify)
# -----------------------------

model_configs = {
    # Underfit & brittle
    'Tiny_Underfit': {
        'hidden_size': [64],
        'lr': 3e-4,
        'epochs': 10,
        'dropout': 0.0
    },
    # Deep-narrow (depth sensitivity)
    'Deep_Narrow': {
        'hidden_size': [128, 128, 128, 128, 128, 128, 128, 128],
        'lr': 3e-4,
        'epochs': 15,
        'dropout': 0.2
    },
    # Well-trained baseline
    'Balanced': {
        'hidden_size': [512, 256],
        'lr': 3e-4,
        'epochs': 15,
        'dropout': 0.2
    },
    # Deep but still robust
    'Balanced_Deep': {
        'hidden_size': [512, 256, 128, 64],
        'lr': 3e-4,
        'epochs': 20,
        'dropout': 0.3
    },
    # Overparameterized (note: very high channel counts may be memory-heavy)
    'Wide': {
        'hidden_size': [2048, 1024],
        'lr': 1e-3,
        'epochs': 30,
        'dropout': 0.0
    },
    # Very overparameterized (optional)
    'Very_Wide': {
        'hidden_size': [4096, 2048, 1024, 512],
        'lr': 1e-3,
        'epochs': 50,
        'dropout': 0.0
    },
}


# -----------------------------
# How you'll call this (you provide loaders)
# -----------------------------
if __name__ == "__main__":
    seed_everything(0)

    # You supply these:
    # train_loader: DataLoader over (image, label)
    # val_loader:   DataLoader over (image, label)
    #
    # Example:
    # train_loader, val_loader = make_your_fashionmnist_loaders(...)
    # OR
    # train_loader, val_loader = make_your_cifar10_loaders(...)

    # Placeholder to show interface (remove after wiring your loaders)
    raise SystemExit(
        "Wire your own train_loader/val_loader, then call train_cnn_suite(...). "
        "See below for the exact call signature."
    )

    # After you create dataloaders, infer shapes (optional)
    # sample_x, sample_y = next(iter(train_loader))
    # in_channels = sample_x.shape[1]
    # num_classes = int(torch.max(sample_y).item()) + 1  # or known (10 for FMNIST/CIFAR-10)

    # Or set explicitly for each dataset:
    # Fashion-MNIST
    # in_channels, num_classes = 1, 10
    # CIFAR-10
    # in_channels, num_classes = 3, 10

    # Choose five (skip Very_Wide by default)
    # use_models = ['Tiny_Underfit', 'Deep_Narrow', 'Balanced', 'Balanced_Deep', 'Wide']

    # Neff pruning config examples:
    # 1) One-shot β·Neff at epoch 1 (after warmup step)
    # neff_cfg = NeffConfig(beta=1.0, scope="tensor", prune_at_epoch=1, verbose=True)
    #
    # 2) Periodic β·Neff every 2 epochs after 1 warmup epoch, per-out-channel masks
    # neff_cfg = NeffConfig(beta=1.0, scope="per_out_channel", prune_every=2, warmup_epochs=1, verbose=True)
    #
    # Run:
    # results = train_cnn_suite(
    #     model_configs=model_configs,
    #     train_loader=train_loader,
    #     val_loader=val_loader,
    #     in_channels=in_channels,
    #     num_classes=num_classes,
    #     device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    #     use_models=use_models,
    #     neff_cfg=neff_cfg,
    #     weight_decay=0.0,
    #     grad_clip=None,
    # )
    # print("Final results:", results)
