In [4]:
#!/usr/bin/env python
"""
mini_imagenet_kernels_per_model.py
----------------------------------
â€¢ Uses the open timm/miniâ€‘imagenet dataset (50â€¯k images, 100 classes)
â€¢ For each timm encoder in MODEL_NAMES:
      â€“ extracts features on N_IMAGES random samples
      â€“ builds its own cosineâ€‘similarity kernel  K = Z Záµ€
      â€“ saves to  kernels_out/K_<model>.pt
"""

import random
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import timm
from datasets import load_dataset
from PIL import Image
from tqdm.auto import tqdm

import gc


# â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€ userâ€‘tweakables â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
N_IMAGES   = 8_192
BATCH_SIZE = 1024
NUM_WORKERS = 0
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_NAMES = [
    # â”€â”€ Classic CNNs â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    #"resnet18",
    #"resnet34",
    #"resnet50",
    #"resnet101",
    #"resnet152",
    #"wide_resnet50_2",
    #"wide_resnet101_2",
    #"resnext50_32x4d",
    #"resnext101_32x8d",
    #"densenet121",
    #"densenet201",
    #"ese_vovnet39b",
    #"regnety_016",
    #"regnety_032",
    # ConvNeXt family
    #"convnext_small",
    # EfficientNet & friends
    #"efficientnet_b0",
    # Mobile / lightweight
    #"mobilenetv3_large_100",
    #"ghostnet_100",
    # NFâ€‘Nets (DeepMind)
    #"dm_nfnet_f0",
    # â”€â”€ Vision Transformers & hybrids â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    # ViT
    #"vit_base_patch16_224",
    # DeiT
    #"deit_tiny_patch16_224",
    #"deit_small_patch16_224",
    # BEiT
    #"beit_base_patch16_224",
    #"beit_large_patch16_224",
    # Swin
    #"swin_tiny_patch4_window7_224",
    # PVTâ€‘v2
    #"pvt_v2_b2",
    # CSWin
    #"cswin_tiny_224",
    # CoAtNet
    #"coatnet_0",
    # Mixers / Convmixer
    "mixer_b16_224",
    # GC ViT
    "gcvit_base",
    # ConvNeXtâ€‘v2
    "convnextv2_base",
    # CLIP ViT (image branch only)
    "clip_vit_base_patch32",
]

OUT_DIR = Path("kernels_out_mi_no_pool")
OUT_DIR.mkdir(exist_ok=True)
# â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€


# 1) Dataset -------------------------------------------------------------------
print("ðŸ“¦  loading timm/miniâ€‘imagenet â€¦")
hf_ds = load_dataset("timm/mini-imagenet", split="train")

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),   # ensure 3â€‘ch
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

class HFWrapper(Dataset):
    def __init__(self, ds, tfm):
        self.ds, self.tfm = ds, tfm
    def __len__(self): return len(self.ds)
    def __getitem__(self, i):
        item = self.ds[int(i)]
        return self.tfm(item["image"]), item["label"]

# sample once, reuse for every model
full_ds   = HFWrapper(hf_ds, transform)
indices   = random.sample(range(len(full_ds)), N_IMAGES)
subset_ds = Subset(full_ds, indices)
loader    = DataLoader(subset_ds, batch_size=BATCH_SIZE,
                       shuffle=False, num_workers=NUM_WORKERS,
                       pin_memory=True)
print(f"âœ“ dataset ready â€” {len(subset_ds)} images\n")


# 2) Feature â†’ kernel â†’ save  (one loop per model) -----------------------------

@torch.no_grad()
def features(model_name: str) -> torch.Tensor:
    model = timm.create_model(model_name, pretrained=True,
                              num_classes=0).to(DEVICE).eval()
    vecs = []
    for imgs, _ in tqdm(loader, desc=f"{model_name:>24}", leave=False):
        # Process in smaller batches if needed
        batch_output = model(imgs.to(DEVICE, non_blocking=True)).flatten(1)
        # Convert to float32 for better memory efficiency
        batch_output = batch_output.to(torch.float32)
        vecs.append(batch_output)
        # Explicitly free memory
        torch.cuda.empty_cache()
    print('done with looping')

    return torch.cat(vecs)


for m in MODEL_NAMES:
    print(f"ðŸš€  processing {m} â€¦")
    F_m = features(m)                            # (N, D_m)
    print('features computed')
    Z_m = F.normalize(F_m, p=2, dim=1)           # rowâ€‘norm
    K_m = Z_m @ Z_m.T                            # (N, N)

    torch.save(
        {"K": K_m.cpu(),                         # kernel
         "Z": F_m.cpu(),                         # normalised feats
         "dim": F_m.shape[1],                    # feature length of this model
         "indices": indices},
        OUT_DIR / f"K_{m}_{N_IMAGES}.pt"
    )
    print(f"   â†³ saved  {OUT_DIR / f'K_{m}_{N_IMAGES}.pt'}\n")

    # Clean up memory before next model
    del F_m, Z_m, K_m
    torch.cuda.empty_cache()
    gc.collect()  # Force garbage collection

print("âœ…  all kernels done.")

ðŸ“¦  loading timm/miniâ€‘imagenet â€¦
âœ“ dataset ready â€” 8192 images

ðŸš€  processing mixer_b16_224 â€¦


model.safetensors:   0%|          | 0.00/240M [00:00<?, ?B/s]

           mixer_b16_224:   0%|          | 0/8 [00:00<?, ?it/s]

done with looping
features computed
   â†³ saved  kernels_out_mi_no_pool/K_mixer_b16_224_8192.pt

ðŸš€  processing gcvit_base â€¦


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth" to /home/user/.cache/torch/hub/checkpoints/gcvit_base_224_nvidia-f009139b.pth


              gcvit_base:   0%|          | 0/8 [00:00<?, ?it/s]

done with looping
features computed
   â†³ saved  kernels_out_mi_no_pool/K_gcvit_base_8192.pt

ðŸš€  processing convnextv2_base â€¦


model.safetensors:   0%|          | 0.00/355M [00:00<?, ?B/s]

         convnextv2_base:   0%|          | 0/8 [00:00<?, ?it/s]

done with looping
features computed
   â†³ saved  kernels_out_mi_no_pool/K_convnextv2_base_8192.pt

ðŸš€  processing clip_vit_base_patch32 â€¦


RuntimeError: Unknown model (clip_vit_base_patch32)