[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dbamman/anlp25/blob/main/12.multimodal/Multimodal.ipynb)

# Multimodal embeddings with CLIP

In this notebook, we will explore multimodal embeddings using the CLIP model, which includes an image embedder and text encoder that project both into the same embedding space.

We will use CLIP to explore a small subset of images from the National Gallery of Art.

In [None]:
from itertools import islice
from pathlib import Path

import torch
import matplotlib.pyplot as plt

from torchvision.io import read_image
from transformers import SiglipModel, SiglipProcessor, CLIPModel, AutoProcessor
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

## Load the data

We've compiled some public-domain images from [the National Gallery of Art](https://www.nga.gov/artworks/free-images-and-open-access#a-section-header-p102746). Let's download the data and take a look at a few of the images.

In [None]:
!wget https://github.com/dbamman/anlp25/raw/refs/heads/main/data/nga.tar.gz -O nga.tar.gz
!tar -xzf nga.tar.gz

In [None]:
image_paths = sorted(list(Path("images/").glob("*.jpg")))

In [None]:
def show_image(path):
    plt.imshow(read_image(path).permute(1, 2, 0))
    plt.axis('off')

def show_images(paths, num_per_row=5):
    num_images = len(paths)
    num_rows = (num_images + num_per_row - 1) // num_per_row  # Ceiling division
    
    fig, axes = plt.subplots(num_rows, num_per_row, figsize=(num_per_row * 3, num_rows * 3))
    
    # Flatten axes array for easier indexing (handles both 1D and 2D cases)
    if num_rows == 1 and num_per_row == 1:
        axes = np.array([axes])
    elif num_rows == 1 or num_per_row == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    for idx, path in enumerate(paths):
        axes[idx].imshow(read_image(path).permute(1, 2, 0))
        axes[idx].axis('off')  # Hide axes for cleaner display
    
    # Hide any unused subplots
    for idx in range(num_images, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
show_images(image_paths[:15])

## Embedding images

We will use the CLIP model to get the image embeddings for all of the images in our dataset.

In [None]:
def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 2) â†’ AB CD EF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

def get_image_embeddings():
    with torch.no_grad():
        all_outputs = []
        for batch in batched(tqdm(image_paths), 32):
            batch_images = [read_image(path) for path in batch]
            inputs = processor.image_processor(batch_images, return_tensors="pt").to(model.device)
            outputs = model.get_image_features(**inputs)
            all_outputs.append(outputs.cpu())
    return torch.vstack(all_outputs)

    
embeds = get_image_embeddings()
embeds.shape

## Querying against the embeddings

Recall our previous experiments with word and sentence embeddings, where we queried for nearest neighbors based on cosine similarity. We can do the same here.

In [None]:
def get_nn(query_vec, n=10):
    if len(query_vec.shape) < 2:
        # if query_vec is a single vector, make it a batch of size 1
        query_vec = query_vec.unsqueeze(0)
    sims = torch.cosine_similarity(query_vec, embeds)
    return sims.argsort()[-n:].tolist()[::-1]

For example, let's find the 10 most similar images to this image of a bucket.

In [None]:
show_image(image_paths[754])

In [None]:
neighbors = get_nn(embeds[754])
show_images([image_paths[i] for i in neighbors])

In [None]:
def search(query, n=10):
    processed = processor.tokenizer([query], return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.get_text_features(**processed)
        embedding = out.cpu()
    nns = get_nn(embedding, n)
    return nns

In [None]:
show_images([image_paths[i] for i in search("A serene lake", n=6)], num_per_row=3)

## Explore

What else can you find in this dataset?