# ViQuAE Selection Images for RRT
## Imports

In [1]:
import json
import numpy as np
from datasets import load_from_disk, set_caching_enabled
set_caching_enabled(False)

## Loading Data

In [2]:
dataset = load_from_disk("data/viquae_dataset")

In [3]:
kb = load_from_disk('data/viquae_passages/')

In [4]:
wiki = load_from_disk('data/viquae_wikipedia')

In [5]:
train_set, dev_set, test_set = dataset['train'], dataset['validation'], dataset['test']

## Passage to Article Mapping

In [6]:
f = open('data/viquae_wikipedia/non_humans/article2passage.json')
n_h_article2passage = json.load(f)
f.close()                                                              

In [7]:
n_h_passage2article = {}
for k, v in n_h_article2passage.items(): 
    for x in v: 
        n_h_passage2article[x] = k 

len_n_h = len(n_h_passage2article)
len_n_h

7175529

In [8]:
f = open('data/viquae_wikipedia/humans_without_faces/article2passage.json')
h_wo_f_article2passage = json.load(f)                                  
f.close()

In [9]:
h_wo_f_passage2article = {} 
for k, v in h_wo_f_article2passage.items(): 
    for x in v: 
        h_wo_f_passage2article[x] = k 

len_h_wo_f = len(h_wo_f_passage2article)
len_h_wo_f

298698

In [10]:
f = open('data/viquae_wikipedia/humans_with_faces/article2passage.json')                                                                      
h_w_f_article2passage = json.load(f)                                   
f.close()                  

In [11]:
h_w_f_passage2article = {} 
for k, v in h_w_f_article2passage.items(): 
    for x in v: 
        h_w_f_passage2article[x] = k

len_h_w_f = len(h_w_f_passage2article)
len_h_w_f

4411741

In [12]:
len_n_h + len_h_w_f + len_h_wo_f == 11885968

True

In [13]:
passage2article = {**h_w_f_passage2article, **h_wo_f_passage2article, **n_h_passage2article}
len(passage2article)

11885968

In [14]:
with open("data/viquae_wikipedia/passage2article.json", "w") as outfile:
    json.dump(passage2article, outfile)

In [15]:
humans_with_faces, humans_without_faces, non_humans = wiki['humans_with_faces'], wiki['humans_without_faces'], wiki['non_humans']

## Some Exploration

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
        num_rows: 1190
    })
    validation: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50

In [17]:
train_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
    num_rows: 1190
})

In [18]:
dev_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
    num_rows: 1250
})

In [32]:
test_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_BM25_indices', 'document_BM25_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id'],
    num_rows: 1257
})

In [33]:
h_w_f_passage_split  = dict(zip(h_w_f_passage2article.keys(),  ['humans_with_faces'] * len_h_w_f))
h_wo_f_passage_split = dict(zip(h_wo_f_passage2article.keys(), ['humans_without_faces'] * len_h_wo_f))
n_h_passage_split    = dict(zip(n_h_passage2article.keys(),    ['non_humans'] * len_n_h))

passage_wiki_split = {**h_w_f_passage_split, **h_wo_f_passage_split, **n_h_passage_split}
len(passage_wiki_split)

11885968

In [34]:
with open("data/viquae_wikipedia/passage_wiki_split.json", "w") as outfile:
    json.dump(passage_wiki_split, outfile)

In [35]:
wiki

DatasetDict({
    non_humans: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 953379
    })
    humans_with_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'history', 'image', 'image_embedding', 'image_hash', 'keep_face_embedding', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 506237
    })
    humans_without_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_embedding', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 35736
    })
})

In [36]:
item = train_set[110]
#print(item['provenance_indices'] in )
len(item['original_answer_provenance_indices']), len(item['provenance_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])

(4, 4, 1, 99)

In [37]:
item['search_indices'][0]

2488266

In [38]:
import numpy as np
import os.path as osp
import pickle, json, random

import torch

In [39]:
def pickle_load(path):
    with open(path, 'rb') as fid:
        data_ = pickle.load(fid)
    return data_

In [40]:
def pickle_save(path, data):
    with open(path, 'wb') as fid:
        pickle.dump(data, fid)

In [41]:
oxford_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/oxford5k/gnd_roxford5k.pkl"

In [42]:
oxford_gnd = pickle_load(oxford_gnd_file)

In [43]:
oxford_gnd.keys()

dict_keys(['gnd', 'imlist', 'qimlist'])

In [44]:
def get_selection_images_for_rrt(data_set,
                                 wikipedia,
                                 passage2article,
                                 passage_wiki_split):
    
    selection_imgs  = []
    
    # for every question, get the list images corresponding to the top 100 search results
    for item in data_set:
        
        img_list = []
        
        for passage in sorted(item['search_indices']): 
            wiki_index = int(passage2article[passage])
            wiki_split = passage_wiki_split[passage]
            wiki_item  = wikipedia[wiki_split][wiki_index]
            # wiki_img   = '.'.join((wiki_item['image'].split('.')[:-1]))
            wiki_img = wiki_item['image']
            img_list.append(wiki_img)
                
        selection_imgs.append(img_list)
    
    return selection_imgs
     

In [45]:
train_selection_imgs = get_selection_images_for_rrt(train_set, 
                                                    wiki, 
                                                    passage2article, 
                                                    passage_wiki_split)

In [46]:
np.array(train_selection_imgs).shape

(1190, 100)

In [47]:
test_selection_imgs = get_selection_images_for_rrt(test_set, 
                                                   wiki, 
                                                   passage2article, 
                                                   passage_wiki_split)

In [48]:
np.array(test_selection_imgs).shape

(1257, 100)

In [49]:
dev_selection_imgs = get_selection_images_for_rrt(dev_set, 
                                                  wiki, 
                                                  passage2article, 
                                                  passage_wiki_split)

In [50]:
np.array(dev_selection_imgs).shape

(1250, 100)

In [51]:
np.savetxt('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/train_selection_imgs.txt', train_selection_imgs, fmt="%s")

In [52]:
np.savetxt('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/test_selection_imgs.txt', test_selection_imgs, fmt="%s")

In [53]:
np.savetxt('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/dev_selection_imgs.txt', dev_selection_imgs, fmt="%s")

In [54]:
t_selection_imgs = np.genfromtxt('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/train_selection_imgs.txt', dtype='str')
t_selection_imgs.shape

(1190, 100)

In [55]:
number_str = str(123)
zero_filled_number = number_str.zfill(5)
zero_filled_number

'00123'