# CLIP: Multimodal foundation model

CLIP can compare a text with an image and predict a matching score.

**Goal.** The goal of this notebook is to demonstrate zero-shot capabilities of foundation models.

You need the following extra libraries beyond PyTorch:
* tranformers
* Pillow (PIL)

In [None]:
# Uncomment in Google Colab.
#! mkdir images
#! for name in cat cat_tree landscape opossum room street_upside_down; do wget -O images/${name}.jpg https://raw.github.com/ivan-chai/isscai-cv-2024/master/05-foundation/images/${name}.jpg; done

In [None]:
import math
import pathlib
import transformers
from PIL import Image
from matplotlib import pyplot as plt

paths = list(pathlib.Path("./images/").glob("*.jpg"))
images = [Image.open(path) for path in paths]

TILE_SIZE = 3
fig, axs = plt.subplots(int(math.ceil(len(images) / TILE_SIZE)), TILE_SIZE)
for i, image in enumerate(images):
    ax = axs[i // TILE_SIZE, i % TILE_SIZE]
    name = paths[i].name.split(".")[0]
    ax.set_title(name)
    ax.set_axis_off()
    ax.imshow(image)
plt.show()

In [None]:
model = transformers.CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = transformers.CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
prompts = ["a photo of a cat", "a beautiful image", "an outdoor scene", "an upside-down image", "a landscape"]
inputs = processor(text=prompts,
                   images=images,
                   return_tensors="pt", padding=True)
logits = model(**inputs).logits_per_image
logits.shape

In [None]:
def draw_tile(scores, threshold):
    fig, axs = plt.subplots(int(math.ceil(len(images) / TILE_SIZE)), TILE_SIZE)
    for i, image in enumerate(images):
        ax = axs[i // TILE_SIZE, i % TILE_SIZE]
        score = scores[i]
        color = "green" if score > threshold else "red"
        ax.set_title(f"Score: {score:.1f}", color=color)
        ax.set_axis_off()
        ax.imshow(image)
    plt.show()

In [None]:
thresholds = [25, 21, 20, 25, 23]
for i, prompt in enumerate(prompts):
    print("=" * 50)
    print(prompt.upper())
    print("=" * 50)
    draw_tile(logits[:, i], thresholds[i])