# Lab 7 - Multimodal search with CLIP

In [None]:
import requests
import zipfile
import os
import io
from PIL import Image
from IPython.display import Image as IP_Image, display
from open_clip import tokenizer, create_model_and_transforms
import torch
from sklearn.neighbors import NearestNeighbors
import numpy as np
from tqdm import tqdm

## Get the image dataset (interiors of houses)

- Source: https://www.kaggle.com/datasets/mikhailma/house-rooms-streets-image-dataset/data
- Cached: https://max.io/house_data_png.zip (resized to 256x256 and converted to PNG)
- License: CC-0 Public Domain

In [None]:
# Function to download and extract the zip file
def download_and_extract_zip(url, extract_to='.'):
    print('Downloading and extracting',url)
    response = requests.get(url)
    with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
        zip_ref.extractall(extract_to)

# Download and extract the example images
url = "https://max.io/house_data_png.zip"
download_and_extract_zip(url)
image_dir = 'house_data_png'
image_paths = [os.path.join(image_dir, filename) for filename in os.listdir(image_dir)]
print('Extracted',len(image_paths),'images')

In [None]:
#Load our model.
model, transform, preprocess = create_model_and_transforms('ViT-B-32', pretrained='openai')

#Print the model architecture, note both the "visual" and "transformer" branches of the model
model.eval()

In [None]:
#Infers images in batches.
def get_image_embeddings(image_paths, batch_size=32):
    embeddings = []

    # Process images in batches
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing Images"):
        batch_paths = image_paths[i:i+batch_size]
        batch_images = [transform(Image.open(path)).unsqueeze(0) for path in batch_paths]

        # Stack and process the batch
        batch_images_tensor = torch.vstack(batch_images).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        
        with torch.no_grad():
            batch_embeddings = model.encode_image(batch_images_tensor)

        embeddings.append(batch_embeddings)

    # Concatenate all embeddings
    return torch.vstack(embeddings)

In [None]:
image_embeddings = get_image_embeddings(image_paths, batch_size=32)

In [None]:
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) #Normalization is required!

In [None]:
import pickle
with open('house_data_png.pkl', 'wb') as fd:
    pickle.dump(image_embeddings.cpu().numpy(), fd, pickle.HIGHEST_PROTOCOL)

In [None]:
print(len(image_embeddings),image_embeddings[0].shape)

In [None]:
#Encodes the text to the same vector space as the images
def embed_text(text):
    tokens = tokenizer.tokenize([text])
    with torch.no_grad():
        text_features = model.encode_text(tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True) #Normalization is required!
    return text_features

In [None]:
# Function to display images
def display_images(image_paths,distances):
    for idx,path in enumerate(image_paths):
        display(IP_Image(filename=path))
        print('👆',distances[idx])

In [None]:
# This will search and display nearest images given a text query
nbrs = NearestNeighbors(n_neighbors=10, metric='cosine').fit(image_embeddings.cpu().numpy())
def search(text):
    text_embedding = embed_text(text)
    distances, indices = nbrs.kneighbors(text_embedding.cpu().numpy())
    nearest_images = [image_paths[i] for i in indices[0]]
    display_images(nearest_images, distances[0])

In [None]:
search('large kitchen island colonial')

In [None]:
search('white marble shower stall')

In [None]:
search('red ferrari')

In [None]:
search('nuclear reactor')

In [None]:
def search_by_image(index):
    image_embedding = image_embeddings[index]
    distances, indices = nbrs.kneighbors([image_embedding.cpu().numpy()])
    nearest_images = [image_paths[i] for i in indices[0]]
    display_images(nearest_images, distances[0])

In [None]:
search_by_image(505)