In [1]:
#!/usr/bin/env python
"""
mini_imagenet_kernels_per_model.py
----------------------------------
• Uses the open timm/mini‑imagenet dataset (50 k images, 100 classes)
• For each timm encoder in MODEL_NAMES:
      – extracts features on N_IMAGES random samples
      – builds its own cosine‑similarity kernel  K = Z Zᵀ
      – saves to  kernels_out/K_<model>.pt
"""

import random
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import timm
from datasets import load_dataset
from PIL import Image
from tqdm.auto import tqdm


# ───────────── user‑tweakables ─────────────
N_IMAGES   = 16_384
BATCH_SIZE = 512  # Reduced batch size to lower memory footprint
FEATURE_BATCH_SIZE = 2048  # Process kernel creation in chunks
NUM_WORKERS = 1  # Reduced from 8 to lower memory pressure
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_NAMES = [
    # CNNs
    "resnet50", "resnet152",
    "convnext_base",
    "efficientnet_b5",
    # ViT / MLP‑Mixer style
    "vit_base_patch16_224",
    "deit_base_patch16_224",
    "swin_base_patch4_window7_224"
]

OUT_DIR = Path("kernels_out_more_images")
OUT_DIR.mkdir(exist_ok=True)
# ───────────────────────────────────────────


# 1) Dataset -------------------------------------------------------------------
print("📦  loading timm/mini‑imagenet …")
hf_ds = load_dataset("timm/mini-imagenet", split="train")

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),   # ensure 3‑ch
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

class HFWrapper(Dataset):
    def __init__(self, ds, tfm):
        self.ds, self.tfm = ds, tfm
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        item = self.ds[int(i)]
        return self.tfm(item["image"]), item["label"]

# sample once, reuse for every model
full_ds   = HFWrapper(hf_ds, transform)
indices   = random.sample(range(len(full_ds)), N_IMAGES)
subset_ds = Subset(full_ds, indices)
loader    = DataLoader(subset_ds, batch_size=BATCH_SIZE,
                       shuffle=False, num_workers=NUM_WORKERS,
                       pin_memory=True)
print(f"✓ dataset ready — {len(subset_ds)} images\n")


# 2) Feature → kernel → save  (one loop per model) -----------------------------

# 2) Feature → kernel → save  (one loop per model) -----------------------------
@torch.no_grad()
def features(model_name: str) -> torch.Tensor:
    model = timm.create_model(model_name, pretrained=True,
                              num_classes=0, global_pool="").to(DEVICE).eval()
    vecs = []
    for imgs, _ in tqdm(loader, desc=f"{model_name:>24}", leave=False):
        vecs.append(model(imgs.to(DEVICE, non_blocking=True)).flatten(1).cpu())
    Z = torch.cat(vecs)                         # (N, D)
    return F.normalize(Z, p=2, dim=1)           # row-normalize

def compute_kernel_in_chunks(Z):
    """Compute kernel matrix in chunks to save memory"""
    n = Z.shape[0]
    K = torch.zeros((n, n), dtype=Z.dtype)
    
    for i in tqdm(range(0, n, FEATURE_BATCH_SIZE), desc="Computing kernel chunks"):
        end_idx = min(i + FEATURE_BATCH_SIZE, n)
        # Compute one chunk of rows at a time
        K[i:end_idx] = Z[i:end_idx] @ Z.T
        # Explicitly delete intermediate tensors
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return K

for m in MODEL_NAMES:
    print(f"🚀  processing {m} …")
    
    # Process and immediately normalize features
    Z_m = features(m)                            # (N, D_m), already normalized
    print(f"   ↳ features extracted, shape: {Z_m.shape}")
    
    # Compute kernel matrix in chunks
    K_m = compute_kernel_in_chunks(Z_m)
    print(f"   ↳ kernel computed, shape: {K_m.shape}")
    
    # Save immediately to free memory
    torch.save(
        {"K": K_m.cpu(),                         # kernel
         "Z": Z_m.cpu(),                         # normalised feats
         "dim": Z_m.shape[1],                    # feature length of this model
         "indices": indices},
        OUT_DIR / f"K_{m}_{N_IMAGES}.pt"
    )
    print(f"   ↳ saved  {OUT_DIR / f'K_{m}_{N_IMAGES}.pt'}")
    
    # Explicitly free memory
    del Z_m, K_m
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    import gc
    gc.collect()
    print()

print("✅  all kernels done.")

📦  loading timm/mini‑imagenet …
✓ dataset ready — 16384 images

🚀  processing resnet50 …


                resnet50:   0%|          | 0/32 [00:00<?, ?it/s]

: 