# Source retrieval with CLIP

In [None]:
import os
import json
import torch
import clip
import numpy as np
import IPython.display
import matplotlib.pyplot as plt

import seaborn as sns
from tqdm import tqdm
import pickle

from PIL import Image
from collections import OrderedDict

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print("Torch version:", torch.__version__)

print(clip.available_models())

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model, preprocess = clip.load("ViT-B/32")
model.to(device).eval()
input_resolution = model.visual.input_resolution
context_length   = model.context_length
vocab_size       = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:",   context_length)
print("Vocab size:",       vocab_size)

## Load dataset

In [None]:
with open('./data/WebQA_train_val.json') as f:
    json_dict = json.load(f)

json_dict['d5bbc6d80dba11ecb1e81171463288e9'].keys()

In [None]:
with open('data/val_subset_1644.pkl', 'rb') as f:
    data = pickle.load(f)

## Compute similarities
### Text and Question

In [None]:
pos_text_cos_sim = []
neg_text_cos_sim = []

total_qa = 1000
batch_bar = tqdm(total=total_qa, dynamic_ncols=True, desc='Text-Q Similarity') 

with torch.no_grad():
    for i, (key, qa) in enumerate(data.items()):
        if i == total_qa:
            break
        # Encode question
        question          = qa['Q']
        question_tokens   = clip.tokenize([question])
        question_features = model.encode_text(question_tokens).float()
        question_features /= question_features.norm(dim=-1, keepdim=True)

        # Encode positive images
        if len(qa['TxtPos']) > 0:
            pos_text_tokens   = clip.tokenize( [txt[:60] for txt in qa['TxtPos']] )
            pos_text_features = model.encode_text(pos_text_tokens)
            pos_text_features /= pos_text_features.norm(dim=-1, keepdim=True)

            # Compute similarity
            pos_text_similarity = torch.matmul(question_features, pos_text_features.T).squeeze(0)
            pos_text_cos_sim.extend(pos_text_similarity.tolist())

        # Encode distractor images
        if len(qa['TxtNeg']) > 0:
            neg_text_tokens   = clip.tokenize( [txt[:60] for txt in qa['TxtNeg']] )
            neg_text_features = model.encode_text(neg_text_tokens)
            neg_text_features /= neg_text_features.norm(dim=-1, keepdim=True)
            
            # Compute similarity
            neg_text_similarity = torch.matmul(question_features, neg_text_features.T).squeeze(0)
            neg_text_cos_sim.extend(neg_text_similarity.tolist())

        batch_bar.set_postfix(Avg_pos="{:.3f}".format(np.mean(pos_text_cos_sim)),
                              Avg_neg="{:.3f}".format(np.mean(neg_text_cos_sim)))
        batch_bar.update()

batch_bar.close()

### Captions and question

In [None]:
pos_caption_cos_sim = []
neg_caption_cos_sim = []

total_qa = 1000
batch_bar = tqdm(total=total_qa, dynamic_ncols=True, desc='Caption-Q Similarity') 

with torch.no_grad():
    for i, (key, qa) in enumerate(json_dict.items()):
        if i == total_qa:
            break
        # Encode question
        question          = qa['Q']
        question_tokens   = clip.tokenize([question])
        question_features = model.encode_text(question_tokens).float()
        question_features /= question_features.norm(dim=-1, keepdim=True)

        # Encode positive images
        if len(qa['img_posFacts']) > 0:
            pos_caption_tokens   = clip.tokenize( [txt['caption'][:60] for txt in qa['img_posFacts']] )
            pos_caption_features = model.encode_text(pos_caption_tokens)
            pos_caption_features /= pos_caption_features.norm(dim=-1, keepdim=True)

            # Compute similarity
            pos_caption_similarity = torch.matmul(question_features, pos_caption_features.T).squeeze(0)
            pos_caption_cos_sim.extend(pos_caption_similarity.tolist())

        # Encode distractor images
        if len(qa['img_negFacts']) > 0:
            neg_caption_tokens   = clip.tokenize( [txt['caption'][:60] for txt in qa['img_negFacts']] )
            neg_caption_features = model.encode_text(neg_caption_tokens)
            neg_caption_features /= neg_caption_features.norm(dim=-1, keepdim=True)
            
            # Compute similarity
            neg_caption_similarity = torch.matmul(question_features, neg_caption_features.T).squeeze(0)
            neg_caption_cos_sim.extend(neg_caption_similarity.tolist())

        batch_bar.set_postfix(Avg_pos="{:.3f}".format(np.mean(pos_caption_cos_sim)),
                              Avg_neg="{:.3f}".format(np.mean(neg_caption_cos_sim)))
        batch_bar.update()

batch_bar.close()

### Title and question

In [None]:
pos_title_cos_sim = []
neg_title_cos_sim = []

total_qa = 1000
batch_bar = tqdm(total=total_qa, dynamic_ncols=True, desc='Title-Q Similarity') 

with torch.no_grad():
    for i, (key, qa) in enumerate(json_dict.items()):
        if i == total_qa:
            break
        # Encode question
        question          = qa['Q']
        question_tokens   = clip.tokenize([question])
        question_features = model.encode_text(question_tokens).float()
        question_features /= question_features.norm(dim=-1, keepdim=True)

        # Encode positive images
        if len(qa['img_posFacts']) > 0:
            pos_title_tokens   = clip.tokenize( [txt['title'][:60] for txt in qa['img_posFacts']] )
            pos_title_features = model.encode_text(pos_title_tokens)
            pos_title_features /= pos_title_features.norm(dim=-1, keepdim=True)

            # Compute similarity
            pos_title_similarity = torch.matmul(question_features, pos_title_features.T).squeeze(0)
            pos_title_cos_sim.extend(pos_title_similarity.tolist())

        # Encode distractor images
        if len(qa['img_negFacts']) > 0:
            neg_title_tokens   = clip.tokenize( [txt['title'][:60] for txt in qa['img_negFacts']] )
            neg_title_features = model.encode_text(neg_caption_tokens)
            neg_title_features /= neg_title_features.norm(dim=-1, keepdim=True)
            
            # Compute similarity
            neg_title_similarity = torch.matmul(question_features, neg_title_features.T).squeeze(0)
            neg_title_cos_sim.extend(neg_title_similarity.tolist())

        batch_bar.set_postfix(Avg_pos="{:.3f}".format(np.mean(pos_title_cos_sim)),
                              Avg_neg="{:.3f}".format(np.mean(neg_title_cos_sim)))
        batch_bar.update()

batch_bar.close()

### Images and question

In [None]:
pos_image_cos_sim = []
neg_image_cos_sim = []

total_qa = 1000
batch_bar = tqdm(total=total_qa, dynamic_ncols=True, desc='Image-Q Similarity') 

img_shape = (224,224,3)
with torch.no_grad():
    for i, (key, qa) in enumerate(data.items()):
        if i == total_qa:
            break
        # Encode question
        question          = qa['Q']
        question_tokens   = clip.tokenize([question])
        question_features = model.encode_text(question_tokens).float()
        question_features /= question_features.norm(dim=-1, keepdim=True)

        # Encode positive images
        if len(qa['ImgPos']) > 0:
            # Get indices where shape is correct
            idx = np.where([np.shape(img)==img_shape for img in qa['ImgPos']])[0]
            if len(idx) > 0:
                pos_image_input    = torch.tensor(np.stack([ qa['ImgPos'][i] for i in idx])).permute([0,3,1,2])
                pos_image_features = model.encode_image(pos_image_input).float()
                pos_image_features /= pos_image_features.norm(dim=-1, keepdim=True)

                # Compute similarity
                pos_image_similarity = torch.matmul(question_features, pos_image_features.T).squeeze(0)
                pos_image_cos_sim.extend(pos_image_similarity.tolist())

        # Encode distractor images
        if len(qa['ImgNeg']) > 0:
            # Get indices where shape is correct
            idx = np.where([np.shape(img)==img_shape for img in qa['ImgNeg']])[0]
            if len(idx) > 0:
                neg_image_input    = torch.tensor(np.stack([ qa['ImgNeg'][i] for i in idx])).permute([0,3,1,2])
                neg_image_features = model.encode_image(neg_image_input).float()
                neg_image_features /= neg_image_features.norm(dim=-1, keepdim=True)
            
                # Compute similarity
                neg_image_similarity = torch.matmul(question_features, neg_image_features.T).squeeze(0)
                neg_image_cos_sim.extend(neg_image_similarity.tolist())

        batch_bar.set_postfix(Avg_pos="{:.3f}".format(np.mean(pos_image_cos_sim)), Avg_neg="{:.3f}".format(np.mean(neg_image_cos_sim)))
        batch_bar.update()

batch_bar.close()

## Similarity histograms

### Images and question

In [None]:
sns.set(style="darkgrid")

fig, ax = plt.subplots()
sns.histplot(pos_image_cos_sim, bins=30, color=[4/255.0,   101/255.0, 130/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
sns.histplot(neg_image_cos_sim, bins=30, color=[243.0/255, 145.0/255, 137/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
ax.set_xlabel("Cosine similarity with the question")
plt.setp(ax.patches, linewidth=0.2);
ax.legend(["Positive images", "Negative images"])
plt.savefig('image_question_cos_distance_subset.pdf', transparent=False)

### Texts and question

In [None]:
sns.set(style="darkgrid")

fig, ax = plt.subplots()
sns.histplot(pos_text_cos_sim, bins=30, color=[4/255.0,   101/255.0, 130/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
sns.histplot(neg_text_cos_sim, bins=40, color=[243.0/255, 145.0/255, 137/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
ax.set_xlabel("Cosine similarity with the question")
plt.setp(ax.patches, linewidth=0.2);
ax.legend(["Positive texts", "Negative texts"])
plt.savefig('text_question_cos_distance_subset.pdf', transparent=False)

### Captions and question

In [None]:
sns.set(style="darkgrid")

fig, ax = plt.subplots()
sns.histplot(pos_caption_cos_sim, bins=24, color=[4/255.0,   101/255.0, 130/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
sns.histplot(neg_caption_cos_sim, bins=32, color=[243.0/255, 145.0/255, 137/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
ax.set_xlabel("Cosine similarity with the question")
plt.setp(ax.patches, linewidth=0.2);
ax.legend(["Positive image captions", "Negative image captions"])
plt.savefig('caption_question_cos_distance_subset.pdf', transparent=False)

### Titles and question

In [None]:
sns.set(style="darkgrid")

fig, ax = plt.subplots()
sns.histplot(pos_title_cos_sim, bins=38, color=[4/255.0,   101/255.0, 130/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
sns.histplot(neg_title_cos_sim, bins=55, color=[243.0/255, 145.0/255, 137/255.0], alpha=0.6, ax=ax, kde=True, stat='density')
ax.set_xlabel("Cosine similarity with the question")
plt.setp(ax.patches, linewidth=0.2);
ax.legend(["Positive image titles", "Negative image titles"])
plt.savefig('title_question_cos_distance_subset.pdf', transparent=False)