In [1]:
import torch
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from tqdm import tqdm
import json
import numpy as np
import faiss  
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#CLIP
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [3]:
#FAISS
d = 512
index = faiss.IndexFlatIP(d)

In [4]:
with open('datasets/LaSCo/lasco_val.json', 'r') as file:
    lasco_val = json.load(file)
with open('datasets/LaSCo/lasco_val_corpus.json', 'r') as file:
    lasco_val_corpus = json.load(file)

In [5]:
len(lasco_val)

30037

In [6]:
len(lasco_val_corpus)

39826

In [7]:
faiss_index_2_image_map = {}

In [8]:
#Indexing Valiudation Corpus
fi = 0
for record in tqdm(lasco_val_corpus):
    faiss_index_2_image_map.update({fi: record['id']})
    img_name = record['path'].split('/')[-1]
    img_path = os.path.join('datasets', 'LaSCo', 'coco', 'val2014', img_name)
    image = Image.open(img_path).convert("RGB")
    inputs = processor(text=[''], images=[image], return_tensors="pt", padding=True).to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)
        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds
    index.add(image_embeds.detach().cpu().numpy())
    fi+=1

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39826/39826 [14:17<00:00, 46.46it/s]


In [9]:
## Write index and dictionary to to disk
with open('lasco_val_faiss_index_2_image_map.json', 'w') as file:
    json.dump(faiss_index_2_image_map, file)

faiss.write_index(index, 'lasco_val_faiss_index.bin')

In [28]:
"""
## Read index and dictionary to to disk
with open('lasco_val_faiss_index_2_image_map.json', 'r') as file:
    faiss_index_2_image_map = json.load(file)
"""
index = faiss.read_index('lasco_val_faiss_index.bin')


In [10]:
lasco_val_evaluation_results = []

In [33]:
#cntr = 0
## Evaluation Loop
for record in tqdm(lasco_val):
    qid = record['qid']
    query_image = record['query-image'][1].split('/')[-1]
    target_image = record['target-image'][1].split('/')[-1]
    mod_text = record['query-text']

    query_image_name = record['query-image'][1].split('/')[-1]
    query_image_path = os.path.join('datasets', 'LaSCo', 'coco', 'val2014', query_image_name)
    target_image_name = record['target-image'][1].split('/')[-1]
    target_image_path = os.path.join('datasets', 'LaSCo', 'coco', 'val2014', target_image_name)
    mod_text = record['query-text']

    query_image = Image.open(query_image_path).convert("RGB")
    target_image = Image.open(target_image_path).convert("RGB")

    inputs = processor(text=[mod_text, ''], images=[query_image, target_image], return_tensors="pt", padding=True).to("cuda")

    with torch.no_grad():
        outputs = model(**inputs)
        image_embeds = outputs.image_embeds
        text_embeds = outputs.text_embeds

    target_hat = (image_embeds[0] + text_embeds[0]).unsqueeze(0)
    target_hat = F.normalize(target_hat, p=2.0, dim = 1)
    target_hat = target_hat.detach().cpu().numpy()

    _, indices = index.search(target_hat, k=50)
    indices = list(map(lambda x: faiss_index_2_image_map[x], list(indices[0, :])))

    #print(image_embeds[0])
    #print(text_embeds[0])
    #print(target_hat[0])
    #print("----------------------------------------------------------------------------------------------------------------")

    #if cntr == 20:
    #    break

    
    lasco_val_evaluation_results.append(
        {
            'qid': qid,
            'target_img_id': record['target-image'][0],
            'top_50_retrieved': indices
        }
    )

  0%|▎                                                                                                                                                        | 66/30037 [00:04<31:58, 15.62it/s]

KeyboardInterrupt



In [12]:
## Save predictions to file
with open('lasco_val_retrieved_candidates.pkl', 'wb') as file:
    pickle.dump(lasco_val_evaluation_results, file)

In [13]:
lasco_val_evaluation_results[:1]

[{'qid': 318114001,
  'target_img_id': 306889,
  'top_50_retrieved': [318114,
   488645,
   221911,
   231527,
   237277,
   304984,
   123511,
   550862,
   438999,
   434148,
   13145,
   135460,
   136772,
   581829,
   114871,
   467457,
   311041,
   309279,
   512912,
   320972,
   491062,
   301912,
   21397,
   119414,
   235575,
   248912,
   351609,
   577868,
   507318,
   168898,
   549887,
   223384,
   529966,
   100132,
   521540,
   402783,
   577451,
   463555,
   569103,
   136911,
   124135,
   22090,
   405569,
   192457,
   158333,
   234934,
   37958,
   28864,
   75888,
   96327]}]

In [14]:
len(lasco_val_evaluation_results)

30037

In [15]:
recall_counts = [0, 0, 0, 0]

In [17]:
for record in lasco_val_evaluation_results:
    #items = [faiss_index_2_image_map[i] for i in record['top_50_retrieved']]
    try:
        index = record['top_50_retrieved'].index(record['target_img_id'])
        if index < 1:
            recall_counts[0]+=1
        if index < 5:
            recall_counts[1]+=1
        if index < 10:
            recall_counts[2]+=1
        if index < 50:
            recall_counts[3]+=1
    except:
        continue

In [18]:
recall_counts

[1, 1605, 2623, 6736]

In [19]:
print("Average recall Top-1 = {} %".format(100*recall_counts[0]/len(lasco_val_evaluation_results)))
print("Average recall Top-5 = {} %".format(100*recall_counts[1]/len(lasco_val_evaluation_results)))
print("Average recall Top-10 = {} %".format(100*recall_counts[2]/len(lasco_val_evaluation_results)))
print("Average recall Top-50 = {} %".format(100*recall_counts[3]/len(lasco_val_evaluation_results)))

Average recall Top-1 = 0.0033292272863468387 %
Average recall Top-5 = 5.343409794586677 %
Average recall Top-10 = 8.732563172087758 %
Average recall Top-50 = 22.425675000832307 %
