### Imports

In [None]:
import json
import os

from pathlib import Path
from pprint import pprint

import h5py
import matplotlib.pyplot as plt
import numpy as np

from mpl_toolkits.axes_grid1 import ImageGrid
%matplotlib inline

### Loading the JSONs for Comparison

In [None]:
with open("clip_1k_test.json", "r") as f:
    clip_results = json.load(f)
    
with open("clip_retrievals_for_lxmert.json", "r") as f:
    clip_retrievals = json.load(f)
    
with open("lxmert_1k_test.json", "r") as f:
    lxmert_results = json.load(f)
    
with open("lxmert_retrievals_for_clip.json", "r") as f:
    lxmert_retrievals = json.load(f)
    
with open("comparison.json", "r") as f:
    compare = json.load(f)

### Get the image keys to properly display the images

In [None]:
with open(Path(os.environ['WORK_BASE']) / "datasets/coco_ir/test/test_img_keys_1k.tsv", "r") as f:
    test_img_keys = f.readlines()
    
test_img_keys = [k.strip() for k in test_img_keys]

### Opening the `.h5` file with all of the test images

In [None]:
test_imgs = h5py.File(Path(os.environ['WORK_BASE']) / "datasets/coco_ir/test/test_imgs.h5", "r")

### Converting the image indices to image IDs

In [None]:
clip_i2t_hard = clip_results['i2t']['hard']
for query in clip_i2t_hard:
    query['query_key'] = test_img_keys[query['query']]

In [None]:
lxmert_i2t_hard = lxmert_results['i2t']['hard']
for query in lxmert_i2t_hard:
    query['query_key'] = test_img_keys[query['query']]

In [None]:
clip_t2i_hard = clip_results['t2i']['hard']
for query in clip_t2i_hard:
    query['ground_truth'] = test_img_keys[query.pop('ground_truth')]
    query['retrieved'] = list(map(lambda x: test_img_keys[x], query.pop('retrieved')))

In [None]:
lxmert_t2i_hard = lxmert_results['t2i']['hard']
for query in lxmert_t2i_hard:
    query['ground_truth'] = test_img_keys[query.pop('ground_truth')]
    query['retrieved'] = list(map(lambda x: test_img_keys[x], query.pop('retrieved')))

## "Hard" Image-Based Text Retrieval

In [None]:
print('*' * 70)
print('CLIP Image -> Text (HARD)')
print('*' * 70)

for hard_query in clip_i2t_hard:
    img_query = test_imgs[hard_query['query_key']][()].astype(int)
    plt.imshow(img_query)
    plt.axis('off')
    plt.show()
    pprint(hard_query['ground_truth'])
    pprint(hard_query['retrieved'])

In [None]:
print('*' * 70)
print('LXMERT Image -> Text (HARD)')
print('*' * 70)

for hard_query in lxmert_i2t_hard:
    img_query = test_imgs[hard_query['query_key']][()].astype(int)
    plt.imshow(img_query)
    plt.axis('off')
    plt.show()
    pprint(hard_query['ground_truth'])
    pprint(hard_query['retrieved'])

## "Hard" Image-Based Text Retrieval

In [None]:
print('*' * 70)
print('CLIP Text -> Image (HARD)')
print('*' * 70)

# Displays the hard textual queries for image retrieval for CLIP
for hard_query in clip_t2i_hard:
    cap_query = hard_query['query']
    print(cap_query)
    img_ground_truth = test_imgs[hard_query['ground_truth']][()].astype(int)
    plt.imshow(img_ground_truth)
    plt.axis('off')
    plt.show()
    
    fig = plt.figure(figsize=(20, 20))
    grid = ImageGrid(fig, 111, 
                     nrows_ncols=(2, 5),  # creates 2x5 grid of axes
                     axes_pad=0.1,  # pad between axes
                     )
    
    for ax, im in zip(grid, hard_query['retrieved']):
        img = test_imgs[im][()].astype(int)
        ax.imshow(img)
        ax.axis('off')
        
    plt.show()

For the examples in the paper, we analyzed both retrieval results for each "hard" query. Whereas with text retrieval the results were able to be interpreted immediately (since the retrievals were already converted to the captions), we need to actually plot the im

In [None]:
print('*' * 70)
print('LXMERT retrieval for CLIP Text -> Image (HARD)')
print('*' * 70)

# Displays the hard textual queries for image retrieval for LXMERT
for hard_query in lxmert_retrievals["t2i"]:
    cap_query = hard_query['query']
    print(cap_query)
    
    fig = plt.figure(figsize=(20, 20))
    grid = ImageGrid(fig, 111, 
                     nrows_ncols=(2, 5),  # creates 2x5 grid of axes
                     axes_pad=0.1,  # pad between axes
                     )
    
    for ax, im in zip(grid, hard_query['retrieved']):
        img = test_imgs[im][()].astype(int)
        ax.imshow(img)
        ax.axis('off')
        
    plt.show()

In [None]:
print('*' * 70)
print('LXMERT Text -> Image (HARD)')
print('*' * 70)

for hard_query in lxmert_t2i_hard:
    cap_query = hard_query['query']
    print(cap_query)
    img_ground_truth = test_imgs[hard_query['ground_truth']][()].astype(int)
    plt.imshow(img_ground_truth)
    plt.axis('off')
    plt.show()
    
    fig = plt.figure(figsize=(20, 20))
    grid = ImageGrid(fig, 111, 
                     nrows_ncols=(2, 5),  # creates 2x2 grid of axes
                     axes_pad=0.1,  # pad between axes
                     )
    
    for ax, im in zip(grid, hard_query['retrieved']):
        img = test_imgs[im][()].astype(int)
        ax.imshow(img)
        ax.axis('off')
        
    plt.show()

In [None]:
print('*' * 70)
print('CLIP retrieval for LXMERT Text -> Image (HARD)')
print('*' * 70)

for hard_query in clip_retrievals["t2i"]:
    cap_query = hard_query['query']
    print(cap_query)
    
    fig = plt.figure(figsize=(20, 20))
    grid = ImageGrid(fig, 111, 
                     nrows_ncols=(2, 5),  # creates 2x5 grid of axes
                     axes_pad=0.1,  # pad between axes
                     )
    
    for ax, im in zip(grid, hard_query['retrieved']):
        img = test_imgs[im][()].astype(int)
        ax.imshow(img)
        ax.axis('off')
        
    plt.show()

## When does CLIP perform better than the fine-tuned LXMERT?

In [None]:
for better_caption in compare["i2t"]:
    img = test_imgs[better_caption["query"]][()].astype(int)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    print(f"CLIP: [{better_caption['clip']}], LXMERT: [{better_caption['lxmert']}]")

In [None]:
for better_image in compare["t2i"]:
    print(better_image['query'])
    print(f"CLIP: [{better_image['clip']}], LXMERT: [{better_image['lxmert']}]\n")