# From Pixels to Prompts: <br><small>A Crash Course on Zero-Shot Classification Using Vision-Language Models</small>

### Imports

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import nn
import open_clip
from typing import List, Optional, Callable
from tqdm import tqdm

from feat_eng import IMAGENET_CLASSES, IMAGENET_TEMPLATES
from utils import DPATH_VALID, init_dataloader, batch_prec1, init_resnet50, init_vlm, print_eval_header

import pdb

### Evaluation Config

In [None]:
BATCH_SIZE = 2048
N_WORKERS  = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
@torch.no_grad()
def run_inference(model:          nn.Module, 
                  dataloader:     DataLoader, 
                  device:         torch.device, 
                  model_name:     str, 
                  class_embs_txt: Optional[torch.Tensor] = None) -> None:
    
    # (1) Initialize some parameters for performance bookkeeping
    n_samps   = 0
    prec1_sum = 0.0
    # (2) Iterate over validation data, loading up a batch of B images and corresponding targets at a time
    for imgs_b, targs_enc_b in tqdm(dataloader, desc="Evaluation"):
        """
        DELETE
        imgs_b -------- Tensor(B, 3, 224, 224)
        targs_enc_b --- Tensor(B)
        """
        # (3) Number of samples in the current batch
        B           = targs_enc_b.size(0)
        # (4) Send batched images and targets tensors to GPU
        imgs_b      = imgs_b.to(device, non_blocking=True)
        targs_enc_b = targs_enc_b.to(device, non_blocking=True)

        # (5) Perform batch inference with ResNet-50 or VLM. This will be explained next.
        if model_name == "ResNet-50":
            logits = batch_inference_res(model, imgs_b)
        else:
            logits = batch_inference_vlm(model, imgs_b, class_embs_txt)

        # (x) Compute Top-1 Precision (Prec@1) for batch, update Prec@1 sum and sample count
        prec1 = batch_prec1(logits, targs_enc_b)
        prec1_sum += prec1.item() * B
        n_samps += B

    print(
        f"",
        f"Prec@1: {prec1_sum / n_samps:.1%}",
        f"",
        sep="\n"
    )

def batch_inference_res(model:  nn.Module, 
                        imgs_b: torch.Tensor) -> torch.Tensor:

    logits = model(imgs_b)

    return logits

def batch_inference_vlm(model:          nn.Module, 
                        imgs_b:         torch.Tensor, 
                        class_embs_txt: torch.Tensor) -> torch.Tensor:

    embs_img_b = model.encode_image(imgs_b)  # run batch of images through image encoder to produce embeddings, a D-dimensional embedding produced for each image
    embs_img_b = F.normalize(embs_img_b, dim=1)  # embeddings are normalized to unit length
    logits     = embs_img_b @ class_embs_txt

    return logits

### ResNet-50

In [None]:
model, img_pp     = init_resnet50(device)
dataloader_resnet = init_dataloader(DPATH_VALID, img_pp, BATCH_SIZE, N_WORKERS)

model_name = "ResNet-50"
print_eval_header(model_name)
run_inference(model, dataloader_resnet, device, model_name)

### Class Prototypes

In [None]:
@torch.no_grad()
def build_class_protos_text(
    model:         nn.Module, 
    tokenizer:     Callable, 
    txts_class:    List[str], 
    txt_templates: List[str], 
    device:        torch.device) -> torch.Tensor:

    protos = []

    for txt in tqdm(txts_class, desc="Building class prototypes"):

        txts      = [temp.format(txt) for temp in txt_templates]
        toks_txts = tokenizer(txts).to(device)

        embs_txts = model.encode_text(toks_txts)
        embs_txts = embs_txts / embs_txts.norm(dim=-1, keepdim=True)

        proto = embs_txts.mean(dim=0)
        proto = proto / proto.norm()

        protos.append(proto)

    protos = torch.stack(protos, dim=1)
    
    return protos

### List Pretrained Models

In [None]:
open_clip.list_pretrained()

### VLM Configurations

In [None]:
VLM_CONFIGS = [
    ("RN50",                "openai", True,  "CLIP ResNet-50 (224px)"),
    ("ViT-B-32",            "openai", True,  "CLIP ViT-B/32 (224px)"),
    ("ViT-B-16",            "openai", True,  "CLIP ViT-B/16 (224px)"),
    ("ViT-L-14",            "openai", True,  "CLIP ViT-L/14 (224px)"),
    ("ViT-L-14-336",        "openai", True,  "CLIP ViT-L/14 (336px)"),
    ("ViT-B-16-SigLIP",     "webli",  False, "SigLIP ViT-B/16 (224px)"),
    ("ViT-B-16-SigLIP-256", "webli",  False, "SigLIP ViT-B/16 (256px)"),
    ("ViT-L-16-SigLIP-256", "webli",  False, "SigLIP ViT-L/16 (256px)"),
]

### Template Ensembles

In [None]:
ENS_RAW      = ["{}"]
ENS_STANDARD = ["a photo of a {}."]
ENS_CLIP80   = IMAGENET_TEMPLATES

In [None]:
def ens_across_models(templates: List[str]) -> None:
    """
    Benchmarks ensemble of templates across models specified in VLM_CONFIGS.
    Templates + class labels + text encoder are used to create class prototypes.
    Class prototypes are then used to performed zero-shot image classification.
    """

    for model_id, pretrained, quick_gelu, model_name in VLM_CONFIGS:
        print_eval_header(model_name)

        model, img_pp, tokenizer = init_vlm(model_id, pretrained, quick_gelu, device)
        
        dataloader     = init_dataloader(DPATH_VALID, img_pp, BATCH_SIZE, N_WORKERS)
        class_embs_txt = build_class_protos_text(model, tokenizer, IMAGENET_CLASSES, templates, device).to(device)

        run_inference(model, dataloader, device, model_name, class_embs_txt=class_embs_txt)

### Raw Label Template

In [None]:
ens_across_models(ENS_RAW)

### Standard Template

In [None]:
ens_across_models(ENS_STANDARD)

### CLIP 80 Templates

In [None]:
ens_across_models(ENS_CLIP80)