# Adapting ViQuAE images for RRT
## Imports

In [1]:
import json
import numpy as np
from datasets import load_from_disk, set_caching_enabled
set_caching_enabled(False)

## Loading Data

In [2]:
dataset = load_from_disk("data/viquae_dataset")

In [3]:
kb = load_from_disk('data/viquae_passages/')

In [4]:
wiki = load_from_disk('data/viquae_wikipedia')

In [5]:
train_set, dev_set, test_set = dataset['train'], dataset['validation'], dataset['test']

In [6]:
humans_with_faces, humans_without_faces, non_humans = wiki['humans_with_faces'], wiki['humans_without_faces'], wiki['non_humans']

## Passage to Article Mapping

In [7]:
f = open('data/viquae_wikipedia/non_humans/article2passage.json')
n_h_article2passage = json.load(f)
f.close()                                                              

In [8]:
n_h_passage2article = {}
for k, v in n_h_article2passage.items(): 
    for x in v: 
        n_h_passage2article[x] = k 

len_n_h = len(n_h_passage2article)
len_n_h

7175529

In [9]:
f = open('data/viquae_wikipedia/humans_without_faces/article2passage.json')
h_wo_f_article2passage = json.load(f)                                  
f.close()

In [10]:
h_wo_f_passage2article = {} 
for k, v in h_wo_f_article2passage.items(): 
    for x in v: 
        h_wo_f_passage2article[x] = k 

len_h_wo_f = len(h_wo_f_passage2article)
len_h_wo_f

298698

In [11]:
f = open('data/viquae_wikipedia/humans_with_faces/article2passage.json')                                                                      
h_w_f_article2passage = json.load(f)                                   
f.close()                  

In [12]:
h_w_f_passage2article = {} 
for k, v in h_w_f_article2passage.items(): 
    for x in v: 
        h_w_f_passage2article[x] = k

len_h_w_f = len(h_w_f_passage2article)
len_h_w_f

4411741

In [13]:
len_n_h + len_h_w_f + len_h_wo_f == 11885968

True

In [14]:
passage2article = {**h_w_f_passage2article, **h_wo_f_passage2article, **n_h_passage2article}
len(passage2article)

11885968

In [15]:
with open("data/viquae_wikipedia/passage2article.json", "w") as outfile:
    json.dump(passage2article, outfile)

## Some Exploration

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
        num_rows: 1190
    })
    validation: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50

In [17]:
train_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
    num_rows: 1190
})

In [18]:
dev_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
    num_rows: 1250
})

In [19]:
test_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_BM25_indices', 'document_BM25_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
    num_rows: 1257
})

In [20]:
wiki

DatasetDict({
    non_humans: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 953379
    })
    humans_with_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'history', 'image', 'image_embedding', 'image_hash', 'keep_face_embedding', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 506237
    })
    humans_without_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_embedding', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 35736
    })
})

In [21]:
item = train_set[0]
#print(item['provenance_indices'] in )
len(item['original_answer_provenance_indices']), len(item['provenance_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])
len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])

(0, 100)

In [22]:
h_w_f_passage_split  = dict(zip(h_w_f_passage2article.keys(),  ['humans_with_faces'] * len_h_w_f))
h_wo_f_passage_split = dict(zip(h_wo_f_passage2article.keys(), ['humans_without_faces'] * len_h_wo_f))
n_h_passage_split    = dict(zip(n_h_passage2article.keys(),    ['non_humans'] * len_n_h))

passage_wiki_split = {**h_w_f_passage_split, **h_wo_f_passage_split, **n_h_passage_split}
len(passage_wiki_split)

11885968

In [23]:
with open("data/viquae_wikipedia/passage_wiki_split.json", "w") as outfile:
    json.dump(passage_wiki_split, outfile)

In [24]:
item = train_set[2]
len(item['search_provenance_indices']), len(item['search_irrelevant_indices']), len(list(set(item['search_provenance_indices']))), len(list(set(item['search_irrelevant_indices'])))

(22, 78, 22, 78)

In [25]:
def get_gallery_imgs_for_rrt(data_set,
                             wikipedia,
                             passage2article,
                             passage_wiki_split):
    
    question_imgs, positive_imgs, negative_imgs = [], [], []
    
    # for every question, get the list of the top 100 search results
    for item in data_set:
        # item_relevant_passages = list(set(item['search_indices']) - set(item['search_irrelevant_indices']))

        # append the question image
        question_imgs.append(item['image'])
        
        def loop_over_passages(passages):
            
            img_list = []
            
            for passage in passages: 
                # for every passage, get the list its corresponding wikipedia article id and split
                wiki_index = int(passage2article[passage])
                
                wiki_split = passage_wiki_split[passage]
                
                wiki_item = wikipedia[wiki_split][wiki_index]

                img_list.append(wiki_item['image'])
                
            return img_list
        
        
        # append the images of passages containing the answer
        positive_imgs.extend(loop_over_passages(item['search_provenance_indices']))
        
        # append the images of irrelevant passages
        negative_imgs.extend(loop_over_passages(item['search_irrelevant_indices']))
        
    return question_imgs, positive_imgs, negative_imgs
     

In [26]:
train_question_imgs, train_positive_imgs, train_negative_imgs = get_gallery_imgs_for_rrt(train_set, 
                                                                       wiki, 
                                                                       passage2article, 
                                                                       passage_wiki_split)

In [27]:
dev_question_imgs, dev_positive_imgs, dev_negative_imgs = get_gallery_imgs_for_rrt(dev_set, 
                                                                       wiki, 
                                                                       passage2article, 
                                                                       passage_wiki_split)

In [28]:
test_question_imgs, test_positive_imgs, test_negative_imgs = get_gallery_imgs_for_rrt(test_set, 
                                                                       wiki, 
                                                                       passage2article, 
                                                                       passage_wiki_split)

In [29]:
question_imgs = train_question_imgs + dev_question_imgs + test_question_imgs
positive_imgs = train_positive_imgs + dev_positive_imgs + test_positive_imgs
negative_imgs = train_negative_imgs + dev_negative_imgs + test_negative_imgs
len(question_imgs), len(positive_imgs), len(negative_imgs), 369700, (len(positive_imgs) + len(negative_imgs))

(3697, 46866, 333623, 369700, 380489)

In [30]:
len(list(set(train_question_imgs))), len(list(set(dev_question_imgs))), len(list(set(test_question_imgs)))

(1105, 1108, 1105)

In [31]:
len(list(set(train_positive_imgs))), len(list(set(dev_positive_imgs))), len(list(set(test_positive_imgs)))

(7034, 6114, 6644)

In [32]:
len(list(set(train_negative_imgs))), len(list(set(dev_negative_imgs))), len(list(set(test_negative_imgs)))

(37613, 39935, 40529)

In [33]:
len(list(set(question_imgs))), len(list(set(positive_imgs))), len(list(set(negative_imgs)))

(3318, 17155, 89554)

In [34]:
entire_dataset_imgs = list(set(list(set(question_imgs)) + list(set(positive_imgs)) + list(set(negative_imgs))))
len(entire_dataset_imgs)

99954

In [35]:
np.savetxt('data/entire_dataset_imgs.txt', entire_dataset_imgs, fmt="%s")

## GroundTruth Generation

In [36]:
import numpy as np
import os.path as osp
import pickle, json, random

import torch

In [37]:
def pickle_load(path):
    with open(path, 'rb') as fid:
        data_ = pickle.load(fid)
    return data_

In [38]:
def pickle_save(path, data):
    with open(path, 'wb') as fid:
        pickle.dump(data, fid)

In [39]:
oxford_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/oxford5k/gnd_roxford5k.pkl"

In [40]:
oxford_gnd = pickle_load(oxford_gnd_file)

In [41]:
oxford_gnd.keys()

dict_keys(['gnd', 'imlist', 'qimlist'])

In [42]:
oxford_gnd['gnd'][0].keys(), len(oxford_gnd.keys())

(dict_keys(['bbx', 'easy', 'hard', 'junk']), 3)

In [43]:
(len(oxford_gnd['gnd'][0]['easy']), len(oxford_gnd['gnd'][0]['hard']), len(oxford_gnd['gnd'][0]['junk']))
np.sum((len(oxford_gnd['gnd'][0]['easy']), len(oxford_gnd['gnd'][0]['hard'])))

103

In [44]:
def prepare_gnd_selection_images_for_rrt(data_set,
                                         wikipedia,
                                         passage2article,
                                         passage_wiki_split):
    
    selection_imgs  = []
    
    # for every question, get the list images corresponding to the top 100 search results
    for item in data_set:
        
        img_list = []
        
        for passage in item['search_indices']: 
            wiki_index = int(passage2article[passage])
            wiki_split = passage_wiki_split[passage]
            wiki_item = wikipedia[wiki_split][wiki_index]

            img_list.append(wiki_item['image'])
                
        selection_imgs.append(img_list)
    
    return selection_imgs
     

In [45]:
train_selection_imgs = prepare_gnd_selection_images_for_rrt(train_set, 
                                                            wiki, 
                                                            passage2article, 
                                                            passage_wiki_split)

In [46]:
np.array(train_selection_imgs).shape

(1190, 100)

In [47]:
len(list(set(question_imgs))), len(list(set(positive_imgs))), len(list(set(negative_imgs)))

(3318, 17155, 89554)

In [48]:
train_gnd = {}
train_gnd['qimlist'] = train_question_imgs
train_gnd['imlist']  = train_question_imgs + train_positive_imgs + train_negative_imgs
train_gnd['simlist'] = train_selection_imgs

In [49]:
train_gnd['imlist']  = list(set(train_gnd['imlist']))
#train_gnd['qimlist'] = list(set(train_gnd['qimlist']))

In [50]:
dev_gnd = {}
dev_gnd['qimlist'] = dev_question_imgs
dev_gnd['imlist']  = dev_question_imgs + dev_positive_imgs + dev_negative_imgs
dev_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(dev_set, 
                                                          wiki, 
                                                          passage2article, 
                                                          passage_wiki_split)

In [51]:
dev_gnd['imlist']  = list(set(dev_gnd['imlist']))
#dev_gnd['qimlist'] = list(set(dev_gnd['qimlist']))

In [52]:
test_gnd = {}
test_gnd['qimlist'] = test_question_imgs
test_gnd['imlist']  = test_question_imgs + test_positive_imgs + test_negative_imgs
test_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(test_set, 
                                                           wiki, 
                                                           passage2article, 
                                                           passage_wiki_split)

In [53]:
test_gnd['imlist']  = list(set(test_gnd['imlist']))
#test_gnd['qimlist'] = list(set(test_gnd['qimlist']))

In [54]:
len(dev_gnd['simlist']), len(dev_gnd['qimlist']), len(dev_gnd['imlist']),

(1250, 1250, 44325)

In [55]:
#viquae_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_images/gnd_dev_viquae.pkl"

In [56]:
#viquae_gnd = pickle_load(viquae_gnd_file)

In [57]:
#viquae_gnd['gnd'][0].keys()

In [115]:
item = train_set[7]

print(item['search_indices'] + item['search_provenance_indices'])

[2622393, 8814396, 8646819, 1562495, 1562496, 1281262, 4696984, 8646838, 1280203, 8209050, 1281255, 3803785, 11138918, 5668626, 5873302, 6606045, 6606046, 6606047, 1314171, 462157, 4697063, 3586842, 462106, 4523355, 341683, 4697011, 8646833, 1922175, 11347466, 7660524, 8073939, 11279406, 11279407, 11279408, 11279409, 11279410, 11279411, 11279412, 11279413, 11279414, 11279415, 11279416, 11279417, 11279418, 11279419, 11279420, 11279421, 11279422, 11279423, 11279424, 11279425, 11279426, 11279427, 11279428, 11279429, 11279430, 11279431, 11279432, 11279433, 11279434, 11279435, 7824223, 3290460, 5873294, 3698059, 4697068, 10992510, 4148951, 3803752, 4696998, 3665117, 3665119, 3665120, 3665121, 3665118, 8646820, 6971366, 11282428, 10795668, 2698285, 4788496, 5115289, 10269213, 3233235, 8814398, 4788576, 462088, 7614546, 7614547, 7614548, 7614549, 7614550, 1922146, 462090, 11138920, 1238729, 4223041, 9576926, 5873295, 1922171, 5867811, 5867827]


In [116]:
print(item['search_provenance_indices'])

[5867811, 5867827]


In [155]:
def prepare_gnd_for_rrt(data_set,
                        wikipedia,
                        passage2article,
                        passage_wiki_split):
    
    # ranks = []
    data_gnd = []
    
    # for every question, get the list of the top 100 search results
    for item in data_set:
        if set(item['search_provenance_indices']) <= set(item['search_indices']):
            all_indices = sorted(item['search_indices'])
        else: all_indices = sorted(item['search_indices']) + item['search_provenance_indices']
        
        rank_dict = {all_indices [k]: k for k in range(len(all_indices))}
        #print(rank_dict)
        
        def loop_over_passages(passages):
            
            img_ranks = []
            
            for passage in passages:
                img_ranks.append(rank_dict[passage])
                
            return img_ranks
        
        question_gnd = {}
        question_gnd['easy'] = []
        question_gnd['hard'] = loop_over_passages(item['search_provenance_indices'])
        question_gnd['junk'] = loop_over_passages(item['search_irrelevant_indices'])
        
        data_gnd.append(question_gnd)
        
    return data_gnd
         

In [156]:
train_gnd['gnd'] = prepare_gnd_for_rrt(train_set,
                                       wiki,
                                       passage2article,
                                       passage_wiki_split)

In [157]:
dev_gnd['gnd'] = prepare_gnd_for_rrt(dev_set,
                                     wiki,
                                     passage2article,
                                     passage_wiki_split)

In [158]:
test_gnd['gnd'] = prepare_gnd_for_rrt(test_set,
                                      wiki,
                                      passage2article,
                                      passage_wiki_split)

In [159]:
train_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_train.pkl"
pickle_save(train_gnd_file, train_gnd)

In [160]:
dev_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_dev.pkl"
pickle_save(dev_gnd_file, dev_gnd)

In [161]:
test_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_test.pkl"
pickle_save(test_gnd_file, test_gnd)

In [162]:
test_gnd['gnd'][0]['hard']

[44, 16, 36]

In [68]:
number_str = str(123)
zero_filled_number = number_str.zfill(5)
zero_filled_number

'00123'