# From Pixels to Prompts:
## A Crash Course on Zero-Shot Classification Using Language-Vision Models

This tutorial assumes a general knowledge of deep learning and familiarity with PyTorch

Setup:
1. Pull repo via `git clone xyz`
2. etc
3. Install env
4. etc
5. etc

Recommended resources:
- GPU: x
- RAM: x
- etc.

Note: The environment does not support the newer Hopper nor Blackwell GPU models (H100/H200/B100/B200)

Create and activate environment with:<br>
`conda env create -f environment.yaml`<br>
`conda activate imagenet-zeroshot`

(Note that everything below this point is very similar to what's in the notebook)

In [1]:
import torch
import torchvision
import torch.nn.functional as F
import open_clip
from typing import Dict, List, Optional, Tuple
from pathlib import Path
from tqdm import tqdm

from feat_eng import IMAGENET_CLASSES, IMAGENET_TEMPLATES
from utils import spawn_dataloader, batch_prec1

import pdb

  from .autonotebook import tqdm as notebook_tqdm


Want to move everything not important to the learning objectives of the tutorial outta here

In [2]:
# BATCH_SIZE = 32
# BATCH_SIZE = 512
# BATCH_SIZE = 1024
BATCH_SIZE = 2048
# BATCH_SIZE = 4096  # resnet: 54.6, 
# BATCH_SIZE = 8192
# BATCH_SIZE = 16384
# BATCH_SIZE = 32768

N_WORKERS  = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dpath_valid = Path("/data/ai/ref-data/image/ImageNet/imagenet1k/Data/CLS-LOC/val")

In [3]:
@torch.no_grad()
def build_class_protos_text(
    model, 
    tokenizer, 
    txts_class: List[str], 
    txt_templates: List[str], 
    device: torch.device) -> torch.Tensor:

    protos = []

    for txt in tqdm(txts_class, desc="Creating text embeddings"):

        txts      = [temp.format(txt) for temp in txt_templates]
        toks_txts = tokenizer(txts).to(device)

        embs_txts = model.encode_text(toks_txts)
        embs_txts = embs_txts / embs_txts.norm(dim=-1, keepdim=True)

        proto = embs_txts.mean(dim=0)
        proto = proto / proto.norm()
        protos.append(proto)

    protos = torch.stack(protos, dim=1)
    
    return protos

In [4]:
@torch.no_grad()
def run_inference(model, dataloader, device, model_name, class_embs_txt=None):
    # (1) Initialize some parameters for performance bookkeeping
    n_samps   = 0
    prec1_sum = 0.0
    # (2) Iterate over validation data, loading up a batch of B images and corresponding targets at a time
    for imgs_b, targs_enc_b in tqdm(dataloader, desc="Evaluation"):
        """
        DELETE
        imgs_b -------- Tensor(B, 3, 224, 224)
        targs_enc_b --- Tensor(B)
        """
        # (3) Number of samples in the current batch
        B           = targs_enc_b.size(0)
        # (4) Send batched images and targets tensors to GPU
        imgs_b      = imgs_b.to(device, non_blocking=True)
        targs_enc_b = targs_enc_b.to(device, non_blocking=True)

        # (5) Perform batch inference with ResNet-50 or VLM. This will be explained next.
        if model_name == "ResNet-50":
            logits = batch_inference_res(model, imgs_b)
        else:
            logits = batch_inference_vlm(model, imgs_b, class_embs_txt)

        # (x) Compute Top-1 Precision (Prec@1) for batch, update Prec@1 sum and sample count
        prec1 = batch_prec1(logits, targs_enc_b)
        prec1_sum += prec1.item() * B
        n_samps += B

    print(
        f"",
        f"Prec@1: {prec1_sum / n_samps:.1%}",
        f"",
        sep="\n"
    )


@torch.no_grad()
def batch_inference_res(model, imgs_b):

    logits = model(imgs_b)

    return logits

@torch.no_grad()
def batch_inference_vlm(model, imgs_b, class_embs_txt):

    embs_img_b = model.encode_image(imgs_b)  # run batch of images through image encoder to produce embeddings, a D-dimensional embedding produced for each image
    embs_img_b = F.normalize(embs_img_b, dim=1)  # embeddings are normalized to unit length
    logits     = embs_img_b @ class_embs_txt  # (1) <-- make this like a red circle with a white 1 or something

    return logits

# Initialize ResNet-50 + DataLoader

In [5]:
resnet_wts    = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
resnet_img_pp = resnet_wts.transforms()
resnet        = torchvision.models.resnet50(weights=resnet_wts).to(device).eval()

dataloader_resnet = spawn_dataloader(dpath_valid, resnet_img_pp, BATCH_SIZE, N_WORKERS)

# ResNet-50 Inference

In [6]:
print("ResNet-50:")
run_inference(resnet, dataloader_resnet, device, "ResNet-50")

ResNet-50:


Evaluation: 100%|██████████| 25/25 [00:52<00:00,  2.09s/it]


Prec@1: 76.2%






In [7]:
# RUN INFERENCE OF SINGLE MODEL (CLIP FLAGSHIP)

# VLM Inference

`open_clip.list_pretrained()` can be executed to view pretrained VLMs available through `open_clip`. Running this function displays architectures along with the dataset it was pretrained on which is needed to initialize the VLM image preprocessor. Unfortunately, this function does not also display the recommended `quick_gelu` setting so that is something the reader will have to look up on their own per model, but in general, the CLIP architectures performed pretraining using QuickGeLU and all the others did not. Typically, models initialized with pretrained weights from OpenAI should use QuickGeLU and all others not, as we will soon see.

Interested readers can learn more about some of the more prominent open-source pretraining datasets at the following links:
* [LAION](https://laion.ai/blog/laion-5b/)
* [CommonPool](https://ar5iv.labs.arxiv.org/html/2304.14108)
* [WebLI](https://research.google/blog/pali-scaling-language-image-learning-in-100-languages/)

In [8]:
# <img src="images/jl_launcher.png">
# <center><a href="https://www.nvidia.com/dli"> <img src="images/DLI_Header.png" alt="Header" style="width: 400px;"/> </a></center>

In [9]:
open_clip.list_pretrained()

[('RN50', 'openai'),
 ('RN50', 'yfcc15m'),
 ('RN50', 'cc12m'),
 ('RN101', 'openai'),
 ('RN101', 'yfcc15m'),
 ('RN50x4', 'openai'),
 ('RN50x16', 'openai'),
 ('RN50x64', 'openai'),
 ('ViT-B-32', 'openai'),
 ('ViT-B-32', 'laion400m_e31'),
 ('ViT-B-32', 'laion400m_e32'),
 ('ViT-B-32', 'laion2b_e16'),
 ('ViT-B-32', 'laion2b_s34b_b79k'),
 ('ViT-B-32', 'datacomp_xl_s13b_b90k'),
 ('ViT-B-32', 'datacomp_m_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_image_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_text_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_s128m_b4k'),
 ('ViT-B-32', 'datacomp_s_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_image_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_text_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_basic_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_s13m_b4k'),
 ('ViT-

In [23]:
TEMPLATES_BARE = ["{}"]
TEMPLATES_BARE_PER = ["{}."]
TEMPLATES_PREP = ["a photo of a {}"]
TEMPLATES_PREP_PER = ["a photo of a {}."]
TEMPLATES_COMBO = ["{}", "{}.", "a photo of a {}", "a photo of a {}."]

TEMPS_EXP1      = [TEMPLATES_BARE, TEMPLATES_BARE_PER, TEMPLATES_PREP, TEMPLATES_PREP_PER,    TEMPLATES_COMBO]
TEMP_NAMES_EXP1 = ["Bare",         "Bare w/ period",   "Prepending",   "Prededing w/ period", "Combo"]

####

TEMPLATES_CUSTOM = [
    "a lucky {}",
    "a bad {}",
    "a good {}",
    "an excellent {}",
    "quite a poor {}",
    "what's going on with this {}",
    "a {} in the sun",
    "a {} in the shadows",
    "a big {}",
    "a little {}",
    "a top notch {}",
    "kind of looks like a {}",
    "it might be a {}",
    "one heck of a {}",
    "the best {} around",
    "found a {}",
    "look at this {}",
    "a great {}",
    "today I went to the grocery store and I saw a {} along the way there",
    "yesterday I climbed a mountain and there was a {} at the gas station on the way there",
    "today in school I learned about the ancient version of a {}",
    "today on the internet I learned about the {} factory",
    "what kind of {} is this?",
    "quite the sharp {}",
    "quite close to the {}",
    "the {} is in sight",
    "today we ran into a {}",
    "futuristic {}",
    "this {} will have the last laugh",
    "what a wild {}",
    "an untame and unruly {}",
    "a quivering {}",
    "a well placed {}",
    "a green, mean, {} cleaning machine",
    "a wild {} has appeared!",
    "{} radio",
    "the official {} convention is in town",
    "an amazing photo of a {}",
    "a decent photo of a {}",
    "a poor photo of a {}",
    "a {} at the bottom",
    "a {} at the top",
    "this {} is winning",
    "{} lol",
]
TEMPLATES_CLIP80 = IMAGENET_TEMPLATES

In [24]:
# TEMP_GROUP = TEMPS_EXP1
# TG_NAMES   = TEMP_NAMES_EXP1

# TEMP_GROUPS = [TEMPLATES_CLIP80]
# TG_NAMES   = ["CLIP 80"]

# TEMP_GROUPS = [TEMPLATES_CUSTOM]
# TG_NAMES    = ["Custom"]

TEMP_GROUPS = [TEMPLATES_CUSTOM + TEMPLATES_CLIP80]
TG_NAMES    = ["Custom + CLIP 80"]

def run_experiment(template_groups, tg_names):

    VLM_CONFIGS = [
        ("RN50",                "openai", True,  "CLIP ResNet-50 (224px)"),
        ("ViT-B-32",            "openai", True,  "CLIP ViT-B/32 (224px)"),
        ("ViT-B-16",            "openai", True,  "CLIP ViT-B/16 (224px)"),
        ("ViT-L-14",            "openai", True,  "CLIP ViT-L/14 (224px)"),
        ("ViT-L-14-336",        "openai", True,  "CLIP ViT-L/14 (336px)"),
        ("ViT-B-16-SigLIP",     "webli",  False, "SigLIP ViT-B/16 (224px)"),
        ("ViT-B-16-SigLIP-256", "webli",  False, "SigLIP ViT-B/16 (256px)"),
        ("ViT-L-16-SigLIP-256", "webli",  False, "SigLIP ViT-L/16 (256px)"),
    ]

    for model_id, pretrained, quick_gelu, model_name in VLM_CONFIGS:


        print(
            f"============================================",
            f"{model_name}",
            f"",
            sep="\n"
        )

        for templates, exp_name in zip(template_groups, tg_names):
            print(
                f"--------------------------------------------",
                f"{exp_name}",
                f"",
                sep="\n"
            )

            

            model, _, img_pp = open_clip.create_model_and_transforms(model_id, pretrained=pretrained, quick_gelu=quick_gelu, device=device)
            model.eval()
            tokenizer = open_clip.get_tokenizer(model_id)

            dataloader     = spawn_dataloader(dpath_valid, img_pp, BATCH_SIZE, N_WORKERS)
            class_embs_txt = build_class_protos_text(model, tokenizer, IMAGENET_CLASSES, templates, device).to(device)

            run_inference(model, dataloader, device, model_name, class_embs_txt=class_embs_txt)


run_experiment(TEMP_GROUPS, TG_NAMES)

CLIP ResNet-50 (224px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:27<00:00, 36.91it/s]
Evaluation: 100%|██████████| 25/25 [00:37<00:00,  1.51s/it]



Prec@1: 59.9%

CLIP ViT-B/32 (224px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:25<00:00, 38.76it/s]
Evaluation: 100%|██████████| 25/25 [00:34<00:00,  1.38s/it]



Prec@1: 63.3%

CLIP ViT-B/16 (224px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:25<00:00, 39.84it/s]
Evaluation: 100%|██████████| 25/25 [00:47<00:00,  1.91s/it]



Prec@1: 68.3%

CLIP ViT-L/14 (224px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:43<00:00, 22.87it/s]
Evaluation: 100%|██████████| 25/25 [02:44<00:00,  6.59s/it]



Prec@1: 75.6%

CLIP ViT-L/14 (336px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:43<00:00, 22.89it/s]
Evaluation: 100%|██████████| 25/25 [06:15<00:00, 15.03s/it]



Prec@1: 76.7%

SigLIP ViT-B/16 (224px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:30<00:00, 33.16it/s]
Evaluation: 100%|██████████| 25/25 [00:48<00:00,  1.93s/it]



Prec@1: 76.2%

SigLIP ViT-B/16 (256px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [00:29<00:00, 33.81it/s]
Evaluation: 100%|██████████| 25/25 [01:00<00:00,  2.40s/it]



Prec@1: 76.7%

SigLIP ViT-L/16 (256px)

--------------------------------------------
Custom + CLIP 80



Creating text embeddings: 100%|██████████| 1000/1000 [01:34<00:00, 10.56it/s]
Evaluation: 100%|██████████| 25/25 [02:44<00:00,  6.59s/it]


Prec@1: 80.7%






Note: `RN50` in the printout refers to CLIP ResNet-50 and `ResNet` refers to the standalone ResNet-50

Maybe perform an ablation of the text ensemble: just encode the labels straight up and compare to the templates

Not the 76.2% reported in the seminal CLIP paper (Table 11), but close enough

We conclude with a discussion of the results: