# Demo for CLIP
CLIP is a Vision Lanuage Model published by OpenAI. It combines Vision & Language and we can do some cool stuff with it.
Here is a demo of what you can for example do with it.

In [None]:
# Install the required packages
%pip install torch torchvision
%pip install open_clip_torch

In [None]:
import torch
from PIL import Image
import open_clip
from pathlib import Path
from tqdm import tqdm

In [None]:
# Load the model
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [None]:
# Simple function for loading an image and pre-processing it
def load_and_process_image(image_path):
    image = Image.open(str(image_path))
    inputs = preprocess(image).unsqueeze(0)
    image_features = model.encode_image(inputs)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features, image

# Shortcut function for processing text in CLIP
def process_text(texts, formatted="this is an image of {}"):
    if formatted:
        texts = [formatted.format(c) for c in texts]
    texts = tokenizer(texts)

    text_features = model.encode_text(texts)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    return text_features

# Helper to display images
def show_image(pil_image, max_size=400):
    pil_image.thumbnail((max_size, max_size))
    return display(pil_image)

# Helper to display the results of classification
def print_classes(classes, text_probs):
    max_idx = torch.max(text_probs, dim=1)[1].item()
    for i in range(len(classes)):
        text = f"- {classes[i]:<12} {text_probs[0, i].item()*100:.2f}%"
        if i == max_idx:
            text += " (best match)"
        print(text)

## Different concepts
CLIP can understand vastly different concepts. An example below

In [None]:
classes = ["a diagram", "a dog", "a cat", "a car"]
text_features = process_text(classes)

dir = Path("images/start/")
for image_path in dir.glob("*"):
    image_features, image = load_and_process_image(image_path)
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    show_image(image)
    print_classes(classes, text_probs)
    print("--------------------\n")

## Specific concepts
It can also be more specific, such as different breeds of cats

In [None]:
classes = ["European Shorthair", "Maine Coon", "Siamese", "British Shorthair", "Persian"]
text_features = process_text(classes, formatted="this image contains a {}, a cat breed")

dir = Path("images/cat_breeds/")
for image_path in dir.glob("*.jpg"):
    image_features, image = load_and_process_image(image_path)
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    show_image(image)
    print_classes(classes, text_probs)
    print("--------------------\n")

## Abstract concepts
Also more abstract concepts such as seasons & cities.

In [None]:
classes = ["spring", "summer", "autumn", "winter"]
text_features = process_text(classes, formatted="this image is taken in {}")

dir = Path("images/seasons/")
for image_path in dir.glob("*"):
    image_features, image = load_and_process_image(image_path)
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    show_image(image)
    print_classes(classes, text_probs)
    print("--------------------\n")

In [None]:
classes = ["amsterdam", "new york", "tokyo", "paris"]
text_features = process_text(classes, formatted="this image is taken in {}")

dir = Path("images/cities/")
for image_path in dir.glob("*"):
    image_features, image = load_and_process_image(image_path)
    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    show_image(image)
    print_classes(classes, text_probs)
    print("--------------------\n")

# Image Retrieval
We can also do some image retrieval. Select the best images given a query.

In [None]:
dir = Path("images/cats/")

all_images = []
all_image_features = []
for image_path in tqdm(list(dir.glob("*"))):
    image_features, image = load_and_process_image(image_path)
    all_images.append(image)
    all_image_features.append(image_features)

In [None]:
classes = ["multiple kittens"]
text_features = process_text(classes, formatted="this image contains {}")

probs = (100.0 * torch.concat(all_image_features) @ text_features.T)
sorted_indexes = torch.argsort(probs.squeeze(), dim=0, descending=True).tolist()

for i in sorted_indexes:
    show_image(all_images[i])
    print(f"Similarity: {probs[i].item():.2f}")
    print("--------------------")