# Zero-shot image classification

## Zero-shot learning with CLIP

In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt

dset = "rajuptvs/ecommerce_products_clip"
dataset = load_dataset(dset)
print(dataset["train"][0]["Description"])
plt.imshow(dataset["train"][0]["image"])
plt.show()

In [None]:
from transformers import CLIPProcessor, CLIPModel

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

categories = ["shirt", "trousers", "shoes", "dress", "hat", "bag", "watch"]

inputs = processor(text=categories, images=dataset["train"][0]["image"], return_tensors="pt", padding=True)
outputs = model(**inputs)

probs = outputs.logits_per_image.softmax(dim=1)
categories[probs.argmax().item()]

## Automated caption quality assessment

In [None]:
from torchmetrics.functional.multimodal import clip_score

image = dataset["train"][0]["image"]
description = dataset["train"][0]["Description"]

from torchvision.transforms import ToTensor

image = ToTensor()(image)*255
score = clip_score(image, description, "openai/clip-vit-base-patch32")

print(f"CLIP score: {score}")