# Adapting ViQuAE images for RRT
## Imports

In [1]:
import json
import numpy as np
from datasets import load_from_disk, set_caching_enabled
set_caching_enabled(False)

## Loading Data

In [2]:
dataset = load_from_disk("data/viquae_dataset")

In [3]:
kb = load_from_disk('data/viquae_passages/')

In [4]:
wiki = load_from_disk('data/viquae_wikipedia')

In [5]:
train_set, dev_set, test_set = dataset['train'], dataset['validation'], dataset['test']

In [6]:
humans_with_faces, humans_without_faces, non_humans = wiki['humans_with_faces'], wiki['humans_without_faces'], wiki['non_humans']

## Some Exploration

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
        num_rows: 1190
    })
    validation: Dataset({
        features: ['BM25_indices', 'BM25_scores', 'arcface

In [8]:
train_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
    num_rows: 1190
})

In [9]:
dev_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
    num_rows: 1250
})

In [10]:
test_set

Dataset({
    features: ['BM25_indices', 'BM25_scores', 'arcface_indices', 'arcface_scores', 'clip-RN50', 'clip-RN50_indices', 'clip-RN50_scores', 'document_BM25_indices', 'document_BM25_scores', 'document_arcface_indices', 'document_arcface_scores', 'document_provenance_indices', 'document_resnet_indices', 'document_resnet_scores', 'document_search_indices', 'document_search_scores', 'face', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'id', 'image', 'image_embedding', 'image_hash', 'input', 'keep_clip-RN50', 'keep_face_embedding', 'keep_image_embedding', 'kilt_id', 'meta', 'original_answer_provenance_indices', 'original_question', 'output', 'provenance_indices', 'resnet_indices', 'resnet_scores', 'search_indices', 'search_irrelevant_indices', 'search_provenance_indices', 'search_scores', 'semi-oracle_irrelevant_indices', 'semi-oracle_provenance_indices', 'url', 'wikidata_id', 'search_alternative_indices', 'human'],
    num_rows: 1257
})

In [11]:
wiki

DatasetDict({
    non_humans: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 953379
    })
    humans_with_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_box', 'face_embedding', 'face_landmarks', 'face_prob', 'history', 'image', 'image_embedding', 'image_hash', 'keep_face_embedding', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 506237
    })
    humans_without_faces: Dataset({
        features: ['anchors', 'categories', 'clip-RN50', 'document', 'face_embedding', 'history', 'image', 'image_embedding', 'image_hash', 'kilt_id', 'passage_index', 'text', 'url', 'wikidata_info', 'wikipedia_id', 'wikipedia_title'],
        num_rows: 35736
    })
})

## GroundTruth Generation

In [1]:
import numpy as np
import os.path as osp
import pickle, json, random

import torch

In [2]:
def pickle_load(path):
    with open(path, 'rb') as fid:
        data_ = pickle.load(fid)
    return data_

In [3]:
def pickle_save(path, data):
    with open(path, 'wb') as fid:
        pickle.dump(data, fid)

In [22]:
set_name = 'train'
gnd_name = 'gnd_' + set_name + '.pkl'

In [23]:
dataset_name = 'viquae_for_rrt'
data_dir = osp.join('/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data', dataset_name)

In [24]:
gnd = pickle_load(osp.join(data_dir, gnd_name))
non_human_gnd = pickle_load(osp.join(data_dir, 'humans_'+gnd_name))
human_gnd = pickle_load(osp.join(data_dir, 'non_humans_'+gnd_name))

In [25]:
len(gnd['gnd']), len(human_gnd['gnd']), len(non_human_gnd['gnd'])#, len(train_set)

(1190, 656, 534)

In [21]:
gnd['gnd'][0].keys()

dict_keys(['easy', 'hard', 'junk', 'neg', 'provenance_entity', 'ir_order', 'img_rank_dict', 'rank_img_dict', 'r_easy', 'r_hard', 'r_junk', 'r_neg', 'r_ir_order', 'anchor_idx', 'g_easy', 'g_hard', 'g_junk', 'g_neg'])

In [10]:
len(gnd['simlist'][0])

43

In [21]:
for i in range(len(train_set)):
    gnd['gnd'][i]['is_human'] = train_set[i]['human']

In [22]:
train_gnd = pickle_load(osp.join(data_dir, 'gnd_train.pkl'))
for i in range(len(train_gnd['gnd'])):
    train_gnd['gnd'][i]['is_human'] = train_set[i]['human']

In [23]:
dev_gnd = pickle_load(osp.join(data_dir, 'gnd_dev.pkl'))
for i in range(len(dev_gnd['gnd'])):
    dev_gnd['gnd'][i]['is_human'] = dev_set[i]['human']

In [24]:
test_gnd = pickle_load(osp.join(data_dir, 'gnd_test.pkl'))
for i in range(len(test_gnd['gnd'])):
    test_gnd['gnd'][i]['is_human'] = test_set[i]['human']

In [25]:
tuto_gnd = pickle_load(osp.join(data_dir, 'gnd_tuto.pkl'))
for i in range(len(tuto_gnd['gnd'])):
    tuto_gnd['gnd'][i]['is_human'] = train_set[i]['human']

In [26]:
len(tuto_gnd['gnd']), tuto_gnd['gnd'][1].keys()

(120,
 dict_keys(['easy', 'hard', 'junk', 'neg', 'provenance_entity', 'ir_order', 'img_rank_dict', 'rank_img_dict', 'r_easy', 'r_hard', 'r_junk', 'r_neg', 'r_ir_order', 'anchor_idx', 'g_easy', 'g_hard', 'g_junk', 'g_neg', 'is_human']))

In [27]:
(np.sum(train_set['human'][:120]), np.sum([tuto_gnd['gnd'][i]['is_human'] for i in range(len(tuto_gnd['gnd']))])), (np.sum(train_set['human']), np.sum([train_gnd['gnd'][i]['is_human'] for i in range(len(train_gnd['gnd']))])), (np.sum(dev_set['human']), np.sum([dev_gnd['gnd'][i]['is_human'] for i in range(len(dev_gnd['gnd']))])), (np.sum(test_set['human']), np.sum([test_gnd['gnd'][i]['is_human'] for i in range(len(test_gnd['gnd']))]))

((63, 63), (534, 534), (573, 573), (577, 577))

In [28]:
tuto_gnd['gnd'][10]['is_human']

False

In [40]:
num_candidates = 100

labels = [train_gnd['gnd'][i]['anchor_idx'] for i in range(len(train_gnd['gnd']))]

#############################################################################
## Collect valid tuples
valids = np.zeros_like(labels)
counts = 0
len_positives = 0
len_negatives = 0
for i in range(len(train_gnd['qimlist'])):
    positives = train_gnd['gnd'][i]['r_easy']
    negatives = train_gnd['gnd'][i]['r_junk']
    if len(positives) < 1 or len(negatives) < 1:
        continue
    valids[i] = 1
    counts += len(positives) + len(negatives)
    len_positives += len(positives)
    len_negatives += len(negatives)
valids = np.where(valids > 0)[0]
num_samples = len(valids)

In [41]:
num_samples, counts, train_gnd['gnd'][i]['r_easy'], len_positives, len_negatives

(1131, 56080, [7, 6], 7821, 48259)

In [34]:
train_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_train.pkl"
pickle_save(train_gnd_file, train_gnd)

In [35]:
dev_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_dev.pkl"
pickle_save(dev_gnd_file, dev_gnd)

In [36]:
test_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_test.pkl"
pickle_save(test_gnd_file, test_gnd)

In [37]:
tuto_gnd_file = "/mnt/beegfs/home/smessoud/RerankingTransformer/models/research/delf/delf/python/delg/data/viquae_for_rrt/gnd_tuto.pkl"
pickle_save(tuto_gnd_file, tuto_gnd)

In [32]:
def extend(a):
    out = []
    for sublist in a:
        out.extend(sublist)
    return out

In [33]:
def preserve_order(array):
    new_array = []
    for e in array:
        if e in new_array:
            continue
        else:
            new_array.append(e)
    return new_array

In [34]:
def format_selection(selection):
    return [preserve_order(sub) for sub in selection]
    

In [35]:
len(gnd['gnd'][1].keys())

19

In [38]:
def format_gnd_for_entity_type(gnd, entity_type):
    new_gnd = {}
    new_gnd['imlist']  = gnd['imlist']
    
    if entity_type == 'humans':
        new_gnd['qimlist'] = [gnd['qimlist'][i] for i in range(len(gnd['qimlist'])) if gnd['gnd'][i]['is_human']]
        new_gnd['simlist'] = [gnd['simlist'][i] for i in range(len(gnd['simlist'])) if gnd['gnd'][i]['is_human']]
    else:
        new_gnd['qimlist'] = [gnd['qimlist'][i] for i in range(len(gnd['qimlist'])) if not gnd['gnd'][i]['is_human']]
        new_gnd['simlist'] = [gnd['simlist'][i] for i in range(len(gnd['simlist'])) if not gnd['gnd'][i]['is_human']]
    
    new_gnd_gnd =  []
    if entity_type == 'humans':
        for i in range(len(gnd['qimlist'])):
            if gnd['gnd'][i]['is_human']:
                question_gnd = {}
                question_gnd['easy'] = gnd['gnd'][i]['easy']
                question_gnd['hard'] = gnd['gnd'][i]['hard']
                question_gnd['junk'] = gnd['gnd'][i]['junk']
                question_gnd['neg']  = gnd['gnd'][i]['neg']
                
                question_gnd['r_easy'] = gnd['gnd'][i]['r_easy']
                question_gnd['r_hard'] = gnd['gnd'][i]['r_hard']
                question_gnd['r_junk'] = gnd['gnd'][i]['r_junk']
                question_gnd['r_neg']  = gnd['gnd'][i]['r_neg']
                
                question_gnd['g_easy'] = gnd['gnd'][i]['g_easy']
                question_gnd['g_hard'] = gnd['gnd'][i]['g_hard']
                question_gnd['g_junk'] = gnd['gnd'][i]['g_junk']
                question_gnd['g_neg']  = gnd['gnd'][i]['g_neg']
                
                question_gnd['provenance_entity']  = gnd['gnd'][i]['provenance_entity']
                
                question_gnd['ir_order']    = gnd['gnd'][i]['ir_order']
                question_gnd['r_ir_order']  = gnd['gnd'][i]['r_ir_order']
                
                question_gnd['rank_img_dict']  = gnd['gnd'][i]['rank_img_dict']
                question_gnd['img_rank_dict']  = gnd['gnd'][i]['img_rank_dict']
                             
                question_gnd['anchor_idx']  = gnd['gnd'][i]['anchor_idx']
                question_gnd['is_human']    = gnd['gnd'][i]['is_human']
                
                new_gnd_gnd.append(question_gnd)
                
    if entity_type != 'humans':
        for i in range(len(gnd['qimlist'])):
            if not gnd['gnd'][i]['is_human']:
                question_gnd = {}
                question_gnd['easy'] = gnd['gnd'][i]['easy']
                question_gnd['hard'] = gnd['gnd'][i]['hard']
                question_gnd['junk'] = gnd['gnd'][i]['junk']
                question_gnd['neg']  = gnd['gnd'][i]['neg']
                
                question_gnd['r_easy'] = gnd['gnd'][i]['r_easy']
                question_gnd['r_hard'] = gnd['gnd'][i]['r_hard']
                question_gnd['r_junk'] = gnd['gnd'][i]['r_junk']
                question_gnd['r_neg']  = gnd['gnd'][i]['r_neg']
                
                question_gnd['g_easy'] = gnd['gnd'][i]['g_easy']
                question_gnd['g_hard'] = gnd['gnd'][i]['g_hard']
                question_gnd['g_junk'] = gnd['gnd'][i]['g_junk']
                question_gnd['g_neg']  = gnd['gnd'][i]['g_neg']
                
                question_gnd['provenance_entity']  = gnd['gnd'][i]['provenance_entity']
                
                question_gnd['ir_order']    = gnd['gnd'][i]['ir_order']
                question_gnd['r_ir_order']  = gnd['gnd'][i]['r_ir_order']
                
                question_gnd['rank_img_dict']  = gnd['gnd'][i]['rank_img_dict']
                question_gnd['img_rank_dict']  = gnd['gnd'][i]['img_rank_dict']
                             
                question_gnd['anchor_idx']  = gnd['gnd'][i]['anchor_idx']
                question_gnd['is_human']    = gnd['gnd'][i]['is_human']
                
                new_gnd_gnd.append(question_gnd)
    
    new_gnd['gnd'] = new_gnd_gnd
    
    return new_gnd

In [41]:
humans_tuto_gnd     = format_gnd_for_entity_type(tuto_gnd, entity_type='humans')
non_humans_tuto_gnd = format_gnd_for_entity_type(tuto_gnd, entity_type='non_humans')
pickle_save(osp.join(data_dir, "humans_gnd_tuto.pkl"), humans_tuto_gnd)
pickle_save(osp.join(data_dir, "non_humans_gnd_tuto.pkl"), non_humans_tuto_gnd)
len(humans_tuto_gnd['gnd']), len(non_humans_tuto_gnd['gnd']), len(tuto_gnd['gnd']), 63+57

(63, 57, 120, 120)

In [42]:
prefixed = None
prefixed is not None

False

In [43]:
humans_train_gnd     = format_gnd_for_entity_type(train_gnd, entity_type='humans')
non_humans_train_gnd = format_gnd_for_entity_type(train_gnd, entity_type='non_humans')
pickle_save(osp.join(data_dir, "humans_gnd_train.pkl"), humans_train_gnd)
pickle_save(osp.join(data_dir, "non_humans_gnd_train.pkl"), non_humans_train_gnd)
len(humans_train_gnd['qimlist']), len(non_humans_train_gnd['qimlist']), len(train_gnd['qimlist']), 534+656

(534, 656, 1190, 1190)

In [44]:
humans_dev_gnd     = format_gnd_for_entity_type(dev_gnd, entity_type='humans')
non_humans_dev_gnd = format_gnd_for_entity_type(dev_gnd, entity_type='non_humans')
pickle_save(osp.join(data_dir, "humans_gnd_dev.pkl"), humans_dev_gnd)
pickle_save(osp.join(data_dir, "non_humans_gnd_dev.pkl"), non_humans_dev_gnd)
len(humans_dev_gnd['qimlist']), len(non_humans_dev_gnd['qimlist']), len(humans_dev_gnd['qimlist']), 573+677

(573, 677, 573, 1250)

In [45]:
humans_test_gnd     = format_gnd_for_entity_type(test_gnd, entity_type='humans')
non_humans_test_gnd = format_gnd_for_entity_type(test_gnd, entity_type='non_humans')
pickle_save(osp.join(data_dir, "humans_gnd_test.pkl"), humans_test_gnd)
pickle_save(osp.join(data_dir, "non_humans_gnd_test.pkl"), non_humans_test_gnd)
len(humans_test_gnd['qimlist']), len(non_humans_test_gnd['qimlist']), len(test_gnd['qimlist']), 577+680

(577, 680, 1257, 1257)

In [46]:
len(humans_tuto_gnd['qimlist']), len(non_humans_tuto_gnd['qimlist']), len(tuto_gnd['qimlist']), 63+57

(63, 57, 120, 120)

In [None]:
entire_dataset_imgs = list(set(train_gnd['imlist'] + dev_gnd['imlist'] + test_gnd['imlist']))
len(entire_dataset_imgs)

In [None]:
np.savetxt('data/entire_dataset_imgs.txt', entire_dataset_imgs, fmt="%s")

In [None]:
number_str = str(123)
zero_filled_number = number_str.zfill(5)
zero_filled_number