# Zero-Shot Learning with CLIP

In this notebook, we explore zero-shot learning using the pre-trained CLIP model from Hugging Face. I'll explain each step along the way from installing required libraries to processing images and candidate labels, and finally classifying an image using a zero-shot approach.

We'll follow these steps:
1. Install and import necessary libraries.
2. Load the pre-trained CLIP model and its processor.
3. Define a function to perform zero-shot classification.
4. Run an example with a sample image and candidate labels.

In [1]:
# Install the required libraries if you haven't already.
# Uncomment the next line to install the packages in your notebook environment.

# !pip install torch torchvision transformers pillow requests

## Import Libraries

In this cell, we import the libraries we need. We use:
- `torch` for tensor operations,
- `transformers` to load the CLIP model and its processor,
- `PIL` (Python Imaging Library) for image processing,
- `requests` to download images from a URL.

In [2]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests

## Load the Pre-Trained CLIP Model and Processor

The CLIP model is pre-trained to map images and text into the same embedding space. The processor helps prepare both the image and the text data (candidate labels) so that they can be input into the model. Here we load the model `openai/clip-vit-base-patch32` along with its processor.

In [3]:
# Define the model name and load the CLIP model and processor.
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

## Define the Zero-Shot Classification Function

This function, `zero_shot_classify`, does the following:
1. Downloads an image from a given URL.
2. Processes the image and candidate text labels using the processor.
3. Feeds the processed inputs into the CLIP model.
4. Computes similarity scores (logits) between the image and each candidate label.
5. Converts these logits into probabilities using softmax.
6. Returns the candidate label with the highest probability along with all probabilities.

We use cosine similarity (embedded in the model's logits) to compare the image with each text description.

In [4]:
def zero_shot_classify(image_url, candidate_labels):
    """
    Perform zero-shot classification on an image using candidate labels.

    Args:
    - image_url (str): URL to the image.
    - candidate_labels (list of str): List of text labels to classify the image.

    Returns:
    - tuple: (predicted label, tensor of probabilities for each candidate label)
    """
    # Download and open the image from the URL.
    response = requests.get(image_url, stream=True)
    image = Image.open(response.raw).convert("RGB")

    # Use the processor to prepare both the image and the candidate text labels.
    inputs = processor(text=candidate_labels, images=image, return_tensors="pt", padding=True)

    # Perform a forward pass through the model to obtain similarity scores (logits).
    outputs = model(**inputs)

    # Extract the logits representing image-text similarity scores.
    logits_per_image = outputs.logits_per_image

    # Convert the logits to probabilities using the softmax function.
    probs = logits_per_image.softmax(dim=1)

    # Determine the candidate label with the highest probability.
    max_prob, max_idx = torch.max(probs, dim=1)

    return candidate_labels[max_idx.item()], probs

## Run an Example

Now, let's run an example. We provide an image URL and a list of candidate labels (for example: "a cat", "a dog", "a bird", and "a car"). The notebook will use our `zero_shot_classify` function to predict which label best describes the image.

In [5]:
# Example image URL. You can replace this URL with any image you'd like to classify.
image_url = "https://cdn.pixabay.com/photo/2023/08/18/15/02/dog-8198719_640.jpg"

# Define candidate labels for the classification.
candidate_labels = ["a cat", "a dog", "a bird", "a car"]

# Use the zero_shot_classify function to predict the label for the image.
predicted_label, probabilities = zero_shot_classify(image_url, candidate_labels)

# Print the result.
print("Predicted Label:", predicted_label)
print("Probabilities:", probabilities.detach().numpy())

Predicted Label: a dog
Probabilities: [[1.2293695e-03 9.9801838e-01 4.2982257e-04 3.2242425e-04]]


## Conclusion

In this notebook, we learned how zero-shot learning can be applied using the pre-trained CLIP model. We went through:
- Installing and importing the required libraries.
- Loading the CLIP model and processor.
- Defining a function for zero-shot classification.
- Running an example to see the model in action.

This approach shows how a model can classify images into categories it was never explicitly trained on by using a shared semantic space between images and text. Enjoy experimenting with zero-shot learning in your projects!