<a href="https://colab.research.google.com/github/mesnico/DTfH-Laboratory/blob/main/2025/text_to_image_similarity_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Laboratory on Text and Image Representations for Text-to-Image Similarity Search
-----------------

You'll learn to:

*   Represent images using features extracted from multimodal deep neural networks.
*   Search images using textual descriptions.

## Getting Started

First of all, we need to download and unzip the image dataset, and install some Python dependencies.

We will use MIRFlickr5k, a subset of the larger [MIRFlickr25k](https://press.liacs.nl/mirflickr/mirdownload.html), which contains photographs downloaded from the popular Flickr website.

Run the following for getting the environment ready! This could take some minutes...



In [None]:
import os
if not os.path.exists('mirflickr5k'):
  # !wget mb-messina.isti.cnr.it/mirflickr5k.zip
  !gdown --id 1sEBg-sZgSQac0W7fyPecj-s_uWAjLLh9
  !unzip -n mirflickr5k.zip
else:
  print('Dataset already downloaded!')

Next, we install some python dependency and we import the needed python packages

In [None]:
!pip install transformers

import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
import transformers
from PIL import Image

from pathlib import Path
from tqdm.auto import tqdm

import matplotlib
import matplotlib.pyplot as plt
from skimage.transform import resize
import requests

import random
import os

%matplotlib inline
random.seed(42)

euclidean_distance = torch.cdist

In [None]:
# define the function for computing the k-NN from the distances
def k_nearest_neighbors(distances, k=5):
  nq, ndb = distances.shape

  sorted_distances = distances.argsort(axis=1)  # sort the scores ascending, for each query
  topk = sorted_distances[:, :k]  # get **indices** of the topk images for each row
  topk_distances = distances[np.arange(nq)[:, None], topk]  # use the indices to get the topk scores (magic slicing version)
  # topk_scores = np.concatenate([scores[i, topk[i]] for i in range(nq)])  # get topk scores (comprehensible version)
  return topk, topk_distances

# define an helper function to view the results
def show_images(urls, figsize=None):
  n_images = len(urls)
  fig, axes = plt.subplots(1, n_images, figsize=figsize)
  for ax, url in zip(axes, urls):
    image_np = np.asarray(Image.open(requests.get(url, stream=True).raw))
    # image_np = unpad_image(image_np)
    image_np = resize(image_np, (400, 300))
    ax.set_ylabel(f'Query')
    ax.imshow(image_np)
    ax.set_xticks([])
    ax.set_yticks([])

  return fig

## Data Loading

Let's inspect the data.

In [None]:
image_paths = Path('mirflickr5k').rglob('*.jpg')
image_paths = sorted(image_paths)
image_paths[:5]

In [None]:
fig, axes = plt.subplots(5, 10, figsize=(20, 10))
for ax, image_path in zip(axes.flatten(), image_paths):
  image_np = plt.imread(image_path)
  image_np = resize(image_np, (400, 300))
  ax.imshow(image_np)
  ax.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)

Consider a (potentially large) database of images and a set of query images.
Our goal is to retrieve images from the database that are visually similar to the queries.

Let's first select some images among which we will search. We will define our queries later.

In [None]:
ndb = 1000 # number of samples in the database to consider

selected_image_paths = random.sample(image_paths, ndb)
db_image_paths = selected_image_paths[:ndb]
db_image_paths[:5]

## Text to Image Retrieval

We will try to retrieve images using natural texts as a query.

We need:
- a feature extractor for the _images_ (our _database_)
- a feature extractor for the _texts_ (our _queries_)

We will use the [CLIP](https://huggingface.co/docs/transformers/v4.19.2/en/model_doc/clip) deep neural network, which implements both the feature extractors.

This model is trained
- to extract **representations of images** (image features)
- and **representations of short text sentences** (text features)
- such that those representations **match** when the text describe the image content.

Let's initialize CLIP.

In [None]:
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

### Image and text representations

Now we define some nice helper functions:
- the `extract_features_from_images`, which extracts representations from the images;
- the `extract_features_from_texts`, which extracts representations from the images.

We hide the internals of these functions, as there are some unimportant details. Feel free to unhide it if you want to know more.

However, run the following block before moving on!

In [None]:
def extract_features_from_images(images):
  features = []

  # we repeat the extraction for each image in the given list of images
  for image in tqdm(images):

    # open the image
    image = Image.open(image)

    # perform some preprocessing (scale, normalization) on the image
    inputs = processor(images=image, return_tensors="pt")
    inputs = inputs.to(device)

    # extract the features from the image using our CLIP deep neural network
    with torch.no_grad():
      feature = model.get_image_features(**inputs)

    # save the features in a list that we will return
    features.append(feature)

  # do some post-processing on the features before returning them
  features = torch.cat(features, dim=0)
  features = F.normalize(features)

  return features


def extract_features_from_texts(texts):

  # preprocess the words of the text
  inputs = tokenizer(texts, padding=True, return_tensors="pt")
  inputs = inputs.to(device)

  # extract the features from the text using the CLIP deep neural network
  with torch.no_grad():
    features = model.get_text_features(**inputs)

  # do some post-processing on the features before returning them
  features = F.normalize(features)
  return features

###Let's go with feature extraction!

Now, call the `extract_features_from_images` function to extract the image features from the whole dataset.

Look at the dimensionality of the extracted features. For each of the image in the database, we have a 512-dimensional feature (for ResNet it was 2048).

In [None]:
image_dataset_features = extract_features_from_images(db_image_paths)
print(image_dataset_features.shape)

Then, we define some textual queries and we extract their features using this `extract_features_from_texts` function.

Note that the dimensionality of the textual features is again 512, as the image features! We can therefore compute the Euclidean distance between them.

In [None]:
textual_queries = [
                   'A person riding a bike',
                   'A picture of a young child',
                   'A view of some mountains',
                   'A laptop'
]

textual_queries_features = extract_features_from_texts(textual_queries)
print(textual_queries_features.shape)

### The core of text-to-image similarity search

Once we have feature vectors representations, we can search similar representations by comparing the features vectors instead of pixels.

We will compare feature vectors using the Euclidean distance (as we did for the image-retrieval case! Nothing changed here)
$$
d = \sqrt{(x_1-y_1)^2 + (x_2-y_2)^2 + \dotso + (x_n-y_n)^2},
$$
where $\{x_1, x_2, ... x_n\}$ are the coordinates of the first feature and $\{y_1, y_2, ... y_n\}$ those of the second.

In [None]:
distances = euclidean_distance(textual_queries_features, image_dataset_features)

Now, we can reuse the `k_nearest_neighbors` function already used in Part 1, as is, for searching image representations more similar to text representations!

Therefore, remember what we are doing under the hood:
1. We sort the distances from the smallest to the largest, for each query.
2. We take the first $k$ features as a result, again for each query.

In [None]:
k = 5
topk, topk_distances = k_nearest_neighbors(distances, k)
print(topk_distances)

### Look at the results!

Let's finally view the results.

In [None]:
nq, ndb = distances.shape

# show topk similar
fig, axes = plt.subplots(k, nq, figsize=(18, 4*k))
for j in range(k):
  axes[j, 0].set_ylabel(f'Rank #{j}')
  for i in range(nq):
    if j == 0:
      axes[0, i].set_title(textual_queries[i])
    image_np = plt.imread(db_image_paths[topk[i, j]])
    # image_np = unpad_image(image_np)
    image_np = resize(image_np, (400, 300))
    axes[j, i].imshow(image_np)
    axes[j, i].xaxis.set_label_position('top')
    axes[j, i].set_xlabel('dist = {:.2f}'.format(topk_distances[i, j]))
    axes[j, i].set_xticks([])
    axes[j, i].set_yticks([])

### Try yourself
You could try the following things:
- change the number $k$ of neirest neighbors to retrieve for each query;
- try to write other textual queries to understand what are the nice properties and the limitations of this approach. For example:
 - try queries with colors (e.g., "There is a _red_ thing on top of the table")
 - try queries with spatial indications ("a person _to the right of_ a car"))
- try to change the number of images to retrieve (to 1000 to 3000 for example, to show if the results change). _Warning: feature extraction will be very slow :(_