
# ViT-B/16 (224) with Superpixel‑Based Patch Pooling (SPPP)

This notebook implements ViT **without** `timm` and replaces standard patch embedding with **superpixel-based patch pooling** (SPPP). It also provides:
- ImageNet loaders
- Optional SLIC pre-processing and visualization
- AMP training loop with cosine schedule
- Top‑1/Top‑5 metrics
- Overhead analysis for SLIC on ImageNet‑1k

> **Note:** `skimage.segmentation.slic` is CPU‑only. For best throughput, precompute superpixels for validation (and optionally for training with deterministic transforms).


In [1]:
# ==========================
# Imports — optimized for fast I/O and tensor-based transforms
# ==========================

# Core utilities
import os, math, json, time, random, hashlib
from pathlib import Path
from typing import Optional, Tuple, List

# Numeric & torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Vision + transforms
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

# Image handling (RGBA-aware)
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True        # prevent hangs on partial PNGs

# Visualization
import matplotlib.pyplot as plt

# Tensor rearrangement helpers
from einops import rearrange

# --- Optional performance flags ---
torch.backends.cudnn.benchmark = True          # let cuDNN pick fastest kernels
torch.set_float32_matmul_precision('medium')     # torch ≥ 2.0
torch._dynamo.config.capture_scalar_outputs = True
import torch._dynamo
torch._dynamo.config.suppress_errors = True

# --- Notes ---
# • No need for skimage.segmentation.slic or mark_boundaries — SLIC is pre-baked.
# • torchvision.io.read_image() drops alpha, so we’ll load with PIL for RGBA.
# • Rest of pipeline remains the same: we’ll split RGB and SLIC channels later.


In [2]:
print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
else:
    print("Running on CPU (this will be slow).")

PyTorch: 2.7.0+cu128
CUDA available: True
GPU: NVIDIA H100 NVL


In [3]:
# ==========================
# Global Configuration (Updated for 4-Channel PNG Dataset)
# ==========================

from pathlib import Path
import torch

class CFG:
    # --- Data & model params ---
    img_size: int = 224
    patch_size: int = 16
    in_chans: int = 3                  # ViT still receives only RGB channels
    embed_dim: int = 768               # ViT-B
    depth: int = 12
    num_heads: int = 12
    mlp_ratio: float = 4.0
    num_classes: int = 1000

    # --- SLIC info (for documentation) ---
    # These are no longer *used for generation* — they describe what’s encoded
    # in the 4th channel of each PNG.
    num_superpixels: int = 196         # expected number of regions per image
    compactness: float = 0.1
    slic_sigma: float = 1.0
    pooling_type: str = "mean"         # mean | max (for later region pooling)

    # --- Training params ---
    epochs: int = 10
    batch_size: int = 768
    num_workers: int = 16
    lr: float = 3e-4
    weight_decay: float = 0.01
    drop_rate: float = 0.05
    attn_drop_rate: float = 0.05
    label_smoothing: float = 0.0
    stoch_depth: float = 0.0

    # --- System ---
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # --- Paths ---
    # Folder containing 4-channel PNGs (RGB + SLIC in alpha)
    imagenet_root: Path = Path("/home/jovyan/scratch/Imagenet1K/ILSVRC/Data/fused")
    out_dir: Path = Path("./outputs_imageNet1K")

    # --- Performance options ---
    pin_memory: bool = True
    persistent_workers: bool = True
    prefetch_factor: int = 4
    non_blocking: bool = True
    benchmark: bool = True

# --- Prepare output directory ---
CFG.out_dir.mkdir(exist_ok=True, parents=True)

# --- Optional: enable benchmark mode ---
if CFG.benchmark:
    torch.backends.cudnn.benchmark = True

print(f"Using device: {CFG.device}")
print(f"Image root directory: {CFG.imagenet_root}")


Using device: cuda
Image root directory: /home/jovyan/scratch/Imagenet1K/ILSVRC/Data/fused


In [4]:
# ==========================
# Cell 2 — Patch Embedding (Conv2d, like ViT)
# ==========================
# For ViT, only the RGB channels (first 3 of the 4-channel PNG) are used here.

import torch
import torch.nn as nn

class PatchEmbedConv(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size

        # ViT-style patch projection
        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):  # x: [B, 3, H, W]  (RGB only)
        # The dataset class will already slice RGB before this
        x = self.proj(x)                    # [B, D, Gh, Gw]
        x = x.flatten(2).transpose(1, 2)    # [B, N, D]
        return x


In [5]:
# ==========================
# Cell 3 — SPPP Components (for pre-baked SLIC channel)
# ==========================
# The SLIC segmentation is now stored in the 4th channel of each PNG.
# We no longer compute it via skimage — we just decode and use it.

import torch
import torch._dynamo as dynamo

def preprocess_slic_from_png(slic_channel: torch.Tensor) -> torch.Tensor:
    """
    Prepare SLIC labels extracted from the 4th PNG channel.

    Args:
        slic_channel: [H, W] tensor, dtype=uint8 or uint16, loaded from PNG alpha.
    Returns:
        labels: [H, W] int64 tensor, standardized to 0..K-1 range.
    """
    labels = slic_channel.to(torch.int64)

    # Optional normalization (if labels are stored as grayscale intensities)
    if labels.max() > 255:  # 16-bit map
        pass  # leave as-is
    else:
        # Renumber to contiguous integers for safety
        unique_vals = torch.unique(labels)
        mapping = {v.item(): i for i, v in enumerate(unique_vals)}
        labels = torch.tensor(
            [mapping[v.item()] for v in labels.flatten()],
            device=labels.device,
            dtype=torch.int64
        ).view_as(labels)

    return labels


@dynamo.disable
def dominant_superpixel_per_patch(seg: torch.Tensor, patch_size: int, num_superpixels: int) -> torch.Tensor:
    """
    seg: [H, W] int (labels 0..K-1), on CPU or GPU
    Returns labels_per_patch: [Gh*Gw] int in 0..num_superpixels-1 (clipped)
    """
    H, W = seg.shape
    Gh, Gw = H // patch_size, W // patch_size
    seg = seg[:Gh * patch_size, :Gw * patch_size]  # safe crop

    patches = seg.view(Gh, patch_size, Gw, patch_size).permute(0, 2, 1, 3)  # [Gh,Gw,ps,ps]
    labels = patches.reshape(Gh * Gw, patch_size * patch_size)              # [N, P]

    # Safe .item() usage (outside TorchDynamo graph)
    if labels.numel() > 0:
        K = int(labels.max().cpu().item()) + 1
    else:
        K = num_superpixels

    K = max(K, num_superpixels)

    # Count label occurrences per patch
    counts = torch.zeros(labels.size(0), K, device=labels.device, dtype=torch.int32)
    counts.scatter_add_(1, labels, torch.ones_like(labels, dtype=torch.int32))

    # Dominant superpixel per patch
    dom = counts.argmax(dim=1).clamp_(0, num_superpixels - 1)
    return dom  # [N]


def pool_patch_tokens_by_superpixel(
    patch_tokens: torch.Tensor,
    labels_per_patch: torch.Tensor,
    num_superpixels: int,
    mode: str = "mean"
) -> torch.Tensor:
    """
    patch_tokens: [N, D]
    labels_per_patch: [N] int in 0..num_superpixels-1
    Returns pooled: [num_superpixels, D] (zeros for missing labels)
    """
    N, D = patch_tokens.shape
    device = patch_tokens.device
    pooled = torch.zeros(num_superpixels, D, device=device, dtype=patch_tokens.dtype)
    counts = torch.zeros(num_superpixels, 1, device=device, dtype=patch_tokens.dtype)

    if mode == "mean":
        pooled.index_add_(0, labels_per_patch, patch_tokens)
        ones = torch.ones(N, 1, device=device, dtype=patch_tokens.dtype)
        counts.index_add_(0, labels_per_patch, ones)
        pooled = torch.where(counts > 0, pooled / counts.clamp_min(1.0), pooled)
    elif mode == "max":
        pooled.fill_(-float("inf"))
        order = torch.argsort(labels_per_patch)
        labels_sorted = labels_per_patch[order]
        tokens_sorted = patch_tokens[order]
        for s in labels_sorted.unique():
            sel = tokens_sorted[labels_sorted == s]
            pooled[s] = torch.maximum(pooled[s], sel.max(dim=0).values)
        pooled[pooled == -float("inf")] = 0.0
    else:
        raise ValueError(f"Unknown pooling mode: {mode}")

    return pooled  # [R, D]


In [6]:
# ==========================
# Cell 4 — Positional Encoding (Sinusoidal)
# ==========================
# This module operates on the flattened patch tokens produced from RGB images.
# The SLIC (4th channel) does not affect positional encoding.

import math
import torch
import torch.nn as nn

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, dim: int, dropout: float = 0.0):
        super().__init__()
        self.dim = dim
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [B, N, D] — patch embeddings
        Returns:
            [B, N, D] — positionally encoded embeddings
        """
        B, N, D = x.shape
        device = x.device
        pos = torch.arange(N, device=device, dtype=torch.float32).unsqueeze(1)
        div = torch.exp(torch.arange(0, D, 2, device=device, dtype=torch.float32)
                        * (-math.log(10000.0) / D))
        pe = torch.zeros(N, D, device=device, dtype=x.dtype)
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        x = x + pe.unsqueeze(0)  # broadcast to batch
        return self.drop(x)


In [7]:
# ==========================
# Cell 5 — Transformer Encoder (ViT style, ready for RGB + optional SLIC bias)
# ==========================

import torch
import torch.nn as nn

class DropPath(nn.Module):
    """Stochastic Depth per sample (used in residual branches)."""
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # [B,1,1] broadcast
        random_tensor = x.new_empty(shape).bernoulli_(keep_prob).div_(keep_prob)
        return x * random_tensor


class MLP(nn.Module):
    """Feed-forward network inside transformer block."""
    def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    """Multi-Head Self Attention."""
    def __init__(self, dim: int, num_heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, heads, N, head_dim]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    """Transformer encoder block with residual connections and stochastic depth."""
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads, attn_drop, drop)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x


In [8]:
# ==========================
# Cell 6 — SPPPViT (final version for aligned 4-channel PNG dataset)
# ==========================

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
class SPPPViT(nn.Module):
    """
    Superpixel-Pooled Vision Transformer.
    - RGB passes through standard ViT patch embedding.
    - SLIC map (4th PNG channel) guides token pooling via superpixel regions.
    """

    def __init__(self, cfg: CFG):
        super().__init__()
        self.cfg = cfg

        # --- Patch embedding (RGB only) ---
        self.patch_embed = PatchEmbedConv(
            cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
        )

        # --- Tokens & positional encoding ---
        self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim))
        self.pos_embed = SinusoidalPositionalEncoding(cfg.embed_dim, dropout=0.0)

        # --- Transformer backbone ---
        self.blocks = nn.ModuleList([
            Block(
                cfg.embed_dim,
                cfg.num_heads,
                cfg.mlp_ratio,
                drop=cfg.drop_rate,
                attn_drop=cfg.attn_drop_rate,
                drop_path=getattr(cfg, "stoch_depth", 0.0)
            )
            for _ in range(cfg.depth)
        ])
        self.norm = nn.LayerNorm(cfg.embed_dim)
        self.head = nn.Linear(cfg.embed_dim, cfg.num_classes)

        # --- Initialization ---
        nn.init.normal_(self.cls_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    # ----------------------------------------------------
    # Forward: single image (RGB + SLIC)
    # ----------------------------------------------------
    def forward_single(self, rgb_img: torch.Tensor, slic_map: torch.Tensor) -> torch.Tensor:
        """
        rgb_img : [1,3,H,W]  float32 in [0,1]
        slic_map: [H,W]      int64  superpixel IDs
        """
        patch_tokens = self.patch_embed(rgb_img).squeeze(0)  # [N,D]

        labels_per_patch = dominant_superpixel_per_patch(
            slic_map.to(patch_tokens.device, non_blocking=True),
            self.cfg.patch_size,
            self.cfg.num_superpixels,
        )

        pooled = pool_patch_tokens_by_superpixel(
            patch_tokens,
            labels_per_patch,
            self.cfg.num_superpixels,
            self.cfg.pooling_type,
        )  # [R,D]

        # Learnable normalization for stability
        return F.layer_norm(pooled, normalized_shape=(pooled.shape[-1],))

    # ----------------------------------------------------
    # Forward: batch mode
    # ----------------------------------------------------
    def forward(self, imgs: torch.Tensor, slic_maps: torch.Tensor) -> torch.Tensor:
        """
        imgs      : [B,3,H,W]
        slic_maps : [B,H,W]
        Returns logits: [B,num_classes]
        """
        ##start = time.time()
        B = imgs.size(0)
        pooled_batch = []

        for i in range(B):
            pooled = self.forward_single(imgs[i:i+1], slic_maps[i])
            pooled_batch.append(pooled)
        ##print("⏱️ Pooling time:", time.time() - start, "s")
        tokens = torch.stack(pooled_batch, dim=0)            # [B,R,D]
        cls_token = self.cls_token.expand(B, -1, -1)         # [B,1,D]
        x = torch.cat([cls_token, tokens], dim=1)            # [B,1+R,D]
        x = self.pos_embed(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        logits = self.head(x[:, 0])
        return logits


print("✅ SPPPViT finalized — fully aligned with JointTransform + ImageNet4Ch pipeline.")


✅ SPPPViT finalized — fully aligned with JointTransform + ImageNet4Ch pipeline.


In [9]:
# ==========================
# Cell 07 — ImageNet dataset wrapper for 4-channel PNGs (fast, no caching)
# ==========================

from torch.utils.data import Dataset
import torch
import torchvision.io as io
import os
from pathlib import Path

class ImageNet4Ch(Dataset):
    """
    Loads 4-channel PNGs (RGB + SLIC in alpha channel).
    Works with either:
        1. root/split/class_name/image.png
        2. root/split/image_<class>.png  (flat folder)
    """

    def __init__(self, root: str, split: str, transform, img_size: int):
        self.root = Path(root)
        self.split = split
        self.img_size = img_size
        self.transform = transform
        self.split_dir = self.root / split

        # Detect structure safely (ignore hidden/system folders)
        subdirs = [d for d in self.split_dir.iterdir() if d.is_dir() and not d.name.startswith(".")]
        self.flat = len(subdirs) == 0

        # Build (path, label) list
        self.samples, self.class_to_idx = self._gather_samples()

    def _gather_samples(self):
        samples, class_to_idx = [], {}
        if self.flat:
            # Flat folder — infer class from filename prefix before first underscore
            all_pngs = list(self.split_dir.glob("*.png"))
            for p in all_pngs:
                fname = p.name
                class_name = fname.split("_")[0]
                if class_name not in class_to_idx:
                    class_to_idx[class_name] = len(class_to_idx)
                samples.append((str(p), class_to_idx[class_name]))
        else:
            # Folder-per-class layout
            for cls in sorted(os.listdir(self.split_dir)):
                cls_dir = self.split_dir / cls
                if not cls_dir.is_dir():
                    continue
                idx = len(class_to_idx)
                class_to_idx[cls] = idx
                for f in cls_dir.glob("*.png"):
                    samples.append((str(f), idx))
        return samples, class_to_idx

    def __len__(self):
        return len(self.samples)

    # --- ⚡ Fast C++ PNG reader ---
    def _load_4ch_png(self, path: str):
        """
        Load RGBA PNG via torchvision's libpng backend (no PIL/NumPy).
        Returns (rgb_t [3,H,W] float32 0-1, slic_t [H,W] int64)
        """
        img = io.read_image(str(path), mode=io.ImageReadMode.UNCHANGED)  # [C,H,W], uint8
        if img.size(0) != 4:
            raise ValueError(f"{path} expected RGBA, got {img.size(0)} channels")

        rgb_t = img[:3].float() / 255.0
        slic_t = img[3].to(torch.int64)
        return rgb_t, slic_t

    def __getitem__(self, idx: int):
        """Return one RGB+SLIC sample (no caching)."""
        path, target = self.samples[idx]
        rgb_t, slic_t = self._load_4ch_png(path)

        if self.transform is not None:
            rgb_t, slic_t = self.transform(rgb_t, slic_t)

        return rgb_t, target, slic_t


In [9]:
# ==========================
# Cell 07 — ImageNet dataset wrapper for 4-channel PNGs (no caching)
# ==========================

from torch.utils.data import Dataset
from PIL import Image
import torch
import numpy as np
import os
from pathlib import Path

class ImageNet4Ch(Dataset):
    """
    Loads 4-channel PNGs (RGB + SLIC in alpha channel).
    Works with either:
        1.  root/split/class_name/image.png
        2.  root/split/image_<class>.png   (flat folder)
    """

    def __init__(self, root: str, split: str, transform, img_size: int):
        self.root = Path(root)
        self.split = split
        self.img_size = img_size
        self.transform = transform
        self.split_dir = self.root / split

        # Detect structure safely (ignore hidden/system folders)
        subdirs = [d for d in self.split_dir.iterdir() if d.is_dir() and not d.name.startswith(".")]
        self.flat = len(subdirs) == 0

        # Build (path, label) list
        self.samples, self.class_to_idx = self._gather_samples()

    def _gather_samples(self):
        samples, class_to_idx = [], {}
        if self.flat:
            # Flat folder — infer class from filename prefix before first underscore
            all_pngs = list(self.split_dir.glob("*.png"))
            for p in all_pngs:
                fname = p.name
                class_name = fname.split("_")[0]
                if class_name not in class_to_idx:
                    class_to_idx[class_name] = len(class_to_idx)
                samples.append((str(p), class_to_idx[class_name]))
        else:
            # Folder-per-class layout
            for cls in sorted(os.listdir(self.split_dir)):
                cls_dir = self.split_dir / cls
                if not cls_dir.is_dir():
                    continue
                idx = len(class_to_idx)
                class_to_idx[cls] = idx
                for f in cls_dir.glob("*.png"):
                    samples.append((str(f), idx))
        return samples, class_to_idx

    def __len__(self):
        return len(self.samples)

    def _load_4ch_png(self, path: str):
        """Load RGBA PNG and split into RGB + SLIC tensors."""
        img = Image.open(path)
        if img.mode != "RGBA":
            raise ValueError(f"{path} is not 4-channel (mode={img.mode})")
        arr = np.array(img)  # [H, W, 4]
        rgb = arr[..., :3]
        slic = arr[..., 3]
        rgb_t = torch.from_numpy(rgb).permute(2, 0, 1).float().div_(255.0)
        slic_t = torch.from_numpy(slic).to(torch.int64)
        return rgb_t, slic_t

    def __getitem__(self, idx: int):
        """Return one RGB+SLIC sample (no caching)."""
        path, target = self.samples[idx]
        rgb_t, slic_t = self._load_4ch_png(path)

        if self.transform is not None:
            rgb_t, slic_t = self.transform(rgb_t, slic_t)

        return rgb_t, target, slic_t


In [10]:
# ==========================
# Cell 8 — Build dataloaders (4-channel PNGs, aligned RGB + SLIC, Windows-safe)
# ==========================

import os
import random
import torch
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader

# ------------------------------------------------------------
# Joint geometric transforms (keeps RGB↔SLIC spatial alignment)
# ------------------------------------------------------------
class JointTransform:
    """
    Applies identical geometric transforms to RGB and SLIC maps.
    RGB optionally receives mild color jitter; SLIC never does.
    """
    def __init__(self, size: int = 224, train: bool = True):
        self.size = size
        self.train = train

    def __call__(self, rgb: torch.Tensor, slic: torch.Tensor):
        # rgb: [3,H,W] float32 in [0,1]
        # slic: [H,W] int64

        # --- Random horizontal flip ---
        if self.train and random.random() < 0.5:
            rgb = TF.hflip(rgb)
            slic = TF.hflip(slic.unsqueeze(0)).squeeze(0)

        # --- Resize safeguard (should be no-op if already correct size) ---
        if rgb.shape[1:] != (self.size, self.size):
            rgb = TF.resize(rgb, [self.size, self.size], antialias=True)
            slic = TF.resize(
                slic.unsqueeze(0).float(),
                [self.size, self.size],
                interpolation=TF.InterpolationMode.NEAREST
            ).squeeze(0).long()

        # --- Optional RGB-only color jitter ---
        if self.train:
            rgb = TF.adjust_brightness(rgb, 1.0 + (random.random() - 0.5) * 0.2)
            rgb = TF.adjust_contrast(rgb, 1.0 + (random.random() - 0.5) * 0.2)

        return rgb, slic


# ------------------------------------------------------------
# Build DataLoaders
# ------------------------------------------------------------
def build_loaders(CFG):
    # --- Transforms ---
    train_tfms = JointTransform(size=CFG.img_size, train=True)
    val_tfms   = JointTransform(size=CFG.img_size, train=False)

    # --- Dataset wrappers (4-channel PNGs) ---
    train_ds = ImageNet4Ch(CFG.imagenet_root, "train", train_tfms, CFG.img_size)
    val_ds   = ImageNet4Ch(CFG.imagenet_root, "val",   val_tfms,   CFG.img_size)

    # --- Worker setup ---
    TRAIN_NUM_WORKERS = CFG.num_workers
    VAL_NUM_WORKERS   = TRAIN_NUM_WORKERS // 2 if TRAIN_NUM_WORKERS > 0 else 0

    loader_kwargs = dict(
        pin_memory=CFG.pin_memory,
        persistent_workers=True,
    )

    if TRAIN_NUM_WORKERS > 0:
        loader_kwargs["prefetch_factor"] = (
        CFG.prefetch_factor if isinstance(CFG.prefetch_factor, int) else 2
        )

    # On Linux, DO NOT force 'spawn' inside Jupyter
        import sys
        if sys.platform == "win32":
            loader_kwargs["multiprocessing_context"] = "spawn"
    else:
        loader_kwargs["prefetch_factor"] = None


    # --- DataLoaders ---
    train_loader = DataLoader(
        train_ds,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=TRAIN_NUM_WORKERS,
        **loader_kwargs,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=VAL_NUM_WORKERS,
        **loader_kwargs,
    )

    print(
        f"✅ DataLoaders ready — "
        f"train:{len(train_ds)}, val:{len(val_ds)}, "
        f"num_workers={TRAIN_NUM_WORKERS}/{VAL_NUM_WORKERS}"
    )

    return train_loader, val_loader


In [13]:
import os
print([d.name for d in Path(CFG.imagenet_root / "train").iterdir() if d.is_dir()])


['.ipynb_checkpoints']


In [10]:
'''
# ==========================
# Cell 8 — Build dataloaders & verify existing SLIC cache (Windows-safe)
# ==========================
import os
import torchvision.transforms as T
from torch.utils.data import DataLoader

def build_loaders(CFG):
    # --- Transforms ---
    train_tfms = T.Compose([
        T.RandomHorizontalFlip(),
        # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ])
    val_tfms = T.Compose([])

    # --- Dataset wrappers ---
    train_ds = ImageNetWithSLIC(
        CFG.imagenet_root, "train", train_tfms,
        CFG.img_size, CFG.num_superpixels, CFG.compactness, CFG.slic_sigma,
        CFG.slic_cache, compute_on_the_fly=False
    )
    val_ds = ImageNetWithSLIC(
        CFG.imagenet_root, "val", val_tfms,
        CFG.img_size, CFG.num_superpixels, CFG.compactness, CFG.slic_sigma,
        CFG.slic_cache, compute_on_the_fly=False
    )

    # --- Verify cache completeness (optional) ---
    def verify_precompute_split(split: str, ds: ImageNetWithSLIC):
        cache_dir = CFG.slic_cache / split
        cache_dir.mkdir(parents=True, exist_ok=True)
        cached = list(cache_dir.glob("*.npy"))
        print(f"[{split}] {len(cached)} cached files for {len(ds)} images.")
    verify_precompute_split("train", train_ds)
    verify_precompute_split("val",   val_ds)

    # --- Worker setup ---
    TRAIN_NUM_WORKERS = min(CFG.num_workers, os.cpu_count() or 8)
    VAL_NUM_WORKERS   = max(2, TRAIN_NUM_WORKERS // 2)

    # --- DataLoaders ---
    train_loader = DataLoader(
        train_ds,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=TRAIN_NUM_WORKERS,
        pin_memory=CFG.pin_memory,
        persistent_workers=False,          # force off for Windows/Jupyter
        prefetch_factor= CFG.prefetch_factor,
        multiprocessing_context="spawn",   # crucial for Windows
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=VAL_NUM_WORKERS,
        pin_memory=CFG.pin_memory,
        persistent_workers=False,
        prefetch_factor= CFG.prefetch_factor,
        multiprocessing_context="spawn",
    )

    print(f"✅ DataLoaders ready — "
          f"train:{len(train_ds)}, val:{len(val_ds)}, "
          f"num_workers={TRAIN_NUM_WORKERS}/{VAL_NUM_WORKERS}")
    return train_loader, val_loader
'''

In [10]:
'''
# ==========================
# Cell 8 — Build dataloaders & verify existing SLIC cache
# ==========================
from tqdm import tqdm
import multiprocessing as mp
import os

# --- Safe multiprocessing start (Linux) ---
#try:
#    mp.set_start_method("fork", force=True)
#except RuntimeError:
#    pass  # already set

# --- Simple tensor transforms (ViT from scratch) ---
train_tfms = T.Compose([
    T.RandomHorizontalFlip(),
    # Optional small augmentation:
    # T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])

val_tfms = T.Compose([])

assert os.path.isdir(CFG.imagenet_root), f"ImageNet root not found: {CFG.imagenet_root}"

# --- Dataset wrappers ---
train_ds = ImageNetWithSLIC(
    CFG.imagenet_root, "train", train_tfms,
    CFG.img_size, CFG.num_superpixels, CFG.compactness, CFG.slic_sigma,
    CFG.slic_cache, compute_on_the_fly=False
)
val_ds = ImageNetWithSLIC(
    CFG.imagenet_root, "val", val_tfms,
    CFG.img_size, CFG.num_superpixels, CFG.compactness, CFG.slic_sigma,
    CFG.slic_cache, compute_on_the_fly=False
)

# ---------------------------------------------------------------------
# Verify that SLIC cache exists and matches dataset
# ---------------------------------------------------------------------
def _gather_image_paths(ds: ImageNetWithSLIC) -> list[str]:
    return [os.path.abspath(p) for (p, _) in ds.base.samples]

def verify_precompute_split(split: str, ds: ImageNetWithSLIC):
    cache_dir = CFG.slic_cache / split
    cache_dir.mkdir(parents=True, exist_ok=True)
    paths = _gather_image_paths(ds)
    cached = list(cache_dir.glob("*.npy"))

    print(f"[{split}] cache check: {len(cached)} cached files for {len(paths)} images.")
    if len(cached) < len(paths):
        print(f"[{split}] ⚠️  Cache incomplete ({len(cached)}/{len(paths)}). Missing SLIC maps may raise FileNotFoundError.")
    else:
        print(f"[{split}] ✅ Cache looks complete.")

verify_precompute_split("train", train_ds)
verify_precompute_split("val",   val_ds)

# ---------------------------------------------------------------------
# DataLoaders (optimized for Linux and large batch training)
# ---------------------------------------------------------------------
TRAIN_NUM_WORKERS = min(CFG.num_workers, os.cpu_count() or 32)
VAL_NUM_WORKERS   = max(4, TRAIN_NUM_WORKERS // 4)

train_loader = DataLoader(
    train_ds,
    batch_size=CFG.batch_size,
    shuffle=True,
    num_workers=TRAIN_NUM_WORKERS,
    pin_memory=CFG.pin_memory,
    persistent_workers=CFG.persistent_workers,
    prefetch_factor=CFG.prefetch_factor,
)

val_loader = DataLoader(
    val_ds,
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=VAL_NUM_WORKERS,
    pin_memory=CFG.pin_memory,
    persistent_workers=CFG.persistent_workers,
    prefetch_factor=CFG.prefetch_factor,
)

print(f"Train/Val sizes: {len(train_ds)}, {len(val_ds)}")
print(f"num_workers={TRAIN_NUM_WORKERS}, persistent={CFG.persistent_workers}, prefetch={CFG.prefetch_factor}")
'''

[train] cache check: 34745 cached files for 34759 images.
[train] ⚠️  Cache incomplete (34745/34759). Missing SLIC maps may raise FileNotFoundError.
[val] cache check: 3923 cached files for 3923 images.
[val] ✅ Cache looks complete.
Train/Val sizes: 34759, 3923
num_workers=8, persistent=True, prefetch=2


In [None]:
'''
# ==========================
# Cell 9 — Training Utilities
# ==========================

def accuracy_topk(output, target, topk=(1,)):
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k * (100.0 / target.size(0)))
    return res

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader) -> Tuple[float, float, float]:
    model.eval()
    total_loss, total_top1, total_top5 = 0.0, 0.0, 0.0
    for imgs, labels, segs in loader:
        imgs = imgs.to(CFG.device, non_blocking=True)
        labels = labels.to(CFG.device, non_blocking=True)
        segs = segs.to(CFG.device, non_blocking=True)
        logits = model(imgs, segs)
        loss = F.cross_entropy(logits, labels)
        top1, top5 = accuracy_topk(logits, labels, (1,5))
        total_loss += loss.item()
        total_top1 += top1.item()
        total_top5 += top5.item()
    n = len(loader)
    return total_loss/n, total_top1/n, total_top5/n

def train(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader):
    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.device=='cuda'))
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.epochs)

    history = {"train_loss": [], "train_acc1": [], "val_loss": [], "val_acc1": [], "val_acc5": []}

    for epoch in range(CFG.epochs):
        model.train()
        epoch_loss, epoch_acc1 = 0.0, 0.0
        for imgs, labels, segs in train_loader:
            imgs = imgs.to(CFG.device, non_blocking=True)
            labels = labels.to(CFG.device, non_blocking=True)
            segs = segs.to(CFG.device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(CFG.device=='cuda')):
                logits = model(imgs, segs)
                loss = F.cross_entropy(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            acc1, = accuracy_topk(logits, labels, (1,))
            epoch_loss += loss.item()
            epoch_acc1 += acc1.item()

        scheduler.step()
        tl = epoch_loss/len(train_loader)
        ta1 = epoch_acc1/len(train_loader)
        vl, va1, va5 = evaluate(model, val_loader)

        history["train_loss"].append(tl)
        history["train_acc1"].append(ta1)
        history["val_loss"].append(vl)
        history["val_acc1"].append(va1)
        history["val_acc5"].append(va5)

        print(f"Epoch {epoch+1}/{CFG.epochs} | Train: loss {tl:.4f}, acc@1 {ta1:.2f} | Val: loss {vl:.4f}, acc@1 {va1:.2f}, acc@5 {va5:.2f}")

    with open(CFG.out_dir / "history.json", "w") as f:
        json.dump(history, f, indent=2)
    torch.save(model.state_dict(), CFG.out_dir / "sppp_vit_b16.pth")
    return history
    '''


In [13]:
#train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True,
#                          num_workers=16, pin_memory=True)
#imgs, labels, segs = next(iter(train_loader))
#print("Batch shapes:", imgs.shape, labels.shape, segs.shape)



In [19]:
'''
# --- Safe diagnostic version ---
batch = next(iter(train_loader))
print("Batch type:", type(batch))

if isinstance(batch, (list, tuple)):
    print("Batch length:", len(batch))
    for i, item in enumerate(batch):
        if hasattr(item, 'shape'):
            print(f"Item {i} shape:", item.shape, "| dtype:", item.dtype)
            print("Range:", item.min().item(), item.max().item())
        else:
            print(f"Item {i}:", type(item))
else:
    print("Unexpected batch type:", type(batch))
'''

Batch type: <class 'list'>
Batch length: 3
Item 0 shape: torch.Size([64, 3, 224, 224]) | dtype: torch.float32
Range: 0.0 1.0
Item 1 shape: torch.Size([64]) | dtype: torch.int64
Range: 31 995
Item 2 shape: torch.Size([64, 224, 224]) | dtype: torch.int32
Range: 0 146


### Probe A - Are logits flat?

In [14]:
'''
# After model = SPPPViT(CFG).to(CFG.device) and before training:
imgs, labels, segs = next(iter(train_loader))
imgs, labels, segs = imgs.to(CFG.device), labels.to(CFG.device), segs.to(CFG.device)
with torch.no_grad():
    logits = model(imgs[:8], segs[:8])
print("Logits std:", logits.std().item())
print("Avg max prob:", torch.softmax(logits, dim=1).max(dim=1).values.mean().item())
'''

NameError: name 'model' is not defined

In [20]:
'''
import torch

# --- 1. Ensure model exists ---
try:
    model
except NameError:
    print("Model not defined yet — creating a fresh one.")
    model = SPPPViT(CFG).to(CFG.device)
    model.eval()

# --- 2. Fetch one batch ---
imgs, labels, segs = next(iter(train_loader))
imgs, labels, segs = imgs.to(CFG.device), labels.to(CFG.device), segs.to(CFG.device)

# --- 3. Forward pass ---
with torch.no_grad():
    logits = model(imgs[:8], segs[:8])  # forward with superpixel maps
probs = torch.softmax(logits, dim=1)

# --- 4. Diagnostics ---
print(f"Logits mean: {logits.mean().item():.4f} | std: {logits.std().item():.4f}")
print(f"Average max probability (top-1 confidence): {probs.max(dim=1).values.mean().item():.4f}")
print(f"Predicted class range: {probs.argmax(dim=1).min().item()}–{probs.argmax(dim=1).max().item()}")
'''

Model not defined yet — creating a fresh one.
Logits mean: -0.0024 | std: 0.5495
Average max probability (top-1 confidence): 0.0046
Predicted class range: 221–408


### Probe B - superpixel assignment degeneracy chec

In [22]:
'''
import torch

def _dominant_labels_batch(imgs, segs):
    """Quick check of how many superpixel regions are actually used per image."""
    m = model  # your trained or freshly created model
    labs = []
    with torch.no_grad():
        for i in range(min(4, imgs.size(0))):  # check first 4 samples
            # Patchify image using the same patch_embed used by SPPPViT
            tokens = m.patch_embed(imgs[i:i+1]).squeeze(0)  # [N, D]
            # Compute dominant superpixel per patch
            dom = dominant_superpixel_per_patch(
                segs[i].long(),
                CFG.patch_size,
                CFG.num_superpixels
            )
            labs.append(dom)
    return labs


# --- Run probe on one batch ---
imgs, labels, segs = next(iter(train_loader))
imgs, segs = imgs.to(CFG.device), segs.to(CFG.device)

labs = _dominant_labels_batch(imgs, segs)

# --- Report results ---
for i, dom in enumerate(labs):
    uniq, cnt = torch.unique(dom, return_counts=True)
    print(f"Sample {i}: unique labels {len(uniq)} / {CFG.num_superpixels}; "
          f"top region sizes: {cnt.topk(min(5, cnt.numel())).values.tolist()}")
'''

Sample 0: unique labels 62 / 196; top region sizes: [21, 21, 10, 10, 8]
Sample 1: unique labels 99 / 196; top region sizes: [19, 7, 6, 5, 4]
Sample 2: unique labels 102 / 196; top region sizes: [6, 6, 6, 4, 4]
Sample 3: unique labels 99 / 196; top region sizes: [11, 10, 8, 7, 7]


### Quick sanity check for correct reading of .npy files

In [12]:
'''
sp = "train"
samples = train_ds.base.samples[:10]

for img_path, _ in samples:
    npy_path = train_ds._cache_file(img_path)
    print(f"Image: {os.path.basename(img_path)}  →  NPY: {npy_path.name if npy_path and npy_path.exists() else 'NOT FOUND'}")

'''

Image: n01440764_10026.jpg  →  NPY: n01440764_n01440764_10026.npy
Image: n01440764_10027.jpg  →  NPY: n01440764_n01440764_10027.npy
Image: n01440764_10029.jpg  →  NPY: n01440764_n01440764_10029.npy
Image: n01440764_10040.jpg  →  NPY: n01440764_n01440764_10040.npy
Image: n01440764_10042.jpg  →  NPY: n01440764_n01440764_10042.npy
Image: n01440764_10043.jpg  →  NPY: n01440764_n01440764_10043.npy
Image: n01440764_10048.jpg  →  NPY: n01440764_n01440764_10048.npy
Image: n01440764_10066.jpg  →  NPY: n01440764_n01440764_10066.npy
Image: n01440764_10074.jpg  →  NPY: n01440764_n01440764_10074.npy
Image: n01440764_1009.jpg  →  NPY: n01440764_n01440764_1009.npy


In [15]:
'''
for i in range(3):
    img, label, seg = train_ds[i]
    print(f"✅ Loaded sample {i}: img {img.shape}, seg {seg.shape}, label {label}")
'''

✅ Loaded sample 0: img torch.Size([3, 224, 224]), seg torch.Size([224, 224]), label 0
✅ Loaded sample 1: img torch.Size([3, 224, 224]), seg torch.Size([224, 224]), label 0
✅ Loaded sample 2: img torch.Size([3, 224, 224]), seg torch.Size([224, 224]), label 0


In [12]:
'''
import time
t0 = time.time()
for i in range(50):
    img, label, seg = train_ds[i]
    if i % 10 == 0:
        print(f"{i} samples loaded OK")
print("✅ 50 samples done in", time.time()-t0, "sec")
'''

0 samples loaded OK
10 samples loaded OK
20 samples loaded OK
30 samples loaded OK
40 samples loaded OK
✅ 50 samples done in 38.413800954818726 sec


In [None]:
'''
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # allow truncated JPEGs

def probe_first_batch(ds, batch_size):
    from pathlib import Path
    import numpy as np, os
    for i in range(batch_size):
        p, _ = ds.base.samples[i]
        base = os.path.basename(p)
        cf = ds._cache_file(p)
        print(f"[{i}] IMG={base}  SLIC={cf.name if cf else 'None'}")
        # Image check
        with Image.open(p) as im:
            im.convert("RGB").resize((ds.img_size, ds.img_size))
        # SLIC map check
        if cf is None or not cf.exists():
            raise FileNotFoundError(f"Missing SLIC map for {base}")
        seg = np.load(cf)  # will error/hang if corrupted
        if getattr(seg, "ndim", 0) != 2:
            raise ValueError(f"SLIC map not 2D for {base}: shape={getattr(seg, 'shape', None)}")
    print("✅ First batch looks OK")

probe_first_batch(train_ds, CFG.batch_size)
'''

In [17]:
'''import torch
print("CUDA visible:", torch.cuda.device_count())
print("CUDA current device:", torch.cuda.current_device())
print("CUDA name:", torch.cuda.get_device_name())


CUDA visible: 1
CUDA current device: 0
CUDA name: NVIDIA H100 NVL


In [12]:
'''
import time
t0 = time.time()
for i, (x, y, seg) in enumerate(train_loader):
    if i % 1 == 0:
        print(f"Batch {i} done in {time.time()-t0:.1f}s")
    if i == 3:  # just test 20 batches
        break
'''

Batch 0 done in 412.9s
Batch 1 done in 412.9s
Batch 2 done in 420.6s
Batch 3 done in 420.6s


In [14]:
'''
def __getitem__(self, idx: int):
    path, target = self.base.samples[idx]
    print(f"[DEBUG] Loading {os.path.basename(path)}")
    cache_f = self._cache_file(path)

    if cache_f and cache_f.exists():
        try:
            seg = np.load(cache_f)
        except Exception as e:
            print(f"[ERROR] Failed to load {cache_f.name}: {e}")
            raise
    else:
        raise FileNotFoundError(f"No SLIC map found for {path}")

    seg_t = torch.from_numpy(seg.astype(np.int32))

    img = Image.open(path).convert("RGB")
    img_t = self.transform(img) if self.transform else T.ToTensor()(img)

    return img_t, target, seg_t
'''

In [None]:
'''
for i, (x, y, seg) in enumerate(train_loader):
    print(f"Batch {i} loaded")
    if i == 2:
        break
'''

In [None]:
'''
for i, (x, y, seg) in enumerate(train_loader):
    print(f"Batch {i} loaded: {x.shape}, {seg.shape}")
    if i == 2:
        break
'''

In [None]:
'''
from torch.utils.data import DataLoader, Dataset
import torch

class Dummy(Dataset):
    def __len__(self): return 100
    def __getitem__(self, idx): return torch.tensor(idx)

if __name__ == "__main__":
    torch.multiprocessing.set_start_method('spawn', force=True)
    loader = DataLoader(Dummy(), num_workers=4)
    for x in loader:
        print(x)
        break
'''

### GPU Logger

In [11]:

import subprocess, sys, textwrap, os

LOGGER_PID_FILE = ".gpu_logger.pid"

logger_py = r"""
import time, subprocess, csv, os, math

INTERVAL = 1.0
CMD = "nvidia-smi --query-gpu=index,utilization.gpu,temperature.gpu,memory.used,memory.total --format=csv,noheader,nounits"

with open("gpu_log.csv","w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["t_sec","gpu","util_percent","temp_c","mem_used_mib","mem_total_mib"])
    t0 = time.time()
    while True:
        try:
            out = subprocess.check_output(CMD, shell=True, text=True).strip().splitlines()
            now = time.time() - t0
            for line in out:
                idx, util, temp, mu, mt = [x.strip() for x in line.split(",")]
                w.writerow([round(now,2), int(idx), int(util), int(temp), int(mu), int(mt)])
            f.flush()
            time.sleep(INTERVAL)
        except Exception:
            # if nvidia-smi isn't available, wait a bit and retry
            time.sleep(INTERVAL)
"""

p = subprocess.Popen([sys.executable, "-u", "-c", logger_py])
with open(LOGGER_PID_FILE, "w") as f:
    f.write(str(p.pid))

print(f"GPU logger started. PID={p.pid}  | writing to gpu_log.csv")


GPU logger started. PID=197938  | writing to gpu_log.csv


In [12]:

train_loader, val_loader = build_loaders(CFG)
import time
start = time.time()
imgs, lbls, segs = next(iter(train_loader))
print("Batch load time:", time.time() - start, "s")


✅ DataLoaders ready — train:1281166, val:50000, num_workers=16/8
Batch load time: 4.51272177696228 s


In [None]:
# ==========================
# Cell 10 — Training loop (Fixed full-width tqdm for Jupyter)
# ==========================
from tqdm.auto import tqdm   # auto works for both console + notebook
import os, shutil, json, time, torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

# Detect usable width for tqdm dynamically
try:
    term_width = shutil.get_terminal_size((120, 20)).columns
except Exception:
    term_width = 120  # fallback if inside Jupyter
tqdm_params = dict(dynamic_ncols=True, mininterval=0.1, maxinterval=1.0, ascii=False)

def accuracy_topk(output, target, topk=(1,)):
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k * (100.0 / target.size(0)))
    return res


@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    criterion = nn.CrossEntropyLoss(label_smoothing=getattr(CFG, "label_smoothing", 0.0))
    total_loss = total_top1 = total_top5 = 0.0
    n_batches = len(loader)

    with tqdm(loader, desc="🔍 Validating", leave=False, **tqdm_params) as vbar:
        for imgs, labels, segs in vbar:
            imgs, labels, segs = (imgs.to(CFG.device), labels.to(CFG.device), segs.to(CFG.device))
            logits = model(imgs, segs)
            loss = criterion(logits, labels)
            top1, top5 = accuracy_topk(logits, labels, (1, 5))
            total_loss += loss.item(); total_top1 += top1.item(); total_top5 += top5.item()
            vbar.set_postfix(loss=f"{loss.item():.4f}", acc1=f"{top1.item():.2f}")
    return total_loss / n_batches, total_top1 / n_batches, total_top5 / n_batches


def train_with_progress(model, train_loader, val_loader):
    scaler = torch.amp.GradScaler("cuda", enabled=(CFG.device == "cuda"))
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, betas=(0.9, 0.999))
    warmup_epochs = max(3, int(0.05 * CFG.epochs))
    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs),
            CosineAnnealingLR(optimizer, T_max=CFG.epochs - warmup_epochs, eta_min=1e-6),
        ],
        milestones=[warmup_epochs],
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=getattr(CFG, "label_smoothing", 0.0))
    history = {"train_loss": [], "train_acc1": [], "val_loss": [], "val_acc1": [], "val_acc5": []}

    for epoch in range(CFG.epochs):
        print(f"\n🚀 Epoch {epoch+1}/{CFG.epochs}", flush=True)
        model.train()
        running_loss = running_acc1 = 0.0

        with tqdm(train_loader, desc=f"🧠 Training [{epoch+1}/{CFG.epochs}]", leave=True, **tqdm_params) as pbar:
            for imgs, labels, segs in pbar:
                imgs, labels, segs = (imgs.to(CFG.device, non_blocking=True),
                                      labels.to(CFG.device, non_blocking=True),
                                      segs.to(CFG.device, non_blocking=True))
                optimizer.zero_grad(set_to_none=True)
                with torch.amp.autocast("cuda", enabled=(CFG.device == "cuda")):
                    logits = model(imgs, segs)
                    loss = criterion(logits, labels)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer); scaler.update()

                acc1, = accuracy_topk(logits, labels, (1,))
                running_loss += loss.item(); running_acc1 += acc1.item()
                pbar.set_postfix(loss=f"{loss.item():.4f}", acc1=f"{acc1.item():.2f}")

        scheduler.step()
        tl, ta1 = running_loss / len(train_loader), running_acc1 / len(train_loader)
        vl, va1, va5 = evaluate(model, val_loader)

        history["train_loss"].append(tl); history["train_acc1"].append(ta1)
        history["val_loss"].append(vl);   history["val_acc1"].append(va1)
        history["val_acc5"].append(va5)

        print(f"✅ Epoch {epoch+1}/{CFG.epochs} | "
              f"Train: loss {tl:.4f}, acc@1 {ta1:.2f} | "
              f"Val: loss {vl:.4f}, acc@1 {va1:.2f}, acc@5 {va5:.2f}")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6e}", flush=True)

    with open(CFG.out_dir / "history.json", "w") as f:
        json.dump(history, f, indent=2)
    torch.save(model.state_dict(), CFG.out_dir / "sppp_vit_b16.pth")
    return history


# --------------------------
# Run training
# --------------------------
if __name__ == "__main__":
    train_loader, val_loader = build_loaders(CFG)
    model = SPPPViT(CFG).to(CFG.device)
    torch.set_float32_matmul_precision("medium")
    history = train_with_progress(model, train_loader, val_loader)


✅ DataLoaders ready — train:1281166, val:50000, num_workers=16/8

🚀 Epoch 1/10


🧠 Training [1/10]:   0%|          | 0/1669 [00:00<?, ?it/s]

🔍 Validating:   0%|          | 0/66 [00:00<?, ?it/s]

✅ Epoch 1/10 | Train: loss 5.9624, acc@1 4.32 | Val: loss 7.7376, acc@1 0.11, acc@5 0.51
Learning rate: 1.200000e-04

🚀 Epoch 2/10


🧠 Training [2/10]:   0%|          | 0/1669 [00:00<?, ?it/s]

In [12]:
# ==========================
# Cell 10 — Training loop (Jupyter-friendly with stable progress bars)
# ==========================
from tqdm.auto import tqdm   # works in both notebooks & terminals
import json, torch, time
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from torch.utils.data import DataLoader

# ----------------------------
# Accuracy function
# ----------------------------
def accuracy_topk(output, target, topk=(1,)):
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k * (100.0 / target.size(0)))
    return res


# ----------------------------
# Validation
# ----------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader):
    model.eval()
    criterion = nn.CrossEntropyLoss(label_smoothing=getattr(CFG, "label_smoothing", 0.0))
    total_loss = total_top1 = total_top5 = 0.0
    n_batches = len(loader)

    with tqdm(loader, desc="🔍 Validating", leave=False, ncols=100) as vbar:
        for imgs, labels, segs in vbar:
            imgs = imgs.to(CFG.device, non_blocking=True)
            labels = labels.to(CFG.device, non_blocking=True)
            segs = segs.to(CFG.device, non_blocking=True)

            logits = model(imgs, segs)
            loss = criterion(logits, labels)
            top1, top5 = accuracy_topk(logits, labels, (1, 5))

            total_loss += loss.item()
            total_top1 += top1.item()
            total_top5 += top5.item()

            vbar.set_postfix(loss=f"{loss.item():.4f}", acc1=f"{top1.item():.2f}")

    return total_loss / n_batches, total_top1 / n_batches, total_top5 / n_batches


# ----------------------------
# Training
# ----------------------------
def train_with_progress(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader):
    scaler = torch.amp.GradScaler("cuda", enabled=(CFG.device == "cuda"))
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CFG.lr,
        weight_decay=CFG.weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8
    )

    warmup_epochs = max(3, int(0.05 * CFG.epochs))
    main_epochs = CFG.epochs - warmup_epochs
    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs),
            CosineAnnealingLR(optimizer, T_max=main_epochs, eta_min=1e-6),
        ],
        milestones=[warmup_epochs],
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=getattr(CFG, "label_smoothing", 0.0))
    history = {"train_loss": [], "train_acc1": [], "val_loss": [], "val_acc1": [], "val_acc5": []}

    for epoch in range(CFG.epochs):
        print(f"\n🚀 Epoch {epoch+1}/{CFG.epochs}", flush=True)
        model.train()
        running_loss = running_acc1 = 0.0

        # Proper live progress bar
        with tqdm(train_loader, desc=f"🧠 Training [{epoch+1}/{CFG.epochs}]", ncols=100, leave=True) as pbar:
            for i, (imgs, labels, segs) in enumerate(pbar):
                imgs = imgs.to(CFG.device, non_blocking=True)
                labels = labels.to(CFG.device, non_blocking=True)
                segs = segs.to(CFG.device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                with torch.amp.autocast("cuda", enabled=(CFG.device == "cuda")):
                    logits = model(imgs, segs)
                    loss = criterion(logits, labels)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()

                acc1, = accuracy_topk(logits, labels, (1,))
                running_loss += loss.item()
                running_acc1 += acc1.item()

                # update progress bar text (no heavy prints)
                pbar.set_postfix(loss=f"{loss.item():.4f}", acc1=f"{acc1.item():.2f}")

        scheduler.step()

        # Epoch summary
        train_loss = running_loss / len(train_loader)
        train_acc1 = running_acc1 / len(train_loader)
        val_loss, val_acc1, val_acc5 = evaluate(model, val_loader)

        history["train_loss"].append(train_loss)
        history["train_acc1"].append(train_acc1)
        history["val_loss"].append(val_loss)
        history["val_acc1"].append(val_acc1)
        history["val_acc5"].append(val_acc5)

        print(f"✅ Epoch {epoch+1}/{CFG.epochs} | "
              f"Train: loss {train_loss:.4f}, acc@1 {train_acc1:.2f} | "
              f"Val: loss {val_loss:.4f}, acc@1 {val_acc1:.2f}, acc@5 {val_acc5:.2f}", flush=True)
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6e}", flush=True)

    # Save logs
    with open(CFG.out_dir / "history.json", "w") as f:
        json.dump(history, f, indent=2)
    torch.save(model.state_dict(), CFG.out_dir / "sppp_vit_b16.pth")

    return history


# --------------------------
# Run training
# --------------------------
if __name__ == "__main__":
    train_loader, val_loader = build_loaders(CFG)

    model = SPPPViT(CFG).to(CFG.device)
    torch.set_float32_matmul_precision("medium")

    # Optional compile (disabled by default)
    # if hasattr(torch, "compile"):
    #     model = torch.compile(model, backend="inductor", mode="reduce-overhead")

    history = train_with_progress(model, train_loader, val_loader)


✅ DataLoaders ready — train:1281166, val:50000, num_workers=32/16

🚀 Epoch 1/10


🧠 Training [1/10]:   0%|                                                  | 0/1252 [00:01<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 910.00 MiB. GPU 0 has a total capacity of 93.10 GiB of which 385.75 MiB is free. Process 965231 has 516.00 MiB memory in use. Process 970455 has 516.00 MiB memory in use. Process 2504479 has 91.64 GiB memory in use. Of the allocated memory 90.09 GiB is allocated by PyTorch, and 900.76 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [12]:
# ==========================
# Cell 10 — Training loop (Windows + Jupyter visible)
# ==========================
from tqdm.notebook import tqdm      # <-- use notebook-aware tqdm
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from torch.utils.data import DataLoader
import time
##import sys
##import io

# Works in both Jupyter and normal Python
##if hasattr(sys.stdout, "reconfigure"):
##    sys.stdout.reconfigure(line_buffering=True)
##else:
##    sys.stdout = io.TextIOWrapper(sys.stdout.buffer, line_buffering=True)

def accuracy_topk(output, target, topk=(1,)):
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k * (100.0 / target.size(0)))
    return res


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader):
    model.eval()
    criterion = nn.CrossEntropyLoss(label_smoothing=getattr(CFG, "label_smoothing", 0.0))
    total_loss = total_top1 = total_top5 = 0.0
    n_batches = len(loader)

    for imgs, labels, segs in tqdm(loader, desc="Validating", leave=False):
        imgs = imgs.to(CFG.device, non_blocking=True)
        labels = labels.to(CFG.device, non_blocking=True)
        segs = segs.to(CFG.device, non_blocking=True)

        logits = model(imgs, segs)
        loss = criterion(logits, labels)
        top1, top5 = accuracy_topk(logits, labels, (1, 5))

        total_loss += loss.item()
        total_top1 += top1.item()
        total_top5 += top5.item()

    return total_loss / n_batches, total_top1 / n_batches, total_top5 / n_batches


def train_with_progress(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader):
    scaler = torch.amp.GradScaler('cuda', enabled=(CFG.device == "cuda"))

    optimizer = optim.AdamW(
        model.parameters(),
        lr=CFG.lr,
        weight_decay=CFG.weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8
    )

    warmup_epochs = max(3, int(0.05 * CFG.epochs))
    main_epochs = CFG.epochs - warmup_epochs
    scheduler = SequentialLR(
        optimizer,
        schedulers=[
            LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs),
            CosineAnnealingLR(optimizer, T_max=main_epochs, eta_min=1e-6),
        ],
        milestones=[warmup_epochs],
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=getattr(CFG, "label_smoothing", 0.0))
    history = {"train_loss": [], "train_acc1": [], "val_loss": [], "val_acc1": [], "val_acc5": []}

    for epoch in range(CFG.epochs):
        print(f"\n🚀 Starting Epoch {epoch+1}/{CFG.epochs}", flush=True)
        model.train()
        running_loss, running_acc1 = 0.0, 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CFG.epochs}", ncols=100, leave=True)

        for i, (imgs, labels, segs) in enumerate(pbar):
            imgs = imgs.to(CFG.device, non_blocking=True)
            labels = labels.to(CFG.device, non_blocking=True)
            segs = segs.to(CFG.device, non_blocking=True)
            import time

            ##torch.cuda.synchronize()
            ##t0 = time.time()
            with torch.no_grad():
                with torch.amp.autocast("cuda", enabled=(CFG.device == "cuda")):
                    _ = model(imgs, segs)
            ##torch.cuda.synchronize()
            ##print(f"Forward-only time (batch {i}): {time.time() - t0:.3f}s")
        
            optimizer.zero_grad(set_to_none=True)

            
            ##start = time.time()

            with torch.amp.autocast("cuda", enabled=(CFG.device == "cuda")):
                logits = model(imgs, segs)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            ##torch.cuda.synchronize()
            ##print(f"Step {i:4d} time: {time.time() - start:.3f}s")

            acc1, = accuracy_topk(logits, labels, (1,))
            running_loss += loss.item()
            running_acc1 += acc1.item()

            if i % 10 == 0:
                print(f"  iter {i:4d}/{len(train_loader)} loss={loss.item():.4f} acc1={acc1.item():.2f}", flush=True)

            pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc@1": f"{acc1.item():.2f}"})

        scheduler.step()
        tl = running_loss / len(train_loader)
        ta1 = running_acc1 / len(train_loader)
        vl, va1, va5 = evaluate(model, val_loader)

        history["train_loss"].append(tl)
        history["train_acc1"].append(ta1)
        history["val_loss"].append(vl)
        history["val_acc1"].append(va1)
        history["val_acc5"].append(va5)

        print(f"✅ Epoch {epoch+1}/{CFG.epochs} | "
              f"Train: loss {tl:.4f}, acc@1 {ta1:.2f} | "
              f"Val: loss {vl:.4f}, acc@1 {va1:.2f}, acc@5 {va5:.2f}", flush=True)
        print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6e}", flush=True)

    with open(CFG.out_dir / "history.json", "w") as f:
        json.dump(history, f, indent=2)
    torch.save(model.state_dict(), CFG.out_dir / "sppp_vit_b16.pth")
    return history


# --------------------------
# Run training
# --------------------------
if __name__ == "__main__":
    #torch.multiprocessing.freeze_support()
    #torch.multiprocessing.set_start_method("spawn", force=True)

    # -------------------------------
    # Build loaders
    # -------------------------------
    train_loader, val_loader = build_loaders(CFG)

    # -------------------------------
    # Model setup
    # -------------------------------
    model = SPPPViT(CFG)

    # Use Tensor Cores more efficiently (Ampere+ GPUs)
    torch.set_float32_matmul_precision("medium")

    # Optional: compile the model (PyTorch ≥ 2.0)
    #if hasattr(torch, "compile"):
    #    # 'inductor' is the default GPU backend
    #    model = torch.compile(model, backend="inductor", mode="max-autotune")

    # Move to device *after* compile
    model = model.to(CFG.device)

    # -------------------------------
    # Training
    # -------------------------------
    history = train_with_progress(model, train_loader, val_loader)



✅ DataLoaders ready — train:1281166, val:50000, num_workers=16/8

🚀 Starting Epoch 1/10


Epoch 1/10:   0%|                                                         | 0/10010 [00:00<?, ?it/s]

  iter    0/10010 loss=7.1322 acc1=0.00
  iter   10/10010 loss=7.0790 acc1=0.00
  iter   20/10010 loss=7.0226 acc1=0.00
  iter   30/10010 loss=7.0398 acc1=0.00
  iter   40/10010 loss=6.9733 acc1=0.00
  iter   50/10010 loss=7.0319 acc1=0.00
  iter   60/10010 loss=7.0114 acc1=0.00
  iter   70/10010 loss=7.0285 acc1=0.00
  iter   80/10010 loss=6.9835 acc1=0.00
  iter   90/10010 loss=6.9722 acc1=0.78
  iter  100/10010 loss=7.0126 acc1=0.00
  iter  110/10010 loss=6.9820 acc1=0.00
  iter  120/10010 loss=6.9276 acc1=0.78
  iter  130/10010 loss=6.9049 acc1=0.00
  iter  140/10010 loss=6.9825 acc1=0.00
  iter  150/10010 loss=6.9492 acc1=0.00
  iter  160/10010 loss=6.9646 acc1=0.78
  iter  170/10010 loss=6.9316 acc1=0.00
  iter  180/10010 loss=6.9511 acc1=0.78
  iter  190/10010 loss=6.9088 acc1=0.00
  iter  200/10010 loss=6.9403 acc1=0.00
  iter  210/10010 loss=6.9569 acc1=0.00
  iter  220/10010 loss=6.9270 acc1=0.78
  iter  230/10010 loss=6.9637 acc1=0.00
  iter  240/10010 loss=6.9672 acc1=0.00


KeyboardInterrupt: 

In [None]:
'''
# ==========================
# Cell 11 — Plot training curves
# ==========================

hist_path = CFG.out_dir / "history.json"
if hist_path.exists():
    with open(hist_path, "r") as f:
        history = json.load(f)
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(history["train_loss"], label="train")
    plt.plot(history["val_loss"], label="val")
    plt.title("Loss"); plt.legend()
    plt.subplot(1,2,2)
    plt.plot(history["train_acc1"], label="train@1")
    plt.plot(history["val_acc1"], label="val@1")
    plt.plot(history["val_acc5"], label="val@5")
    plt.title("Accuracy"); plt.legend()
    plt.show()
else:
    print("No history yet. Set RUN_TRAINING=True and run Cell 10 to train.")
'''

In [None]:
'''
# ==========================
# Cell 12 — Visualize superpixels on random train/val samples (with filenames)
# ==========================
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries

def visualize_random_samples(split: str = "val", num: int = 4, show_raw: bool = False):
    """
    Visualize random samples with SLIC superpixels overlay.
    
    Args:
        split: "train" or "val"
        num: number of random samples
        show_raw: if True, shows raw image next to overlay for each sample
    """
    ds = val_ds if split == "val" else train_ds
    idxs = random.sample(range(len(ds)), k=min(num, len(ds)))

    if show_raw:
        cols = 2
        rows = num
        figsize = (10, 4*rows)
    else:
        cols = min(num, 4)
        rows = int(math.ceil(num/cols))
        figsize = (5*cols, 5*rows)

    plt.figure(figsize=figsize)
    count = 1
    for i in idxs:
        path, target = ds.base.samples[i]        # actual file path
        fname = os.path.basename(path)           # <-- file name
        arr = ds._load_image_numpy(path)         # deterministic resized image
        cache_f = ds._cache_file(path)

        if not cache_f.exists():
            seg = slic(arr, n_segments=ds.num_superpixels,
                       compactness=ds.compactness, sigma=ds.sigma, start_label=0)
            np.save(cache_f, seg)
        else:
            seg = np.load(cache_f, allow_pickle=False)

        overlay = mark_boundaries(arr, seg)

        if show_raw:
            # Show raw image
            plt.subplot(rows, cols, count); count += 1
            plt.imshow(arr); plt.axis("off")
            plt.title(f"{split}: {fname}", fontsize=9)
            # Show overlay
            plt.subplot(rows, cols, count); count += 1
            plt.imshow(overlay); plt.axis("off")
            plt.title(f"{split} superpixels: {fname}", fontsize=9)
        else:
            plt.subplot(rows, cols, count); count += 1
            plt.imshow(overlay); plt.axis("off")
            plt.title(f"{split}: {fname}", fontsize=9)

    plt.tight_layout(); plt.show()

# Example usage
visualize_random_samples("train", num=4, show_raw=True)
visualize_random_samples("val", num=4, show_raw=False)
'''

In [15]:
'''
# ==========================
# Cell 13 — SLIC Overhead Analysis for ImageNet‑1k
# ==========================

num_images = 1_280_000  # approx ImageNet-1k train + val (rounded up)
ms_per_img = 25         # typical SLIC at 224x224, CPU
hours_1core = num_images * ms_per_img / 1000 / 3600
print(f"Estimated SLIC time on 1 CPU core: ~{hours_1core:.1f} hours")
for cores in [8, 16, 32, 64]:
    print(f"With {cores:>2} cores: ~{hours_1core/cores:.2f} hours")
print("""
Notes:
- skimage.slic is CPU-only; no official GPU implementation in torchvision/torch today.
- Precompute for validation is strongly recommended; for training, either compute on-the-fly
  or switch to deterministic transforms if you want to precompute as well.
""") 
'''

Estimated SLIC time on 1 CPU core: ~8.9 hours
With  8 cores: ~1.11 hours
With 16 cores: ~0.56 hours
With 32 cores: ~0.28 hours
With 64 cores: ~0.14 hours

Notes:
- skimage.slic is CPU-only; no official GPU implementation in torchvision/torch today.
- Precompute for validation is strongly recommended; for training, either compute on-the-fly
  or switch to deterministic transforms if you want to precompute as well.

