# Adapting ViQuAE - Resnet Image Embeddings
## Imports

In [41]:
import json
import numpy as np
from datasets import load_from_disk, set_caching_enabled
set_caching_enabled(False)

## Loading Data

In [2]:
dataset = load_from_disk("data/viquae_dataset")

In [3]:
kb = load_from_disk('data/viquae_passages/')

In [4]:
wiki = load_from_disk('data/viquae_wikipedia')

In [5]:
train_set, dev_set, test_set = dataset['train'], dataset['validation'], dataset['test']

In [6]:
humans_with_faces, humans_without_faces, non_humans = wiki['humans_with_faces'], wiki['humans_without_faces'], wiki['non_humans']

## Passage to Article Mapping

In [7]:
f = open('data/viquae_wikipedia/non_humans/article2passage.json')
n_h_article2passage = json.load(f)
f.close()                                                              

In [8]:
n_h_passage2article = {}
for k, v in n_h_article2passage.items(): 
    for x in v: 
        n_h_passage2article[x] = k 

len_n_h = len(n_h_passage2article)
len_n_h

7175529

In [9]:
f = open('data/viquae_wikipedia/humans_without_faces/article2passage.json')
h_wo_f_article2passage = json.load(f)                                  
f.close()

In [10]:
h_wo_f_passage2article = {} 
for k, v in h_wo_f_article2passage.items(): 
    for x in v: 
        h_wo_f_passage2article[x] = k 

len_h_wo_f = len(h_wo_f_passage2article)
len_h_wo_f

298698

In [11]:
f = open('data/viquae_wikipedia/humans_with_faces/article2passage.json')                                                                      
h_w_f_article2passage = json.load(f)                                   
f.close()                  

In [12]:
h_w_f_passage2article = {} 
for k, v in h_w_f_article2passage.items(): 
    for x in v: 
        h_w_f_passage2article[x] = k

len_h_w_f = len(h_w_f_passage2article)
len_h_w_f

4411741

In [13]:
len_n_h + len_h_w_f + len_h_wo_f == 11885968

True

In [14]:
passage2article = {**h_w_f_passage2article, **h_wo_f_passage2article, **n_h_passage2article}
len(passage2article)

11885968

In [15]:
#with open("data/viquae_wikipedia/passage2article.json", "w") as outfile:
#    json.dump(passage2article, outfile)

## Some Exploration

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
        num_rows: 1190
    })
    validation: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface

In [17]:
train_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
    num_rows: 1190
})

In [18]:
dev_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
    num_rows: 1250
})

In [19]:
test_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_BM25_indices', 'document_BM25_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
    num_rows: 1257
})

In [18]:
wiki

DatasetDict({
    non_humans: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 953379
    })
    humans_with_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'history', 'image', 'image_embedding', 'image_hash', 'keep_face_embedding', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 506237
    })
    humans_without_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_embedding', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 35736
    })
})

In [19]:
item = train_set[0]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])

(100, 0, 100)

In [20]:
item = train_set[47]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices'])

(100, 57, 100)

In [28]:
len(wiki['humans_without_faces'][0]['image_embedding'])

2048

In [29]:
h_w_f_passage_split  = dict(zip(h_w_f_passage2article.keys(),  ['humans_with_faces'] * len_h_w_f))
h_wo_f_passage_split = dict(zip(h_wo_f_passage2article.keys(), ['humans_without_faces'] * len_h_wo_f))
n_h_passage_split    = dict(zip(n_h_passage2article.keys(),    ['non_humans'] * len_n_h))

passage_wiki_split = {**h_w_f_passage_split, **h_wo_f_passage_split, **n_h_passage_split}
len(passage_wiki_split)

11885968

In [30]:
#with open("data/viquae_wikipedia/passage_wiki_split.json", "w") as outfile:
#    json.dump(passage_wiki_split, outfile)

In [31]:
item = train_set[0]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices']), len(item['search_alternative_indices'])

(100, 0, 100, 2)

In [32]:
item = train_set[0]
len(item['search_indices']), len(item['search_provenance_indices']), len(item['search_irrelevant_indices']), len(item['search_alternative_indices'])

(100, 0, 100, 2)

In [44]:
len(wiki['non_humans'][0]['image_embedding'])

2048

## GroundTruth Generation

In [33]:
import numpy as np
import os.path as osp
import pickle, json, random

import torch

In [34]:
def pickle_load(path):
    with open(path, 'rb') as fid:
        data_ = pickle.load(fid)
    return data_

In [35]:
def pickle_save(path, data):
    with open(path, 'wb') as fid:
        pickle.dump(data, fid)

In [36]:
oxford_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/oxford5k/gnd_roxford5k.pkl"

In [37]:
oxford_gnd = pickle_load(oxford_gnd_file)

In [38]:
oxford_gnd.keys()

dict_keys(['gnd', 'imlist', 'qimlist'])

In [39]:
oxford_gnd['gnd'][0].keys(), len(oxford_gnd.keys())

(dict_keys(['bbx', 'easy', 'hard', 'junk']), 3)

In [40]:
(len(oxford_gnd['gnd'][0]['easy']), len(oxford_gnd['gnd'][0]['hard']), len(oxford_gnd['gnd'][0]['junk']))
np.sum((len(oxford_gnd['gnd'][0]['easy']), len(oxford_gnd['gnd'][0]['hard'])))

103

In [45]:
def save_gallery_img_embeddings(data_set,
                                wikipedia,
                                passage2article,
                                passage_wiki_split,
                                tuto=False,
                                img_embeddings = {}):
    
    # for every question, get the list of the top 100 search results
    iterat = 0
    for item in data_set:
        
        if iterat >= 120 and tuto:
            break
        
        # append the question image
        img_embeddings[item['image']] = item['keep_image_embedding']
        
        def loop_over_passages(passages, img_embeddings):
            
            for passage in passages: 
                # for every passage, get the list its corresponding wikipedia article id and split
                wiki_index = int(passage2article[passage])
                
                wiki_split = passage_wiki_split[passage]
                
                wiki_item = wikipedia[wiki_split][wiki_index]

                img_embeddings[wiki_item['image']] = wiki_item['image_embedding']
                
            return img_embeddings
        
        
        # append the images of passages containing the original answer
        original_answer_indices = item['search_provenance_indices']
        img_embeddings = loop_over_passages(original_answer_indices, img_embeddings)
        
        # append the images of passages containing an alternative answer
        alternative_answer_indices = item['search_alternative_indices']
        img_embeddings = loop_over_passages(alternative_answer_indices, img_embeddings)
        
        # append the images of irrelevant passages
        irrelevant_indices = item['search_irrelevant_indices']
        img_embeddings = loop_over_passages(irrelevant_indices, img_embeddings)
        
        # append the images of passages provided by IR search
        img_embeddings = loop_over_passages(item['search_indices'], img_embeddings)
        
        iterat += 1
        
    return img_embeddings
     

In [46]:
train_img_embeddings = save_gallery_img_embeddings(train_set, wiki, passage2article, passage_wiki_split)

In [47]:
dev_img_embeddings   = save_gallery_img_embeddings(dev_set, wiki,   passage2article, passage_wiki_split)

In [48]:
test_img_embeddings  = save_gallery_img_embeddings(test_set, wiki,  passage2article, passage_wiki_split)

In [49]:
dataset_img_embeddings = {**train_img_embeddings, **dev_img_embeddings, **test_img_embeddings}
len(dataset_img_embeddings)

99954

In [58]:
import os.path as osp
import numpy as np
from tqdm import tqdm
from utils import pickle_save, pickle_load
from pprint import pprint
from utils.data.delf import datum_io
from copy import deepcopy

In [65]:
_IMAGENET_EXTENSION = '.imagenet'

In [66]:
dataset_name = 'viquae_for_rrt'
data_dir = osp.join('delg/data', dataset_name)
output_features_dir = osp.join(data_dir, 'imagenet_r50/')

In [74]:
'.'.join(image_name.split('.')[:-1])

'512px-James_II_of_Scotland_17th_century'

In [75]:
for image_name, embedding in dataset_img_embeddings.items():
    image_name = '.'.join(image_name.split('.')[:-1])
    output_feature_filename = osp.join(output_features_dir, image_name + _IMAGENET_EXTENSION)
    datum_io.WriteToFile(np.array(embedding), output_feature_filename)