# Adapting ViQuAE images for RRT
## Imports

In [1]:
import json
import numpy as np
from datasets import load_from_disk, set_caching_enabled
set_caching_enabled(False)

## Loading Data

In [2]:
dataset = load_from_disk("data/viquae_dataset")

In [3]:
kb = load_from_disk('data/viquae_passages/')

In [4]:
wiki = load_from_disk('data/viquae_wikipedia')

In [5]:
train_set, dev_set, test_set = dataset['train'], dataset['validation'], dataset['test']

In [6]:
humans_with_faces, humans_without_faces, non_humans = wiki['humans_with_faces'], wiki['humans_without_faces'], wiki['non_humans']

## Passage to Article Mapping

In [7]:
f = open('data/viquae_wikipedia/non_humans/article2passage.json')
n_h_article2passage = json.load(f)
f.close()                                                              

In [8]:
n_h_passage2article = {}
for k, v in n_h_article2passage.items(): 
    for x in v: 
        n_h_passage2article[x] = k 

len_n_h = len(n_h_passage2article)
len_n_h

7175529

In [9]:
f = open('data/viquae_wikipedia/humans_without_faces/article2passage.json')
h_wo_f_article2passage = json.load(f)                                  
f.close()

In [10]:
h_wo_f_passage2article = {} 
for k, v in h_wo_f_article2passage.items(): 
    for x in v: 
        h_wo_f_passage2article[x] = k 

len_h_wo_f = len(h_wo_f_passage2article)
len_h_wo_f

298698

In [11]:
f = open('data/viquae_wikipedia/humans_with_faces/article2passage.json')                                                                      
h_w_f_article2passage = json.load(f)                                   
f.close()                  

In [12]:
h_w_f_passage2article = {} 
for k, v in h_w_f_article2passage.items(): 
    for x in v: 
        h_w_f_passage2article[x] = k

len_h_w_f = len(h_w_f_passage2article)
len_h_w_f

4411741

In [13]:
len_n_h + len_h_w_f + len_h_wo_f == 11885968

True

In [14]:
passage2article = {**h_w_f_passage2article, **h_wo_f_passage2article, **n_h_passage2article}
len(passage2article)

11885968

In [15]:
#with open("data/viquae_wikipedia/passage2article.json", "w") as outfile:
#    json.dump(passage2article, outfile)

## Some Exploration

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices'],
        num_rows: 1190
    })
    validation: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices'

In [17]:
train_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices'],
    num_rows: 1190
})

In [18]:
dev_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices'],
    num_rows: 1250
})

In [19]:
test_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_BM25_indices', 'document_BM25_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices'],
    num_rows: 1257
})

In [20]:
wiki

DatasetDict({
    non_humans: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 953379
    })
    humans_with_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'history', 'image', 'image_embedding', 'image_hash', 'keep_face_embedding', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 506237
    })
    humans_without_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_embedding', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 35736
    })
})

In [21]:
item = train_set[0]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])

(100, 0, 100)

In [22]:
item = train_set[47]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])

(100, 57, 100)

In [23]:
h_w_f_passage_split  = dict(zip(h_w_f_passage2article.keys(),  ['humans_with_faces'] * len_h_w_f))
h_wo_f_passage_split = dict(zip(h_wo_f_passage2article.keys(), ['humans_without_faces'] * len_h_wo_f))
n_h_passage_split    = dict(zip(n_h_passage2article.keys(),    ['non_humans'] * len_n_h))

passage_wiki_split = {**h_w_f_passage_split, **h_wo_f_passage_split, **n_h_passage_split}
len(passage_wiki_split)

11885968

In [24]:
#with open("data/viquae_wikipedia/passage_wiki_split.json", "w") as outfile:
#    json.dump(passage_wiki_split, outfile)

In [25]:
item = train_set[0]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices']), len(item['search_alternative_indices'])

(100, 0, 100, 2)

In [26]:
item = train_set[0]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices']), len(item['search_alternative_indices'])

(100, 0, 100, 2)

In [27]:
def get_gallery_imgs_for_rrt(data_set,
                             wikipedia,
                             passage2article,
                             passage_wiki_split):
    
    question_imgs, positive_imgs, negative_imgs = [], [], []
    
    # for every question, get the list of the top 100 search results
    iterat = 0
    for item in data_set:
        # item_relevant_passages = list(set(item['search_indices']) - set(item['search_irrelevant_indices']))

        
        if iterat >= 100:
            break
        
        # append the question image
        question_imgs.append(item['image'])
        
        def loop_over_passages(passages):
            
            img_list = []
            
            for passage in passages: 
                # for every passage, get the list its corresponding wikipedia article id and split
                wiki_index = int(passage2article[passage])
                
                wiki_split = passage_wiki_split[passage]
                
                wiki_item = wikipedia[wiki_split][wiki_index]

                img_list.append(wiki_item['image'])
                
            return img_list
        
        
        # append the images of passages containing the answer
        positive_imgs.extend(loop_over_passages(item['search_provenance_indices']))
        
        # append the images of irrelevant passages
        negative_imgs.extend(loop_over_passages(item['search_irrelevant_indices']))
        iterat += 1
        
    return question_imgs, positive_imgs, negative_imgs
     

In [28]:
##train_question_imgs, train_positive_imgs, train_negative_imgs = get_gallery_imgs_for_rrt(train_set, 
##                                                                       wiki, 
##                                                                       passage2article, 
##                                                                       passage_wiki_split)

In [29]:
##dev_question_imgs, dev_positive_imgs, dev_negative_imgs = get_gallery_imgs_for_rrt(dev_set, 
##                                                                       wiki, 
##                                                                       passage2article, 
##                                                                       passage_wiki_split)

In [30]:
##test_question_imgs, test_positive_imgs, test_negative_imgs = get_gallery_imgs_for_rrt(test_set, 
##                                                                       wiki, 
##                                                                       passage2article, 
##                                                                       passage_wiki_split)

In [31]:
##question_imgs = train_question_imgs + dev_question_imgs + test_question_imgs
##positive_imgs = train_positive_imgs + dev_positive_imgs + test_positive_imgs
##negative_imgs = train_negative_imgs + dev_negative_imgs + test_negative_imgs
##len(question_imgs), len(positive_imgs), len(negative_imgs), 369700, (len(positive_imgs) + len(negative_imgs))

In [32]:
##len(list(set(train_question_imgs))), len(list(set(dev_question_imgs))), len(list(set(test_question_imgs)))

In [33]:
##len(list(set(train_positive_imgs))), len(list(set(dev_positive_imgs))), len(list(set(test_positive_imgs)))

In [34]:
##len(list(set(train_negative_imgs))), len(list(set(dev_negative_imgs))), len(list(set(test_negative_imgs)))

In [35]:
##len(list(set(question_imgs))), len(list(set(positive_imgs))), len(list(set(negative_imgs)))

## GroundTruth Generation

In [36]:
import numpy as np
import os.path as osp
import pickle, json, random

import torch

In [37]:
def pickle_load(path):
    with open(path, 'rb') as fid:
        data_ = pickle.load(fid)
    return data_

In [39]:
def pickle_save(path, data):
    with open(path, 'wb') as fid:
        pickle.dump(data, fid)

In [38]:
### Code Added from here

In [40]:
#train_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_train.pkl"
#train_gnd = pickle_load(train_gnd_file)

In [41]:
#dev_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_dev.pkl"
#dev_gnd =  pickle_load(dev_gnd_file)

In [42]:
#test_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_test.pkl"
#test_gnd = pickle_load(test_gnd_file)

In [43]:
#tuto_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_tuto.pkl"
#tuto_gnd = pickle_load(tuto_gnd_file)

In [44]:
"""
def prepare_gnd_selection_images_for_rrt(data_set,
                                         wikipedia,
                                         passage2article,
                                         passage_wiki_split,
                                         tuto=False):
    
    selection_imgs  = []
    iterat = 0
    
    # for every question, get the list images corresponding to the top 100 search results
    for item in data_set:
        
        if iterat >= 100 and tuto:
            break
        
        img_list = []
        
        for passage in sorted(item['search_indices']): 
            wiki_index = int(passage2article[passage])
            wiki_split = passage_wiki_split[passage]
            wiki_item = wikipedia[wiki_split][wiki_index]

            img_list.append(wiki_item['image'])
                
        selection_imgs.append(img_list)
        iterat += 1
    
    return selection_imgs
"""     

"\ndef prepare_gnd_selection_images_for_rrt(data_set,\n                                         wikipedia,\n                                         passage2article,\n                                         passage_wiki_split,\n                                         tuto=False):\n    \n    selection_imgs  = []\n    iterat = 0\n    \n    # for every question, get the list images corresponding to the top 100 search results\n    for item in data_set:\n        \n        if iterat >= 100 and tuto:\n            break\n        \n        img_list = []\n        \n        for passage in sorted(item['search_indices']): \n            wiki_index = int(passage2article[passage])\n            wiki_split = passage_wiki_split[passage]\n            wiki_item = wikipedia[wiki_split][wiki_index]\n\n            img_list.append(wiki_item['image'])\n                \n        selection_imgs.append(img_list)\n        iterat += 1\n    \n    return selection_imgs\n"

In [45]:
##train_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(train_set, 
##                                                            wiki, 
##                                                            passage2article, 
##                                                            passage_wiki_split)

In [46]:
##dev_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(dev_set, 
##                                                          wiki, 
##                                                          passage2article, 
##                                                          passage_wiki_split)

In [47]:
##test_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(test_set, 
##                                                           wiki, 
##                                                           passage2article, 
##                                                           passage_wiki_split)

In [48]:
##train_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_train.pkl"
##pickle_save(train_gnd_file, train_gnd)

In [49]:
##dev_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_dev.pkl"
##pickle_save(dev_gnd_file, dev_gnd)

In [50]:
##test_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_test.pkl"
##pickle_save(test_gnd_file, test_gnd)

In [51]:
### Code Added end here

In [52]:
oxford_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/oxford5k/gnd_roxford5k.pkl"

In [53]:
oxford_gnd = pickle_load(oxford_gnd_file)

In [54]:
oxford_gnd.keys()

dict_keys(['gnd', 'imlist', 'qimlist'])

In [55]:
oxford_gnd['gnd'][0].keys(), len(oxford_gnd.keys())

(dict_keys(['bbx', 'easy', 'hard', 'junk']), 3)

In [56]:
(len(oxford_gnd['gnd'][0]['easy']), len(oxford_gnd['gnd'][0]['hard']), len(oxford_gnd['gnd'][0]['junk']))
np.sum((len(oxford_gnd['gnd'][0]['easy']), len(oxford_gnd['gnd'][0]['hard'])))

103

In [58]:
def get_gallery_imgs_for_rrt(data_set,
                             wikipedia,
                             passage2article,
                             passage_wiki_split,
                             tuto=False):
    
    question_imgs, positive_imgs, alternative_imgs, negative_imgs, selection_imgs = [], [], [], [], []
    
    # for every question, get the list of the top 100 search results
    iterat = 0
    for item in data_set:
        
        if iterat >= 100 and tuto:
            break
        
        # append the question image
        question_imgs.append(item['image'])
        
        def loop_over_passages(passages):
            
            img_list = []
            
            for passage in passages: 
                # for every passage, get the list its corresponding wikipedia article id and split
                wiki_index = int(passage2article[passage])
                
                wiki_split = passage_wiki_split[passage]
                
                wiki_item = wikipedia[wiki_split][wiki_index]

                img_list.append(wiki_item['image'])
                
            return img_list
        
        
        # append the images of passages containing the original answer
        original_answer_indices = item['search_provenance_indices']
        positive_imgs.append(loop_over_passages(original_answer_indices))
        
        # append the images of passages containing an alternative answer
        alternative_answer_indices = item['search_alternative_indices']
        alternative_imgs.append(loop_over_passages(alternative_answer_indices))
        
        # append the images of irrelevant passages
        irrelevant_indices = item['search_irrelevant_indices']
        negative_imgs.append(loop_over_passages(irrelevant_indices))
        
        # append the images of passages provided by IR search
        selection_imgs.append(loop_over_passages(item['search_indices']))
        
        iterat += 1
        
    return question_imgs, positive_imgs, alternative_imgs, negative_imgs, selection_imgs
     

In [59]:
train_questions, train_positives, train_alternatives, train_negatives, train_selections = get_gallery_imgs_for_rrt(
    train_set, wiki, passage2article, passage_wiki_split)

In [60]:
dev_questions, dev_positives, dev_alternatives, dev_negatives, dev_selections = get_gallery_imgs_for_rrt(
    dev_set, wiki, passage2article, passage_wiki_split)

In [61]:
test_questions, test_positives, test_alternatives, test_negatives, test_selections = get_gallery_imgs_for_rrt(
    test_set, wiki, passage2article, passage_wiki_split)

In [62]:
tuto_questions, tuto_positives, tuto_alternatives, tuto_negatives, tuto_selections = get_gallery_imgs_for_rrt(
    train_set, wiki, passage2article, passage_wiki_split, tuto=True)

In [63]:
def extend(a):
    out = []
    for sublist in a:
        out.extend(sublist)
    return out

In [64]:
def format_selection(selection):
    return [list(set(sub)) for sub in selection]

In [65]:
def format_gnd_for_rrt(question_imgs, positive_imgs, alternative_imgs, negative_imgs, selection_imgs):
    
    new_gnd = {}
    new_gnd['qimlist'] = question_imgs
    new_gnd['imlist']  = list(set(question_imgs + extend(positive_imgs) + extend(alternative_imgs) + extend(negative_imgs)))
    new_gnd['simlist'] = format_selection(selection_imgs)
    
    new_gnd_gnd =  []
    
    for i in range(len(question_imgs)):
        question_gnd = {}
        question_gnd['easy'] = list(set(positive_imgs[i]))
        question_gnd['hard'] = list(set(alternative_imgs[i]))
        question_gnd['junk'] = list(set(negative_imgs[i]) - set(positive_imgs[i]))
        question_gnd['neg']  = list(set(negative_imgs[i]) - set(positive_imgs[i])  - set(alternative_imgs[i]))
        question_gnd['provenance_entity']  = len(negative_imgs[i]) == 100
        new_gnd_gnd.append(question_gnd)
    
    new_gnd['gnd'] = new_gnd_gnd
    
    return new_gnd
     

In [66]:
var = len(tuto_negatives[47])==100
var

In [67]:
train_gnd = format_gnd_for_rrt(
    train_questions, train_positives, train_alternatives, train_negatives, train_selections)

In [68]:
dev_gnd = format_gnd_for_rrt(
    dev_questions, dev_positives, dev_alternatives, dev_negatives, dev_selections)

In [69]:
test_gnd = format_gnd_for_rrt(
    test_questions, test_positives, test_alternatives, test_negatives, test_selections)

In [70]:
tuto_gnd = format_gnd_for_rrt(
    tuto_questions, tuto_positives, tuto_alternatives, tuto_negatives, tuto_selections)

In [71]:
def selection_imgs_ranks_for_rrt(new_gnd):
    query_names     = new_gnd['qimlist']
    selection_names = new_gnd['simlist']
    
    for i in range(len(query_names)):
        query_all_names =  list(set(selection_names[i] + new_gnd['gnd'][i]['easy'] + new_gnd['gnd'][i]['hard']))
        img_rank_dict = {query_all_names[k]: k for k in range(len(query_all_names))}
        rank_img_dict = {k: query_all_names[k] for k in range(len(query_all_names))}
        
        new_gnd['gnd'][i]['img_rank_dict'] = img_rank_dict
        new_gnd['gnd'][i]['rank_img_dict'] = rank_img_dict
        
        def loop_over_imgs(images):
            
            img_ranks = []
            
            for img in images:
                img_ranks.append(img_rank_dict[img])
                
            return img_ranks
        
        new_gnd['gnd'][i]['r_easy'] = loop_over_imgs(new_gnd['gnd'][i]['easy'])
        new_gnd['gnd'][i]['r_hard'] = loop_over_imgs(new_gnd['gnd'][i]['hard'])
        new_gnd['gnd'][i]['r_junk'] = loop_over_imgs(new_gnd['gnd'][i]['junk'])
        new_gnd['gnd'][i]['r_neg']  = loop_over_imgs(new_gnd['gnd'][i]['neg'])
        
    return new_gnd

In [72]:
train_gnd = selection_imgs_ranks_for_rrt(train_gnd)

In [73]:
dev_gnd = selection_imgs_ranks_for_rrt(dev_gnd)

In [74]:
test_gnd = selection_imgs_ranks_for_rrt(test_gnd)

In [75]:
tuto_gnd = selection_imgs_ranks_for_rrt(tuto_gnd)

In [76]:
tuto_gnd['gnd'][1]

{'easy': ['512px-Eldon_D._Rudd.jpg',
  '512px-Geoff_Edwards.JPG',
  '512px-Dallas_Collage_Montage.png',
  '512px-Edwin_A._Walker.jpg',
  '512px-Dr._Francis_M._Forster.jpg',
  '512px-JFK_limousine.png',
  '512px-Waggoner_Carr.jpg',
  '512px-John_F._Kennedy,_White_House_color_photo_portrait.jpg',
  '512px-Sixth_Floor_Museum_Logo.svg.png',
  '512px-Irving_June_2019_37_(Ruth_Paine_Home).jpg',
  '512px-Dallas_-_Municipal_Building_01A.jpg',
  '512px-Botham_Jean_Blvd_-_Dallas_Police_HQ_-_June_2021_-_03.jpg',
  '512px-SchoolbookDepository.jpg',
  '512px-John_Peel_BBC_cropped.jpg',
  '512px-Jack_Ruby-1.jpg',
  '512px-Tom_Pettit_of_NBC_News_at_1976_DNC.jpg',
  '512px-Lee_Harvey_Oswald_1963.jpg',
  '512px-Bertram_Chalres_Hill.jpg',
  '512px-J._D._Tippit_in_his_Dallas_Police_Department_photo_distributed_in_1963.jpg',
  '512px-Jim_Leavelle_(clear).jpg'],
 'hard': ['512px-Jack_Ruby-1.jpg'],
 'junk': ['512px-Carcano_mod._1891.jpg'],
 'neg': ['512px-Carcano_mod._1891.jpg'],
 'provenance_entity': False

In [77]:
##def prepare_gnd_selection_images_for_rrt(data_set,
##                                         wikipedia,
##                                         passage2article,
##                                         passage_wiki_split,
##                                         tuto=False):
##    
##    selection_imgs  = []
##    iterat = 0
##    
##    ## for every question, get the list images corresponding to the top 100 search results
##    for item in data_set:
##        
##        if iterat >= 100 and tuto:
##            break
##        
##        img_list = []
##        
##        for passage in sorted(item['search_indices']): 
##            wiki_index = int(passage2article[passage])
##            wiki_split = passage_wiki_split[passage]
##            wiki_item = wikipedia[wiki_split][wiki_index]
##
##            img_list.append(wiki_item['image'])
##        
##        selection_imgs.append(img_list)
##        iterat += 1
##    
##    return selection_imgs
##     

In [78]:
##train_selection_imgs = prepare_gnd_selection_images_for_rrt(train_set, 
##                                                            wiki, 
##                                                            passage2article, 
##                                                            passage_wiki_split)

In [79]:
##np.array(train_selection_imgs).shape

In [80]:
##len(list(set(question_imgs))), len(list(set(positive_imgs))), len(list(set(negative_imgs)))

In [81]:
##train_gnd = {}
##train_gnd['qimlist'] = train_question_imgs
##train_gnd['imlist']  = train_question_imgs + train_positive_imgs + train_negative_imgs
##train_gnd['simlist'] = train_selection_imgs

In [82]:
##train_gnd['imlist']  = list(set(train_gnd['imlist']))
##train_gnd['qimlist'] = list(set(train_gnd['qimlist']))

In [83]:
##dev_gnd = {}
##dev_gnd['qimlist'] = dev_question_imgs
##dev_gnd['imlist']  = dev_question_imgs + dev_positive_imgs + dev_negative_imgs
##dev_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(dev_set, 
##                                                          wiki, 
##                                                          passage2article, 
##                                                          passage_wiki_split)

In [84]:
##dev_gnd['imlist']  = list(set(dev_gnd['imlist']))
##dev_gnd['qimlist'] = list(set(dev_gnd['qimlist']))

In [85]:
##test_gnd = {}
##test_gnd['qimlist'] = test_question_imgs
##test_gnd['imlist']  = test_question_imgs + test_positive_imgs + test_negative_imgs
##test_gnd['simlist'] = prepare_gnd_selection_images_for_rrt(test_set, 
##                                                           wiki, 
##                                                           passage2article, 
##                                                           passage_wiki_split)

In [86]:
##test_gnd['imlist']  = list(set(test_gnd['imlist']))
##test_gnd['qimlist'] = list(set(test_gnd['qimlist']))

In [87]:
##len(train_gnd['simlist']), len(train_gnd['qimlist']), len(train_gnd['imlist']),

In [88]:
##viquae_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_images/gnd_dev_viquae.pkl"

In [89]:
##viquae_gnd = pickle_load(viquae_gnd_file)

In [90]:
##viquae_gnd['gnd'][0].keys()

In [91]:
##item = train_set[7]

##print(item['search_indices'] + item['search_provenance_indices'])

In [92]:
##print(item['search_provenance_indices'])

In [93]:
##def prepare_gnd_for_rrt(data_set,
##                        wikipedia,
##                        passage2article,
##                        passage_wiki_split):
##    
##    ## ranks = []
##    data_gnd = []
##    iterat = 0
##    
##    ## for every question, get the list of the top 100 search results
##    for item in data_set:
##        
##        if iterat >= 100:
##            break
##        
##        if set(item['search_provenance_indices']) <= set(item['search_indices']):
##            all_indices = sorted(item['search_indices'])
##        else: all_indices = sorted(item['search_indices']) + item['search_provenance_indices']
##        
##        rank_dict = {all_indices [k]: k for k in range(len(all_indices))}
##        ##print(rank_dict)
##        
##        def loop_over_passages(passages):
##            
##            img_ranks = []
##            
##            for passage in passages:
##                img_ranks.append(rank_dict[passage])
##                
##            return img_ranks
##        
##        question_gnd = {}
##        question_gnd['easy'] = []
##        question_gnd['hard'] = loop_over_passages(item['search_provenance_indices'])
##        question_gnd['junk'] = loop_over_passages(item['search_irrelevant_indices'])
##        
##        data_gnd.append(question_gnd)
##        iterat += 1
##        
##    return data_gnd
##         

In [94]:
##train_gnd['gnd'] = prepare_gnd_for_rrt(train_set,
##                                       wiki,
##                                       passage2article,
##                                       passage_wiki_split)

In [95]:
##dev_gnd['gnd'] = prepare_gnd_for_rrt(dev_set,
##                                     wiki,
##                                     passage2article,
##                                     passage_wiki_split)

In [96]:
##test_gnd['gnd'] = prepare_gnd_for_rrt(test_set,
##                                      wiki,
##                                      passage2article,
##                                      passage_wiki_split)

In [97]:
train_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_train.pkl"
pickle_save(train_gnd_file, train_gnd)

In [98]:
dev_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_dev.pkl"
pickle_save(dev_gnd_file, dev_gnd)

In [99]:
test_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_test.pkl"
pickle_save(test_gnd_file, test_gnd)

In [100]:
tuto_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_tuto.pkl"
pickle_save(tuto_gnd_file, tuto_gnd)

In [101]:
entire_dataset_imgs = list(set(train_gnd['imlist'] + dev_gnd['imlist'] + test_gnd['imlist']))
len(entire_dataset_imgs)

99954

In [102]:
np.savetxt('data/entire_dataset_imgs.txt', entire_dataset_imgs, fmt="%s")

In [103]:
number_str = str(123)
zero_filled_number = number_str.zfill(5)
zero_filled_number

'00123'