This tutorial assumes a general knowledge of Deep Learning (e.g. CS 7643)

Setup:
1. Pull repo via `git clone xyz`
2. etc
3. Install env
4. etc
5. etc

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
import open_clip
from typing import Dict, List, Optional, Tuple
from pathlib import Path
from tqdm import tqdm

from feat_eng import IMAGENET_CLASSES, IMAGENET_TEMPLATES

import pdb

Recommended resources:
- GPU: x
- RAM: x
- etc.

Note: The environment is not configured to support newer GPUs e.g. (H100/H200/B200)

Create and activate environment with:<br>
`conda env create -f environment.yaml`<br>
`conda activate imagenet-zeroshot`

Want to move everything not important to the learning objectives of the tutorial outta here

Dimension annotations:<br>
B: Batch dimension i.e. sample dimension<br>
D: Embedding dimension<br>
T: Context length (77 for CLIP)<br>
L: Num. classes (1000 in this case)

In [None]:
BATCH_SIZE = 2_048
# BATCH_SIZE = 8

NUM_WORKERS = 8

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dpath_imagenet = Path("/data/ai/ref-data/image/ImageNet/imagenet1k")
dpath_valid    = dpath_imagenet / "Data/CLS-LOC/val"
fpath_csv_val  = dpath_imagenet / "LOC_val_solution.csv"
fpath_synmap   = dpath_imagenet / "LOC_synset_mapping.txt"

root_ilsvrc = dpath_imagenet

In [None]:
@torch.no_grad()
def build_class_protos_text_feat_eng(model, tokenizer, txts_class: List[str], txt_templates: List[str], device: torch.device) -> torch.Tensor:

    protos = []

    for txt in tqdm(txts_class):

        txts      = [temp.format(txt) for temp in txt_templates]
        toks_txts = tokenizer(txts).to(device)

        embs_txts = model.encode_text(toks_txts)
        embs_txts = embs_txts / embs_txts.norm(dim=-1, keepdim=True)

        proto = embs_txts.mean(dim=0)
        proto = proto / proto.norm()
        protos.append(proto)

    protos = torch.stack(protos, dim=1)
        
    return protos

# METRICS

@torch.no_grad()
def topk_accuracy(logits: torch.Tensor, targs: torch.Tensor, ks=(1, 5)) -> List[float]:
    """
    Produces a list of Precision@k batch-average scores given logits and targets.
    Prec@k's included are specified with the `ks` arg, configured to Prec@1 and Prec@5 by default.
    """
    B       = targs.size(0)
    k_max   = max(ks)
    _, pred = logits.topk(k_max, dim=1)

    pred    = pred.t()
    correct = pred.eq(targs.view(1, -1))

    return [correct[:k].reshape(-1).float().sum() / B for k in ks]

# Initialize ResNet-50

In [None]:
resnet_wts = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
resnet_tf  = resnet_wts.transforms()
resnet     = torchvision.models.resnet50(weights=resnet_wts).to(device).eval()

# Initialize Data Loaders

In [None]:
# these can probably be added to the inference function

dataset_resnet    = ImageFolder(dpath_valid, transform=resnet_tf)
dataloader_resnet = DataLoader(dataset_resnet, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)

The code is arranged in such a way to highlight the similarities and differences between classic and VLM paradigms.

# ResNet-50 Inference

In [None]:
n         = 0
prec1_sum = 0.0
prec5_sum = 0.0
with torch.no_grad():
    for imgs_b, targs_enc_b in tqdm(dataloader_resnet):
        """
        imgs_b -------- Tensor(B, 3, 224, 224)
        targs_enc_b --- Tensor(B)
        """
        B           = targs_enc_b.size(0)
        imgs_b      = imgs_b.to(device, non_blocking=True)
        targs_enc_b = targs_enc_b.to(device, non_blocking=True)

        logits = resnet(imgs_b)

        prec1, prec5 = topk_accuracy(logits, targs_enc_b)
        prec1_sum += prec1.item() * B
        prec5_sum += prec5.item() * B
        n += B
print(
    f"ResNet-50:",
    f"Prec@1 --- {prec1_sum / n:.4f}",
    f"Prec@5 --- {prec5_sum / n:.4f}",
    sep="\n"
)

# VLM Inference

(1) logits are scaled with a temperature parameter during training, which adjusts the sharpness of xyz, but for inference it doesn't matter because applying the scaling doesn't change logit ranking (i.e. doesn't change ordering if logits were to be sorted)

In [None]:

@torch.no_grad()
def run_inference_vlm(model, dataloader, class_protos_txt, device, model_name):
    n_samps   = 0
    prec1_sum = 0.0
    prec5_sum = 0.0
    for imgs_b, targs_enc_b in tqdm(dataloader):
        """
        imgs_b -------- Tensor(B, 3, 224, 224)
        targs_enc_b --- Tensor(B)
        """
        B           = targs_enc_b.size(0)
        imgs_b      = imgs_b.to(device, non_blocking=True)
        targs_enc_b = targs_enc_b.to(device, non_blocking=True)

        embs_img_b = model.encode_image(imgs_b)  # run batch of images through image encoder to produce embeddings, a single D-dimensional embedding produced for each image
        embs_img_b = F.normalize(embs_img_b, dim=1)  # embeddings are normalized to unit length
        logits     = embs_img_b @ class_protos_txt  # (1) <-- make this like a red circle with a white 1 or something

        prec1, prec5 = topk_accuracy(logits, targs_enc_b)
        prec1_sum += prec1.item() * B
        prec5_sum += prec5.item() * B
        n_samps += B
    print(
        f"{model_name}:",
        f"Prec@1 --- {prec1_sum / n_samps:.4f}",
        f"Prec@5 --- {prec5_sum / n_samps:.4f}",
        sep="\n"
    )


In [None]:
vlm_configs = [
    ("ViT-B-32",            "openai", True),
    ("ViT-B-16",            "openai", True),
    ("ViT-L-14",            "openai", True),
    ("ViT-L-14-336",        "openai", True),
    ("ViT-B-16-SigLIP",     "webli",  False),
    ("ViT-B-16-SigLIP-256", "webli",  False),
    ("ViT-L-16-SigLIP-256", "webli",  False),
]

for id_model, pretrained, quick_gelu in vlm_configs:

    model, _, model_tf = open_clip.create_model_and_transforms(id_model, pretrained=pretrained, quick_gelu=quick_gelu, device=device)
    model.eval()
    tokenizer = open_clip.get_tokenizer(id_model)

    dataset    = ImageFolder(dpath_valid, transform=model_tf)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)

    class_protos_txt = build_class_protos_text_feat_eng(model, tokenizer, IMAGENET_CLASSES, IMAGENET_TEMPLATES, device).to(device)

    run_inference_vlm(model, dataloader, class_protos_txt, device, id_model)

Note: `RN50` in the printout refers to CLIP ResNet-50 and `ResNet` refers to the standalone ResNet-50

Maybe perform an ablation of the text ensemble: just encode the labels straight up and compare to the templates

Not the 76.2% reported in the seminal CLIP paper (Table 11), but close enough

oo also compare CLIP_MODEL = "RN50", PRETRAINED = "openai"

We conclude with a discussion of the results: