# Reveal the attention of CLIP

In [natural-language-joint-query-search](https://github.com/haofanwang/natural-language-joint-query-search), we support for joint query search. In this project, we slightly modify the CLIP code and visualize the attention of CLIP. We can know which keywords CLIP focuses on, so as to improve the interpretability of CLIP.

In [1]:
!git clone https://github.com/haofanwang/natural-language-joint-query-search.git

Cloning into 'natural-language-joint-query-search'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (106/106), done.[K
remote: Total 116 (delta 37), reused 43 (delta 5), pack-reused 0[K
Receiving objects: 100% (116/116), 13.12 MiB | 29.58 MiB/s, done.
Resolving deltas: 100% (37/37), done.


In [2]:
cd natural-language-joint-query-search

/content/natural-language-joint-query-search


## Setup Environment

In this section we will setup the environment.

In [3]:
!git clone https://github.com/shashwattrivedi/Attention_visualizer.git

Cloning into 'Attention_visualizer'...
remote: Enumerating objects: 44, done.[K
remote: Total 44 (delta 0), reused 0 (delta 0), pack-reused 44[K
Unpacking objects:   2% (1/44)   Unpacking objects:   4% (2/44)   Unpacking objects:   6% (3/44)   Unpacking objects:   9% (4/44)   Unpacking objects:  11% (5/44)   Unpacking objects:  13% (6/44)   Unpacking objects:  15% (7/44)   Unpacking objects:  18% (8/44)   Unpacking objects:  20% (9/44)   Unpacking objects:  22% (10/44)   Unpacking objects:  25% (11/44)   Unpacking objects:  27% (12/44)   Unpacking objects:  29% (13/44)   Unpacking objects:  31% (14/44)   Unpacking objects:  34% (15/44)   Unpacking objects:  36% (16/44)   Unpacking objects:  38% (17/44)   Unpacking objects:  40% (18/44)   Unpacking objects:  43% (19/44)   Unpacking objects:  45% (20/44)   Unpacking objects:  47% (21/44)   Unpacking objects:  50% (22/44)   Unpacking objects:  52% (23/44)   Unpacking objects:  54% (24/44)   Unpacking objects:  56

In [4]:
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install ftfy regex tqdm

Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.7.1+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torch-1.7.1%2Bcu101-cp36-cp36m-linux_x86_64.whl (735.4MB)
[K     |████████████████████████████████| 735.4MB 25kB/s 
[?25hCollecting torchvision==0.8.2+cu101
[?25l  Downloading https://download.pytorch.org/whl/cu101/torchvision-0.8.2%2Bcu101-cp36-cp36m-linux_x86_64.whl (12.8MB)
[K     |████████████████████████████████| 12.8MB 253kB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.7.0+cu101
    Uninstalling torch-1.7.0+cu101:
      Successfully uninstalled torch-1.7.0+cu101
  Found existing installation: torchvision 0.8.1+cu101
    Uninstalling torchvision-0.8.1+cu101:
      Successfully uninstalled torchvision-0.8.1+cu101
Successfully installed torch-1.7.1+cu101 torchvision-0.8.2+cu101
Collecting ftfy
[?25l  Downloading https://files.pythonhosted.org/packages/04/06/e5c80e2e0f979628d47

## Loading the Precomputed Data

In this section the precomputed feature vectors for all photos are loaded. About how to download the data, please refer to [natural-language-joint-query-search](https://github.com/haofanwang/natural-language-joint-query-search) or [natural-language-image-search](https://github.com/haltakov/natural-language-image-search)

In [6]:
from pathlib import Path

# Create a folder for the precomputed features
!mkdir unsplash-dataset

# Download the photo IDs and the feature vectors
!gdown --id 1FdmDEzBQCf3OxqY9SbU-jLfH_yZ6UPSj -O unsplash-dataset/photo_ids.csv
!gdown --id 1L7ulhn4VeN-2aOM-fYmljza_TQok-j9F -O unsplash-dataset/features.npy

# Download from alternative source, if the download doesn't work for some reason (for example download quota limit exceeded)
if not Path('unsplash-dataset/photo_ids.csv').exists():
  !wget https://transfer.army/api/download/TuWWFTe2spg/EDm6KBjc -O unsplash-dataset/photo_ids.csv

if not Path('unsplash-dataset/features.npy').exists():
  !wget https://transfer.army/api/download/LGXAaiNnMLA/AamL9PpU -O unsplash-dataset/features.npy

Downloading...
From: https://drive.google.com/uc?id=1FdmDEzBQCf3OxqY9SbU-jLfH_yZ6UPSj
To: /content/natural-language-joint-query-search/unsplash-dataset/photo_ids.csv
23.8MB [00:00, 65.4MB/s]
Downloading...
From: https://drive.google.com/uc?id=1L7ulhn4VeN-2aOM-fYmljza_TQok-j9F
To: /content/natural-language-joint-query-search/unsplash-dataset/features.npy
2.03GB [00:18, 108MB/s] 


In [7]:
import pandas as pd
import numpy as np

# Load the photo IDs
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])

# Load the features vectors
photo_features = np.load("unsplash-dataset/features.npy")

# Print some statistics
print(f"Photos loaded: {len(photo_ids)}")

Photos loaded: 1981161


## Define Functions

Some important functions from CLIP for processing the data are defined here.

In [8]:
def find_best_matches(text_features, photo_features, photo_ids, results_count=3):
  # Compute the similarity between the search query and each photo using the Cosine similarity
  similarities = (photo_features @ text_features.T).squeeze(1)

  # Sort the photos by their similarity score
  best_photo_idx = (-similarities).argsort()

  # Return the photo IDs of the best matches
  return [photo_ids[i] for i in best_photo_idx[:results_count]]

## Define Functions

Load the model.

In [9]:
import torch
from PIL import Image

from CLIP.clip import clip
from CLIP.clip import model

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

100%|████████████████████████████████████████| 354M/354M [00:02<00:00, 137MiB/s]


## Visualization

Given a search query, we first calculate its embedding and retrive images from unplash as before. Moreover, we save the weight of the last attention layer. The visualized results show the attention of CLIP.

#### "A red flower is under the blue sky and there is a bee on the flower"

In [10]:
search_query = "A red flower is under the blue sky and there is a bee on the flower"

with torch.no_grad():
    # Encode and normalize the search query using CLIP
    text_token = clip.tokenize(search_query).to(device)
    text_encoded, weight = model.encode_text(text_token)
    text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

text_features = text_encoded.cpu().numpy()
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/_QMxWAa3gXM/download
https://unsplash.com/photos/lp_TphksOrg/download
https://unsplash.com/photos/4pYmH4o0zNo/download
https://unsplash.com/photos/Ye-PdCxCmEQ/download
https://unsplash.com/photos/qyN7CD8qm5M/download


In [11]:
from Attention_visualizer.attention_visualizer import *

sentence = search_query.split(" ")
attention_weights = list(weight[-1][0][1+len(sentence)].cpu().numpy())[:2+len(sentence)][1:][:-1]
attention_weights = [float(item) for item in attention_weights]
display_attention(sentence,attention_weights)

#### "A woman holding an umbrella standing next to a man in a rainy day"

In [12]:
search_query = "A woman holding an umbrella standing next to a man in a rainy day"

with torch.no_grad():
    # Encode and normalize the search query using CLIP
    text_token = clip.tokenize(search_query).to(device)
    text_encoded, weight = model.encode_text(text_token)
    text_encoded /= text_encoded.norm(dim=-1, keepdim=True)

text_features = text_encoded.cpu().numpy()
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, 5)

for photo_id in best_photo_ids:
  print("https://unsplash.com/photos/{}/download".format(photo_id))

https://unsplash.com/photos/EFOYS783_D0/download
https://unsplash.com/photos/KKDOB6YLZtM/download
https://unsplash.com/photos/qNo7I5cbZKg/download
https://unsplash.com/photos/cNgiyFNlZw8/download
https://unsplash.com/photos/AVQRYiyXO7o/download


In [13]:
from Attention_visualizer.attention_visualizer import *

sentence = search_query.split(" ")
attention_weights = list(weight[-1][0][1+len(sentence)].cpu().numpy())[:2+len(sentence)][1:][:-1]
attention_weights = [float(item) for item in attention_weights]
display_attention(sentence,attention_weights)