In [3]:
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer
from datasets import load_dataset, concatenate_datasets
import evaluate
import torch

In [4]:
dataset_dict = load_dataset("clip-benchmark/wds_vtab-caltech101")
classes = load_dataset("clip-benchmark/wds_vtab-caltech101", data_files="classnames.txt")["train"]["text"]
templates = load_dataset("clip-benchmark/wds_vtab-caltech101", data_files="zeroshot_classification_templates.txt")["train"]["text"]

Downloading data: 100%|██████████| 174M/174M [00:02<00:00, 66.8MB/s] 
Downloading data: 100%|██████████| 17.0M/17.0M [00:01<00:00, 14.1MB/s]
Downloading data: 100%|██████████| 180M/180M [00:02<00:00, 87.0MB/s] 
Downloading data: 100%|██████████| 176M/176M [00:02<00:00, 81.6MB/s] 
Downloading data: 100%|██████████| 83.4M/83.4M [00:01<00:00, 79.9MB/s]
Generating train split: 100%|██████████| 2753/2753 [00:00<00:00, 4129.46 examples/s]
Generating test split: 100%|██████████| 6085/6085 [00:01<00:00, 4490.00 examples/s]
Downloading data: 100%|██████████| 971/971 [00:00<00:00, 7.42kB/s]
Generating train split: 102 examples [00:00, 105374.14 examples/s]
Downloading data: 100%|██████████| 646/646 [00:00<00:00, 4.29kB/s]
Generating train split: 34 examples [00:00, 18245.44 examples/s]


In [5]:
dataset = dataset_dict["train"] #concatenate_datasets([dataset_dict["train"], dataset_dict["test"]])

In [6]:
model_id = [
    "openai/clip-vit-base-patch16",
    "openai/clip-vit-base-patch32",
    "openai/clip-vit-large-patch14",
    "openai/clip-vit-large-patch14-336",
][0]

device = "cuda" if torch.cuda.is_available() else "cpu"

model = CLIPModel.from_pretrained(model_id)
model.to(device)
processor = CLIPProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [7]:
# # Baseline
# inputs = tokenizer(classes, padding=True, return_tensors="pt")
# inputs = {k: t.to(model.device) for k, t in inputs.items()}
# with torch.no_grad():
#     text_features = model.get_text_features(**inputs)
# def zero_shot_classify(images, text_features):
#     inputs = processor(images=images, return_tensors="pt")
#     inputs = {k: t.to(model.device) for k, t in inputs.items()}
#     with torch.no_grad():
#         image_features = model.get_image_features(**inputs)

#     # Softmax with temperature
#     image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
#     text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
#     logits = 100.0 * (image_features @ text_features.t())
#     probs = logits.softmax(dim=1)
#     pred = pred.argmax(dim=1)
#     return pred

# # With ensemble mean
# text_features = []
# for name in classes:
#     name_templates = [template.format(c=name) for template in templates]
#     inputs = tokenizer(name_templates, padding=True, return_tensors="pt")
#     inputs = {k: t.to(model.device) for k, t in inputs.items()}
#     with torch.no_grad():
#         name_text_features = model.get_text_features(**inputs)
#     avg_text_features = name_text_features.mean(0)
#     text_features.append(avg_text_features)
# text_features = torch.stack(text_features)

# def zero_shot_classify(images, text_features):
#     inputs = processor(images=images, return_tensors="pt")
#     inputs = {k: t.to(model.device) for k, t in inputs.items()}
#     with torch.no_grad():
#         image_features = model.get_image_features(**inputs)

#     # Softmax with temperature
#     image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
#     text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
#     logits = 100.0 * (image_features @ text_features.t())
#     # probs = logits.softmax(dim=1)
#     pred = logits.argmax(dim=1)
#     return pred

# With ensemble
ensemble_method = "mean_logit" # "mean_logit", "mean_softmax"
text_features = []
for name in classes:
    name_templates = [template.format(c=name) for template in templates]
    inputs = tokenizer(name_templates, padding=True, return_tensors="pt")
    inputs = {k: t.to(model.device) for k, t in inputs.items()}
    with torch.no_grad():
        name_text_features = model.get_text_features(**inputs)
    text_features.append(name_text_features)
text_features = torch.stack(text_features)

def zero_shot_classify(images, text_features):
    inputs = processor(images=images, return_tensors="pt")
    inputs = {k: t.to(model.device) for k, t in inputs.items()}
    with torch.no_grad():
        image_features = model.get_image_features(**inputs)
    
    s = text_features.shape
    text_features_view = text_features.view((s[0] * s[1], s[2]))

    # Softmax with temperature
    image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
    text_features_view = text_features_view / text_features_view.norm(p=2, dim=-1, keepdim=True)
    logits = 100.0 * (image_features @ text_features_view.t())

    logits = logits.view((logits.shape[0], s[0], s[1]))

    if ensemble_method == "mean_logit":
        probs = logits.mean(-1).softmax(dim=1)
    else:
        probs = logits.softmax(dim=1).mean(-1)

    pred = probs.argmax(dim=1)

    return pred

In [8]:
accuracy = evaluate.load("accuracy")

# .select(range(100))
for sample in dataset.iter(batch_size=64):
    pred = zero_shot_classify(sample["webp"], text_features)
    accuracy.add_batch(references=sample["cls"], predictions=pred)

accuracy.compute()

Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 2.57MB/s]


AttributeError: module 'PIL.Image' has no attribute 'ExifTags'