# Reveal the attention of CLIP

In [natural-language-joint-query-search](https://github.com/haofanwang/natural-language-joint-query-search), we support for joint query search. In this project, we slightly modify the CLIP code and visualize the attention of CLIP. We can know which keywords CLIP focuses on, so as to improve the interpretability of CLIP.

In [1]:
!git clone https://github.com/haofanwang/natural-language-joint-query-search.git

Mounted at /content/drive


In [2]:
cd natural-language-joint-query-search

/content/drive/MyDrive/natural-language-joint-query-search


## Setup Environment

In this section we will setup the environment.

In [3]:
!git clone https://github.com/shashwattrivedi/Attention_visualizer.git

Cloning into 'Attention_visualizer'...
remote: Enumerating objects: 44, done.[K
remote: Total 44 (delta 0), reused 0 (delta 0), pack-reused 44[K
Unpacking objects: 100% (44/44), done.


In [4]:
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (735.4MB)
[K     |████████████████████████████████| 735.4MB 25kB/s 
[?25hCollecting torchvision==0.8.2+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.8.2%2Bcu101-cp36-cp36m-linux_x86_64.whl (12.8MB)
[K     |████████████████████████████████| 12.8MB 246kB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.7.0+cu101
    Uninstalling torch-1.7.0+cu101:
      Successfully uninstalled torch-1.7.0+cu101
  Found existing installation: torchvision 0.8.1+cu101
    Uninstalling torchvision-0.8.1+cu101:
      Successfully uninstalled torchvision-0.8.1+cu101
Successfully installed torch-1.7.1+cu101 torchvision-0.8.2+cu101


In [5]:
!pip install ftfy

Collecting ftfy
[?25l  Downloading https://files.pythonhosted.org/packages/04/06/e5c80e2e0f979628d47345efba51f7ba386fe95963b11c594209085f5a9b/ftfy-5.9.tar.gz (66kB)
[K     |█████                           | 10kB 21.0MB/s eta 0:00:01[K     |██████████                      | 20kB 15.9MB/s eta 0:00:01[K     |██████████████▉                 | 30kB 13.8MB/s eta 0:00:01[K     |███████████████████▉            | 40kB 12.9MB/s eta 0:00:01[K     |████████████████████████▉       | 51kB 9.4MB/s eta 0:00:01[K     |█████████████████████████████▊  | 61kB 10.1MB/s eta 0:00:01[K     |████████████████████████████████| 71kB 6.2MB/s 
Building wheels for collected packages: ftfy
  Building wheel for ftfy (setup.py) ... [?25l[?25hdone
  Created wheel for ftfy: filename=ftfy-5.9-cp36-none-any.whl size=46451 sha256=38350511c6b937db13fbddf3b147aae4f50b61cb83e4b92f607b6b5581e74a8c
  Stored in directory: /root/.cache/pip/wheels/5e/2e/f0/b07196e8c929114998f0316894a61c752b63bfa3fdd50d2fc3
Successf

## Loading the Precomputed Data

In this section the precomputed feature vectors for all photos are loaded. About how to download the data, please refer to [natural-language-joint-query-search](https://github.com/haofanwang/natural-language-joint-query-search) or [natural-language-image-search](https://github.com/haltakov/natural-language-image-search)

In [6]:
import pandas as pd
import numpy as np

# Load the photo IDs
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

# Load the features vectors
photo_features = np.load("unsplash-dataset/features.npy")

# Print some statistics
print(f"Photos loaded: {len(photo_ids)}")

Photos loaded: 1981161


## Define Functions

Some important functions from CLIP for processing the data are defined here.

In [8]:
def find_best_matches(text_features, photo_features, photo_ids, results_count=3):
  # Compute the similarity between the search query and each photo using the Cosine similarity
  similarities = (photo_features @ text_features.T).squeeze(1)

  # Sort the photos by their similarity score
  best_photo_idx = (-similarities).argsort()

  # Return the photo IDs of the best matches
  return [photo_ids[i] for i in best_photo_idx[:results_count]]

## Define Functions

Load the model.

In [None]:
import torch
from PIL import Image

from CLIP.clip import clip
from CLIP.clip import model

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

## Visualization

Given a search query, we first calculate its embedding and retrive images from unplash as before. Moreover, we save the weight of the last attention layer. The visualized results show the attention of CLIP.

#### "A red flower is under the blue sky and there is a bee on the flower"

In [32]:
search_query = "A red flower is under the blue sky and there is a bee on the flower"

with torch.no_grad():
    # Encode and normalize the search query using CLIP
    text_token = clip.tokenize(search_query).to(device)
    text_encoded, weight = model.encode_text(text_token)
    text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

text_features = text_encoded.cpu().numpy()
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/_QMxWAa3gXM/download
https://unsplash.com/photos/lp_TphksOrg/download
https://unsplash.com/photos/4pYmH4o0zNo/download
https://unsplash.com/photos/Ye-PdCxCmEQ/download
https://unsplash.com/photos/qyN7CD8qm5M/download


In [33]:
from Attention_visualizer.attention_visualizer import *

sentence = search_query.split(" ")
attention_weights = list(weight[-1][0][1+len(sentence)].cpu().numpy())[:2+len(sentence)][1:][:-1]
attention_weights = [float(item) for item in attention_weights]
display_attention(sentence,attention_weights)

#### "Two dogs playing in the snow"

In [39]:
search_query = "a woman holding an umbrella standing next to a man in a rainy day"

with torch.no_grad():
    # Encode and normalize the search query using CLIP
    text_token = clip.tokenize(search_query).to(device)
    text_encoded, weight = model.encode_text(text_token)
    text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

text_features = text_encoded.cpu().numpy()
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/EFOYS783_D0/download
https://unsplash.com/photos/KKDOB6YLZtM/download
https://unsplash.com/photos/qNo7I5cbZKg/download
https://unsplash.com/photos/cNgiyFNlZw8/download
https://unsplash.com/photos/AVQRYiyXO7o/download


In [40]:
from Attention_visualizer.attention_visualizer import *

sentence = search_query.split(" ")
attention_weights = list(weight[-1][0][1+len(sentence)].cpu().numpy())[:2+len(sentence)][1:][:-1]
attention_weights = [float(item) for item in attention_weights]
display_attention(sentence,attention_weights)