# From Pixels to Prompts: <br><small>A Crash Course on Zero-Shot Classification Using Vision-Language Models</small>

## Evaluation Utilities

### Imports

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
import open_clip
from typing import List, Optional, Callable
from tqdm import tqdm

from feat_eng import LABELS, TEMPLATES
from utils import (
    DPATH_VALID, 
    init_dataloader, 
    batch_prec1, 
    init_resnet50, 
    init_vlm, 
    print_eval_header,
)

import pdb

### Hardware Config

In [None]:
BATCH_SIZE = 2048
N_WORKERS  = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Batched Inference: ResNet-50

In [None]:
def batch_inference_res(
    model:  nn.Module, 
    imgs_b: torch.Tensor,
) -> torch.Tensor:
    """
    Run inference on mini-batch of images with ResNet-50 to produce logits.

    Args:
        model ---- ResNet-50
        imgs_b --- Mini-batch of images, tensor of shape (B, C, H, W)

    Returns:
        Logits tensor of shape (B, L)
    """

    logits = model(imgs_b)

    return logits

### Batched Inference: VLM

In [None]:
def batch_inference_vlm(
    model:      nn.Module, 
    imgs_b:     torch.Tensor, 
    protos_txt: torch.Tensor,
) -> torch.Tensor:
    """
    Run batch inference on mini-batch of images using VLM image encoder + class prototype text embeddings to produce logits.

    Args:
        model -------- VLM
        imgs_b ------- Mini-batch of images, tensor of shape (B, C, H, W)
        protos_txt --- Class prototype text embeddings used by VLM to perform image classification, tensor of shape (L, D)

    Returns:
        Logits tensor of shape (B, L)
    """
    # (x) ... tensor of shape (B, D)
    embs_img_b = model.encode_image(imgs_b)  # run batch of images through image encoder to produce embeddings, a D-dimensional embedding produced for each image
    embs_img_b = F.normalize(embs_img_b, dim=1)  # embeddings are normalized to unit length
    # (x) embeddings are D-dimensional vectors on the unit hypersphere. Embeddings that point in a more similar direction will have higher cosine similarity.
    logits     = embs_img_b @ protos_txt.T

    return logits

### Evaluation Loop

In [None]:
@torch.no_grad()
def run_inference(
    model:      nn.Module, 
    dataloader: DataLoader, 
    model_name: str, 
    protos_txt: Optional[torch.Tensor] = None,
) -> None:
    """
    Run evaluation on ImageNet1k validation set and report Top-1 Precision (Prec@1).
    
    Args:
        model -------- ResNet-50 or VLM
        dataloader --- Iterable yielding mini-batches of images and corresponding target encodings
        model_name --- Model display name, used here to choose ResNet-50 or VLM batched inference function
        protos_txt --- Class prototype text embeddings used by VLM to perform image classification, tensor of shape (L, D)
    """
    # (x) Initialize some parameters for performance bookkeeping
    n_samps   = 0
    prec1_sum = 0.0
    # (x) Iterate over validation data, loading up a batch of B images and corresponding targets at a time; imgs_b = tensor of shape (B, C, H, W), targs_b = tensor of shape (B)
    for imgs_b, targs_b in tqdm(dataloader, desc="Evaluation"):
        # (x) Send batched image and target tensors to GPU
        imgs_b  = imgs_b.to(device, non_blocking=True)
        targs_b = targs_b.to(device, non_blocking=True)

        # (x) Perform batch inference with ResNet-50 or VLM to produce logits tensor of shape (B, L). If `model_name` is not "ResNet-50", we assume the model is a VLM.
        if model_name == "ResNet-50":
            logits = batch_inference_res(model, imgs_b)
        else:
            logits = batch_inference_vlm(model, imgs_b, protos_txt)

        # (x) Compute Top-1 Precision (Prec@1) for batch, update per-sample Prec@1 sum and sample count
        prec1 = batch_prec1(logits, targs_b)
        # (x) Number of samples in the current batch
        B     = targs_b.size(0)
        prec1_sum += prec1.item() * B
        n_samps += B

    print(
        f"",
        f"Prec@1: {prec1_sum / n_samps:.1%}",
        f"",
        sep="\n"
    )

## ResNet-50

In [None]:
model_name = "ResNet-50"
print_eval_header(model_name)

model, img_pp = init_resnet50(device)
dataloader    = init_dataloader(DPATH_VALID, img_pp, BATCH_SIZE, N_WORKERS)

run_inference(model, dataloader, model_name)

## VLMs

### Zero-Shot Classifier

In [None]:
def encode_txts(
    model:     nn.Module,
    tokenizer: Callable,
    txts:      List[str],
) -> torch.Tensor:
    """
    Encode text prompts.
    P text prompts are encoded into P text embeddings and normalized to unit length.

    Args:
        model ------- VLM
        tokenizer --- Tokenizer corresponding to VLM
        txts -------- Text prompts

    Returns:
        Text embeddings tensor of shape (P, D)
    """
    # (x) token embedding tensor of shape (L, T)
    toks_txt = tokenizer(txts).to(device)

    # (x) tensor of shape (L, D)
    embs_txt = model.encode_text(toks_txt)
    # (x) normalize text embeddings to unit length
    embs_txt = F.normalize(embs_txt, dim=1)

    return embs_txt

### List Pretrained Models

In [None]:
open_clip.list_pretrained()

### Flagship CLIP Config

In [None]:
model_id   = "ViT-L-14-336"
pretrained = "openai"
quick_gelu = True
model_name = "CLIP ViT-L/14 (336px)"

### Zero-Shot Classification: Raw Labels

In [None]:
print_eval_header(model_name)

model, img_pp, tokenizer = init_vlm(model_id, pretrained, quick_gelu, device)

dataloader = init_dataloader(DPATH_VALID, img_pp, BATCH_SIZE, N_WORKERS)
protos_txt = encode_txts(model, tokenizer, LABELS).to(device)  # tensor of shape (L, D)

run_inference(model, dataloader, model_name, protos_txt=protos_txt)

### Zero-Shot Classification: Standard Template

In [None]:
print_eval_header(model_name)

model, img_pp, tokenizer = init_vlm(model_id, pretrained, quick_gelu, device)

dataloader = init_dataloader(DPATH_VALID, img_pp, BATCH_SIZE, N_WORKERS)
prompts    = ["a photo of a {}.".format(label) for label in LABELS]
protos_txt = encode_txts(model, tokenizer, prompts).to(device)

run_inference(model, dataloader, model_name, protos_txt=protos_txt)

### Prompt Ensembling

In [None]:
# NEED BETTER FUNCTION NAME FOR THIS
@torch.no_grad()
def build_ensemble_prototypes(
    model:     nn.Module,
    tokenizer: Callable,
    labels:    List[str],
    temps:     List[str],
) -> torch.Tensor:
    """
    Build per-class prototype embeddings in the text space.
    A list of templates is applied to each class label. Resulting prompts are encoded with the VLM text encoder and normalized
    to unit length to produce per-template text embeddings. These embeddings are then averaged and normalized to unit length
    to produce class prototype. This process is performed for each class.

    Args:
        model ------- VLM
        tokenizer --- Tokenizer corresponding to VLM
        labels ------ Class labels
        temps ------- Prompt templates with a "{}" placeholder for class labels

    Returns:
        Class prototypes tensor of shape (L, D)
    """
    # (x) we initialize a list to cache prototype embeddings ("cache" correctly used here?)
    protos = []

    # (x) we iterate over class labels
    for txt in tqdm(labels, desc="Building prototypes"):

        txts     = [temp.format(txt) for temp in temps]
        embs_txt = encode_txts(model, tokenizer, txts)  # tensor of shape (P, D)

        # (x) take the average of the embeddings to produce a prototype embedding of dimension (D)
        proto = embs_txt.mean(dim=0)  # tensor of shape (D)
        # (x) normalize prototype embedding to unit length
        proto = F.normalize(proto, dim=0)

        protos.append(proto)

    # (x) we convert the list of 1D tensors into a 2D tensor of shape (L, D)
    protos = torch.stack(protos)

    return protos

### VLM Configurations

In [None]:
VLM_CONFIGS = [
    ("RN50",                "openai", True,  "CLIP ResNet-50 (224px)"),
    ("ViT-B-32",            "openai", True,  "CLIP ViT-B/32 (224px)"),
    ("ViT-B-16",            "openai", True,  "CLIP ViT-B/16 (224px)"),
    ("ViT-L-14",            "openai", True,  "CLIP ViT-L/14 (224px)"),
    ("ViT-L-14-336",        "openai", True,  "CLIP ViT-L/14 (336px)"),
    ("ViT-B-16-SigLIP",     "webli",  False, "SigLIP ViT-B/16 (224px)"),
    ("ViT-B-16-SigLIP-256", "webli",  False, "SigLIP ViT-B/16 (256px)"),
    ("ViT-L-16-SigLIP-256", "webli",  False, "SigLIP ViT-L/16 (256px)"),
]

### Zero-Shot Classification: OpenAI ImageNet1k Templates

In [None]:
for model_id, pretrained, quick_gelu, model_name in VLM_CONFIGS:
    # (brackets, or maybe up top) this code should look familiar. It is the code we used in the first two VLM evaluations, except that template ensembling is used to construct the class prototype text embeddings.
    print_eval_header(model_name)
    
    model, img_pp, tokenizer = init_vlm(model_id, pretrained, quick_gelu, device)
    
    dataloader = init_dataloader(DPATH_VALID, img_pp, BATCH_SIZE, N_WORKERS)
    protos_txt = build_ensemble_prototypes(model, tokenizer, LABELS, TEMPLATES).to(device)  # tensor of shape (L, D)

    run_inference(model, dataloader, model_name, protos_txt=protos_txt)