In [1]:
%load_ext autoreload
%autoreload 2

from time import time
import pandas as pd
import numpy as np
import os
from collections import Counter, defaultdict
import pickle

In [2]:
import sys
sys.path.insert(0, "/data3/muntean/DRhard")

In [3]:
import argparse
import subprocess
import sys
sys.path.append("./")
import faiss
import logging
import os
import numpy as np
# import torch
from transformers import RobertaConfig
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler

from model import RobertaDot
from dataset import (
    TextTokenIdsCache, load_rel, SubsetSeqDataset, SequenceDataset,
    single_get_collate_function
)
from retrieve_utils import (
    construct_flatindex_from_embeddings, 
    index_retrieve, convert_index_to_gpu,
    update_flatindex_from_embeddings
)
logger = logging.Logger(__name__)

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
doc_memmap_path = "/data3/muntean/DRhard/data/passage/evaluate/star/passages.memmap"
docid_memmap_path = "/data3/muntean/DRhard/data/passage/evaluate/star/passages-id.memmap"
query_memmap_path = "/data3/muntean/DRhard/data/passage/evaluate/star/test-manual-query.memmap"
queryids_memmap_path = "/data3/muntean/DRhard/data/passage/evaluate/star/test-manual-query-id.memmap"

In [5]:
doc_embeddings = np.memmap(doc_memmap_path, dtype=np.float32, mode="r")
doc_ids = np.memmap(docid_memmap_path, dtype=np.int32, mode="r")
doc_embeddings = doc_embeddings.reshape(-1, 768)

query_embeddings = np.memmap(query_memmap_path, dtype=np.float32, mode="r")
query_embeddings = query_embeddings.reshape(-1, 768)
query_ids = np.memmap(queryids_memmap_path, dtype=np.int32, mode="r")

In [6]:
%time
index = construct_flatindex_from_embeddings(doc_embeddings, doc_ids)

CPU times: user 2 µs, sys: 2 µs, total: 4 µs
Wall time: 8.34 µs
embedding shape: (38626614, 768)
(38626614,) int64


In [7]:
type(index)

faiss.swigfaiss.IndexIDMap2

# Select certain queries and certain docs for small index

In [8]:
# Load our qid and docid remapping dictionaries

# query id dict
qid_mapping_path = "/data3/muntean/DRhard/data/passage/dataset/queries.CASTmanual.QID2newID.test.tsv"
queries_df = pd.read_csv(qid_mapping_path, delimiter="\t", header=None)
print(len(queries_df))

# collection id dict
collection_mapping_path = "/data3/muntean/DRhard/data/passage/dataset/CASTcollectionPID2newID.tsv"
collection_df = pd.read_csv(collection_mapping_path, delimiter="\t", header=None)
print(len(collection_df))

479
38626614


In [9]:
qid2newqid_dict = dict(zip(queries_df[0], queries_df[1])) 
pid2newpid_dict = dict(zip(collection_df[0], collection_df[1])) 

In [10]:
qid2newqid_dict["32_1"]

9

In [11]:
# Create reverse dictionaries
newqid2qid_dict = dict(zip(queries_df[1], queries_df[0])) 
newpid2pid_dict = dict(zip(collection_df[1], collection_df[0])) 

In [12]:
newqid2qid_dict[9]

'32_1'

In [13]:
# DRhard docid and qid encoding
preprocess_dir = "/data3/muntean/DRhard/data/passage/preprocess"

pid2offset = pickle.load(open(os.path.join(preprocess_dir, "pid2offset.pickle"), 'rb'))
offset2pid = {v:k for k, v in pid2offset.items()}
qid2offset = pickle.load(open(os.path.join(preprocess_dir, f"test-manual-qid2offset.pickle"), 'rb'))
offset2qid = {v:k for k, v in qid2offset.items()}

In [14]:
qid2offset[9]

9

In [15]:
conv_qrel_int = [31, 32, 33, 34, 37, 40, 49, 50, 54, 56, 58, 59, 61, 67, 68, 69, 75, 77, 78, 79]
conv_qrel = [str(x) for x in conv_qrel_int]
# conv_qrel

# Create conv cache

In [103]:
topk = 5000 # cache dimension [1000,2000,5000,10000]

In [104]:
# distance dicts
cache_radius_dict = dict() # between first utterance (qa) and last retrieved doc from the big index
query_distance_dict = dict() # distance between the first (qa) and the rest of utterances of the conversation (qb)
query_radius_dict = dict() # between current utterance (qb) and last retrieved doc from the big index
rb_hat_dict = dict() # rb_hat = ra - d(qb, qa)

In [105]:
def l2_distance(v1,v2):
    return np.linalg.norm(v1-v2)

In [106]:
def create_conv_cache(conv_id, qid2newqid_dict, qid2offset, query_embeddings, doc_embeddings, 
                      index, topk, cache_radius_dict):
    # first utt of the conversation - determines the size of the cache
    first_qid = conv_id + "_1"
    newqid = qid2newqid_dict[first_qid] #added first
    qid_offset = qid2offset[newqid]

    # prendere il memmap
    query_emb = query_embeddings[qid_offset].reshape(1, 768)
    print("Init index: ",first_qid, qid_offset)
    
    # fare retireval nel indice grande e prendere top 2000 documenti
    faiss.omp_set_num_threads(16) #32
    nearest_neighbors = index_retrieve(index, query_emb, topk, batch=32)
    
    # select doc embeddings, paired with ids
    small_doc_emb = doc_embeddings[nearest_neighbors[0]]
    small_doc_ids = np.array(nearest_neighbors[0])
    index_conv = construct_flatindex_from_embeddings(small_doc_emb, small_doc_ids)
       
    # compute distance between the first query and last doc in the list of topk retrieved that are stored in cache (e.g., r_q_i)
    last_doc = nearest_neighbors[0][-1]
    last_doc_embedding = doc_embeddings[last_doc]
    cache_radius_dict[first_qid] = l2_distance(query_emb, last_doc_embedding)
    
    return index_conv, nearest_neighbors, cache_radius_dict

In [107]:
def update_conv_cache(qid, qid2newqid_dict, qid2offset, query_embeddings, doc_embeddings, 
                      index, index_conv, topk, cache_radius_dict):
    # determines the size of the cache?
    newqid = qid2newqid_dict[qid]
    qid_offset = qid2offset[newqid]

    # prendere il memmap
    query_emb = query_embeddings[qid_offset].reshape(1, 768)
    
    print("Update index: ",qid, qid_offset)
    
    # fare retireval nel indice grande e prendere topk documenti
    faiss.omp_set_num_threads(16) #32
    nearest_neighbors = index_retrieve(index, query_emb, topk, batch=32)
    
    # select doc embeddings, paired with ids
    small_doc_emb = doc_embeddings[nearest_neighbors[0]]
    small_doc_ids = np.array(nearest_neighbors[0])
    index_conv = update_flatindex_from_embeddings(index_conv, small_doc_emb, small_doc_ids)
       
    # compute distance between the first query and last doc in the list of topk retrieved that are stored in cache (e.g., r_q_i)
    last_doc = nearest_neighbors[0][-1]
    last_doc_embedding = doc_embeddings[last_doc]
    cache_radius_dict[qid] = l2_distance(query_emb, last_doc_embedding)
    
    return index_conv, nearest_neighbors, cache_radius_dict

In [None]:
# conv_ids = set([x.split("_")[0] for x in qid2newqid_dict.keys()]) # this has all but we don't need all, just the ones in qrel
# conv in qrel: subset of all conv
conv_qrel_int = [32] #, 32, 33, 34, 37, 40, 49, 50, 54, 56, 58, 59, 61, 67, 68, 69, 75, 77, 78, 79]
conv_qrel = [str(x) for x in conv_qrel_int]

results_list = []  # top1000 rankings for each query of the conversation from the CACHE index
cache_update_with_qid = []
coverage1 = {} # top10 ranking for each query from the BIG index
coverage2 = {} # list of triples (3,5,10) for each qb

for conv_id in conv_qrel: # iterate over the conversations
    print()
    print()
    print("Starting conv: " , conv_id)
    
    # ALL QA
    # Create index for first query and retrieve nearest neighbours - top 2000
    index_conv, nn_index, cache_radius_dict = create_conv_cache(conv_id, qid2newqid_dict, 
                                                                         qid2offset, query_embeddings, 
                                                                         doc_embeddings, index, topk, 
                                                                         cache_radius_dict)
    print("Create cache for: ", conv_id)
    
    # save docs for qa for coverage
    retrieved_qa = nn_index[0]
        
    # first  query id & embedding
    first_qid = conv_id + "_1"
    first_newqid = qid2newqid_dict[first_qid] #added first
    first_qid_offset = qid2offset[first_newqid]
    
    # save results - top 1000 for first conv query qa
    for idx, pid in enumerate(nn_index[0][:1000]):
        results_list.append((first_qid_offset, pid, idx+1))
    
    # prendere il memmap di qa
    first_query_emb = query_embeddings[first_qid_offset].reshape(1, 768)
    
    queries_in_cache = {}
    queries_in_cache[first_qid] = first_query_emb
    
    # for each utt in the rest of the conv: qb
    for qid in qid2newqid_dict.keys():
        if not qid.endswith("_1") and qid.startswith(conv_id):
            
            # select query embedding
            newqid = qid2newqid_dict[qid]
            qid_offset = qid2offset[newqid]
            # prendere il memmap
            query_emb = query_embeddings[qid_offset].reshape(1, 768)
            
            print()
            print("Processing qid:", qid, qid_offset)
            
            if len(queries_in_cache)>1: # there are more queries in cache
                # COMPUTE rb_hat with all queries in cache based on convid
                update = True
                for query_in_cache_id, query_in_cache_emb  in queries_in_cache.items():
                    query_distance_dict[qid] = l2_distance(query_emb, query_in_cache_emb)
                    rb_hat_dict[qid] = cache_radius_dict[query_in_cache_id] - query_distance_dict[qid]
                    print("More queries in cache, rb_hat di", qid, " and ", query_in_cache_id, " is ", rb_hat_dict[qid])
                    if rb_hat_dict[qid] >= 0: 
                        update = False
                        break
                
                # update index and query in cache list
                if update:
                    print("Updating cache!")
                    queries_in_cache[qid] = query_emb
                    index_conv, nn_index, cache_radius_dict = update_conv_cache(qid, qid2newqid_dict, qid2offset, 
                                                                                         query_embeddings, doc_embeddings, index,
                                                                                         index_conv, topk, cache_radius_dict)
                
            else:
                # COMPUTE rb_hat with first only
                # compute distance between qb and qa
                query_distance_dict[qid] = l2_distance(query_emb, first_query_emb)

                # compute rb_capuccio = ra - d(qb, qa)
                rb_hat_dict[qid] = cache_radius_dict[first_qid] - query_distance_dict[qid]
                print("One query in cache, rb_hat di", qid, " is ", rb_hat_dict[qid])
            
                # update index and query in cache list
                if rb_hat_dict[qid] < 0:
                    queries_in_cache[qid] = query_emb
                    index_conv, nn_index, cache_radius_dict = update_conv_cache(qid, qid2newqid_dict, qid2offset, 
                                                                                         query_embeddings, doc_embeddings, index,
                                                                                         index_conv, topk, cache_radius_dict)
                    print("Updating cache!")
                

            # retrieve docs for qb
            faiss.omp_set_num_threads(16) #32
            nn_cache = index_retrieve(index_conv, query_emb, 1000, batch=32)
            print("Retrieved top 1000 for", qid)
            
            # save results - top1000, for qb
            for idx, pid in enumerate(nn_cache[0]):
                results_list.append((qid_offset, pid, idx+1))
                
            
            #### Compute COVERAGE 
            # Cov1 - intersezione tra Ba ed Bb_hat
            # distance between qb and each doc in top1000 from cache and check if < rb_hat
            # put them in a list - exact match - and intersect with top k in cache
            retrieved_qb_in_rb_hat = []
            small_doc_emb = doc_embeddings[nn_cache[0]]
            for doc_id, doc_emb in zip(nn_cache[0], small_doc_emb):
                if l2_distance(query_emb, doc_emb) < rb_hat_dict[qid]:
                    retrieved_qb_in_rb_hat.append(doc_id)
            coverage1[qid]=retrieved_qb_in_rb_hat
            
            
            #### Compute DISTANCES rb
            
            #compute distance between qb and the last retrieved doc from INDEX
            first_10_docs = index_retrieve(index, query_emb, 10, batch=32)
            # 3 raggi rb per 3,5,10
            rb_dist_list = []
            for i in [3,5,10]:
                last_doc_embedding = doc_embeddings[first_10_docs[0][i-1]]
                dist_rb = l2_distance(query_emb, last_doc_embedding)
                rb_dist_list.append(dist_rb)
            query_radius_dict[qid] = rb_dist_list
            print("Finished retrieving in the big index!")
            
            # save top10 for qb in INDEX
            results_list_rb = first_10_docs[0][:10]
            result_list_qb_in_cache = nn_cache[0][:10]
            
            print("Check this when update happens:" , results_list_rb, result_list_qb_in_cache, nn_index[0][:10])
        
            #### Compute COVERAGE          
            # Cov2 - intersezione tra risultati di query qb top 1000 su cache e query qb su indice per k=3,5,10
            num_intersection = []
            for cut_off in [3,5,10]:
                elem_in_common = set(results_list_rb[:cut_off]).intersection(result_list_qb_in_cache[:cut_off])
                num_intersection.append(len(elem_in_common))
            coverage2[qid] = num_intersection    
            
            print("finished with qid: ", qid)
    cache_update_with_qid.extend(list(queries_in_cache.keys()))



Starting conv:  32
Init index:  32_1 9
Query Num 1





  0%|                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s][A[A[A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:59<00:00, 59.46s/it][A[A[A


Elapsed Time: 59.5s, Elapsed Time per query: 59467.2ms
embedding shape: (5000, 768)
(5000,) int64
Create cache for:  32

Processing qid: 32_2 10
One query in cache, rb_hat di 32_2  is  1.3569245
Query Num 1





100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 208.67it/s][A[A[A


Elapsed Time: 0.0s, Elapsed Time per query: 8.6ms
Retrieved top 1000 for 32_2
Query Num 1





  0%|                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s][A[A[A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:28<00:00, 28.96s/it][A[A[A


Elapsed Time: 29.0s, Elapsed Time per query: 28963.3ms
Finished retrieving in the big index!
Check this when update happens: [5696924, 5231421, 5231427, 34260273, 34036966, 5696927, 5696932, 34214925, 2846731, 17328200] [5696924, 5231421, 5231427, 34260273, 34036966, 5696927, 5696932, 34214925, 2846731, 17328200] [6584581, 3855636, 2174663, 3839334, 6786658, 1864967, 521815, 3839336, 706230, 2174666]
finished with qid:  32_2

Processing qid: 32_3 11
One query in cache, rb_hat di 32_3  is  -0.09773302
Update index:  32_3 11
Query Num 1





  0%|                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s][A[A[A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:58<00:00, 58.22s/it][A[A[A


Elapsed Time: 58.2s, Elapsed Time per query: 58224.0ms
embedding shape: (5000, 768)
(5000,) int64
Updating cache!
Query Num 1





100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 117.56it/s][A[A[A


Elapsed Time: 0.0s, Elapsed Time per query: 12.1ms
Retrieved top 1000 for 32_3
Query Num 1





  0%|                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s][A[A[A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:41<00:00, 41.36s/it][A[A[A


Elapsed Time: 41.4s, Elapsed Time per query: 41367.3ms
Finished retrieving in the big index!
Check this when update happens: [2269518, 795750, 4416645, 6226326, 13143982, 7108432, 1432514, 5511406, 2174086, 673356] [2269518, 2269518, 795750, 795750, 4416645, 4416645, 6226326, 6226326, 13143982, 13143982] [2269518, 795750, 4416645, 6226326, 13143982, 7108432, 1432514, 5511406, 2174086, 673356]
finished with qid:  32_3

Processing qid: 32_4 12
More queries in cache, rb_hat di 32_4  and  32_1  is  0.36956453
Query Num 1





100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 117.94it/s][A[A[A


Elapsed Time: 0.0s, Elapsed Time per query: 11.9ms
Retrieved top 1000 for 32_4
Query Num 1





  0%|                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s][A[A[A


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:45<00:00, 45.94s/it][A[A[A


Elapsed Time: 45.9s, Elapsed Time per query: 45941.1ms
Finished retrieving in the big index!
Check this when update happens: [3442961, 3594051, 5791550, 3506104, 5791552, 3559434, 7755146, 24040080, 3594048, 13257] [3442961, 3442961, 3594051, 3594051, 5791550, 5791550, 3506104, 3506104, 5791552, 5791552] [2269518, 795750, 4416645, 6226326, 13143982, 7108432, 1432514, 5511406, 2174086, 673356]
finished with qid:  32_4

Processing qid: 32_5 13
More queries in cache, rb_hat di 32_5  and  32_1  is  0.025249958
Query Num 1





100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 53.08it/s][A[A[A


Elapsed Time: 0.0s, Elapsed Time per query: 22.7ms
Retrieved top 1000 for 32_5
Query Num 1





  0%|                                                                                                                                                                                                                                                | 0/1 [00:00<?, ?it/s][A[A[A

In [85]:
with open('../data/star-ranking/approximated-coverage-star-L2-ranking-top1000-cache-top'+str(topk)+'_with_update.tsv', 'w+') as fout:
    for i in coverage2:
        # print(coverage2.keys())
        if i in cache_update_with_qid:
            fout.write(str(i)+"\t"+str(coverage2[i])+"\t"+str(coverage1[i])+"\t"+str(rb_hat_dict[i])+" UPDATE \n")
        else:
            fout.write(str(i)+"\t"+str(coverage2[i])+"\t"+str(coverage1[i])+"\t"+str(rb_hat_dict[i])+"\n")

In [60]:
print(len(cache_update_with_qid))
cache_update_with_qid

67


['31_1',
 '31_3',
 '32_1',
 '32_3',
 '32_6',
 '32_8',
 '33_1',
 '33_10',
 '34_1',
 '37_1',
 '37_6',
 '37_9',
 '40_1',
 '40_5',
 '40_7',
 '40_9',
 '49_1',
 '49_3',
 '50_1',
 '50_2',
 '50_5',
 '50_7',
 '54_1',
 '54_3',
 '54_8',
 '56_1',
 '56_7',
 '58_1',
 '58_7',
 '59_1',
 '59_3',
 '59_5',
 '59_7',
 '61_1',
 '61_6',
 '61_8',
 '67_1',
 '67_4',
 '67_5',
 '67_10',
 '67_11',
 '68_1',
 '68_4',
 '68_5',
 '68_7',
 '68_10',
 '69_1',
 '69_3',
 '69_5',
 '75_1',
 '75_4',
 '75_7',
 '77_1',
 '77_4',
 '77_5',
 '77_6',
 '77_8',
 '77_10',
 '78_1',
 '78_7',
 '78_10',
 '79_1',
 '79_2',
 '79_3',
 '79_4',
 '79_6',
 '79_9']

In [61]:
len(results_list)

194000

In [62]:
# convert ids to original
with open("/data3/muntean/conversational-cache/data/star-ranking/CAST-manual-queries-star-L2-ranking-top1000-cache-top"+str(topk)+"-with-update.tsv", 'w') as outputfile:
    for (qid, pid, idx) in results_list:
        
        new_qid = offset2qid[qid]
        orig_qid = newqid2qid_dict[new_qid]
        
        new_pid = offset2pid[pid]
        orig_pid = newpid2pid_dict[new_pid]
        
        outputfile.write(f"{orig_qid}\t{orig_pid}\t{idx}\n")

# Eval results

In [64]:
# import pyterrier as pt
# pt.init()

In [65]:
qrel_path = "../data/CAST_qrels/qrels-docs.2019.txt"
qrels_df = pd.read_csv(qrel_path, delimiter=" ", header=None)
qrels_df[[3]] = qrels_df[[3]].astype(int)
qrels_df = qrels_df.drop([1], axis=1)
qrels_df.columns=["qid", "docno", "label"]
qrels = qrels_df

In [66]:
topics_path='../data/CAST-2019/test_manual_utterance.tsv' #manual

topics_df = pd.read_csv(topics_path, delimiter="\t", header=None)
topics_df.columns=["qid", "query"]
topics = topics_df
topics.head()

Unnamed: 0,qid,query
0,31_1,What is throat cancer?
1,31_2,Is throat cancer treatable?
2,31_3,Tell me about lung cancer.
3,31_4,What are lung cancer's symptoms?
4,31_5,Can lung cancer spread to the throat?


In [67]:
results_path = "../data/star-ranking/CAST-manual-queries-star-L2-ranking-top1000-cache-top"+str(topk)+"-with-update.tsv"
results_df = pd.read_csv(results_path, delimiter="\t", header=None)
results_df[3] = 1000-results_df[2]
results_df.columns=["qid", "docno", "rank", "score"]
results_df.head()
# Results produced by the transformers must have “qid”, “docno”, “score”, “rank” columns.

Unnamed: 0,qid,docno,rank,score
0,31_1,MARCO_3878347,1,999
1,31_1,MARCO_789620,2,998
2,31_1,MARCO_291003,3,997
3,31_1,MARCO_5625372,4,996
4,31_1,MARCO_2954451,5,995


In [68]:
%%time
pt.Experiment([results_df], topics, qrels, names=["STAR"], 
              eval_metrics=["map", "recip_rank", "recall_200", "P_3", "P_1", "ndcg_cut_3"])

CPU times: user 421 ms, sys: 8.06 ms, total: 429 ms
Wall time: 428 ms


Unnamed: 0,name,map,recip_rank,recall_200,P_3,P_1,ndcg_cut_3
0,STAR,0.18767,0.636636,0.430454,0.489403,0.49711,0.373501


In [65]:
# %%time
# res_per_query = pt.Experiment([results_df], topics, qrels, names=["STAR"], 
#               eval_metrics=["map", "recip_rank", "recall_200", "P_3", "P_1", "ndcg_cut_3"], perquery=True)
# res_per_query