In [1]:
import os
import sys
if sys.path[-1] != "../":
    sys.path.append("../")
    os.chdir("../")

import re
import numpy as np
import pandas as pd
# import seaborn as sns
from IPython.display import display
from random import sample
from transformers import AutoModel, AutoTokenizer, T5ForConditionalGeneration

import torch
import torch.nn.functional as F
from models.AutoModel import AutoModel as AM
from utils.util import *
from utils.index import *
from utils.data import *

from hydra import initialize, compose

config = Config()
with initialize(version_base=None, config_path="../data/config/"):
    overrides = [
        "base=NQ320k",
        # "base=MS300k",
        # "++plm=t5",
    ]
    hydra_config = compose(config_name="_example", overrides=overrides)
    config._from_hydra(hydra_config)

loaders = prepare_data(config)

loader_text = loaders["text"]
loader_query = loaders["query"]
text_dataset = loader_text.dataset
query_dataset = loader_query.dataset

# train_dataset = prepare_train_data(config, loader_text.dataset)
# train_query_dataset = train_dataset.query_datasets[0]

  from .autonotebook import tqdm as notebook_tqdm


[2023-10-09 06:36:33,593] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


[2023-10-09 06:36:35,968] INFO (Config) setting seed to 42...
[2023-10-09 06:36:35,974] INFO (Config) setting PLM to t5...
[2023-10-09 06:36:36,150] INFO (Config) Config: {'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'batch_size': 2, 'bf16': False, 'cache_root': 'data/cache/NQ320k', 'data_format': 'memmap', 'data_root': '/share/peitian/Data/AutoTSG', 'dataset': 'NQ320k', 'debug': False, 'deepspeed': None, 'device': 0, 'distill_src': 'none', 'early_stop_patience': 5, 'enable_all_gather': True, 'enable_distill': False, 'enable_inbatch_negative': True, 'epoch': 20, 'eval_batch_size': 2, 'eval_delay': 0, 'eval_flops': False, 'eval_metric': ['mrr', 'recall'], 'eval_metric_cutoff': [1, 5, 10, 100, 1000], 'eval_mode': 'retrieve', 'eval_posting_length': False, 'eval_set': 'dev', 'eval_step': '1e', 'fp16': False, 'grad_accum_step': 1, 'hits': 1000, 'index_shard': 32, 'index_thread': 10, 'index_type': 'invvec', 'learning_rate': 3e-06, 'load_ckpt': None, 'load_encode': False, 'l

In [12]:
code_type = "term"
# code_type = "2gram"
# code_type = "id"
code_tokenizer = "t5"
code_length = 34

t = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, code_tokenizer))

text_codes = np.memmap(
    f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp",
    mode="r",
    dtype=np.int32
).reshape(len(text_dataset), -1).copy()

# trie = TrieIndex(save_dir=f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}")
# trie.load()

# wordset = WordSetIndex(save_dir=f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}", sep_token_id=6)
# wordset.fit(None)

df = pd.DataFrame(text_codes)
duplicates = df.groupby(df.columns.tolist(),as_index=False).size()
duplicates = duplicates.sort_values("size", ascending=False)
duplicates.reset_index(drop=True, inplace=True)

dup = df.duplicated(keep=False).to_numpy()
dup_indices = np.argwhere(dup)[:, 0]
len(dup_indices)

4

In [13]:
# indices = random.sample(range(len(text_dataset)), 5)
# indices = range(10)
indices = [0]
# indices = dup_indices
text_code = text_codes[indices]
text_code[text_code == -1] = 0
# display(text_code)
display(t.batch_decode(text_code))
display(t.batch_decode(np.array(text_dataset[indices]["text"]["input_ids"])[:, :100]))

most_dup_idx = np.argwhere((text_codes == duplicates.iloc[0].to_numpy()[:-1]).all(-1))[:, 0]
most_dup_code = text_codes[most_dup_idx]
most_dup_code[most_dup_code == -1] = 0
most_dup_text = np.array(text_dataset[most_dup_idx]["text"]["input_ids"])[:, :code_length + 5]
t.batch_decode(most_dup_code), t.batch_decode(most_dup_text)

['<pad> email, marketing, mail, sending, sent, messages, advertising, campaigns, hide, customer, using, triggered, ads, online, purchase,</s><pad>']

['Email marketing Email marketing is the act of sending a commercial message, typically to a group of people, using email. In its broadest sense, every email sent to a potential or current customer could be considered email marketing. It usually involves using email to send advertisements, request business, or solicit sales or donations, and is meant to build loyalty, trust, or brand awareness. Marketing emails can be sent to a purchased lead list']

(['<pad> 1863, suspension, act, suspend, writ, civil, habeas, lincoln, proclamation, president,</s>',
  '<pad> 1863, suspension, act, suspend, writ, civil, habeas, lincoln, proclamation, president,</s>'],
 ['Habeas corpus suspension Act 1863 The Habeas Corpus Suspension Act, 12 Stat. 755 ( 1863 ), entitled An Act',
  'Habeas corpus suspension Act ( 1863 ) The Habeas Corpus Suspension Act, 12 Stat. 755 ( 1863 ),'])

In [None]:
# check query_codes
model = "weight-bs200-3-greedy-sample-tau5"
query_sets = ["train"]
dyn_text_codes = []
qrels = []
query_datasets = []
for query_set in query_sets:
    qrel = load_pickle(f"data/cache/{config.dataset}/dataset/query/{query_set}/qrels.pkl")
    dyn_text_codes.append(
        np.memmap(
            f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/{model}/{query_set}/codes.mmp",
            mode="r",
            dtype=np.int32
        ).reshape(len(qrel), -1, code_length).copy()
    )
    qrels.append(qrel)
    query_datasets.append(QueryDataset(config, query_set))

# text_idx: [[(qrel_idx, query_idx)], ...,]
tidx_2_qrel_query_idx_pair = defaultdict(lambda: [[] for _ in range(len(query_sets))])
for query_set_idx, qrel in enumerate(qrels):
    for j, rel in enumerate(qrel):
        tidx_2_qrel_query_idx_pair[rel[1]][query_set_idx].append((j, rel[0]))
# new_docs = {}
# for k,v in tidx_2_qrel_query_idx_pair.items():
#     has_common = 0
#     for x in v:
#         has_common += len(x) > 0
#     if has_common == len(query_sets):
#         new_docs[k] = v
# tidx_2_qrel_query_idx_pair = new_docs
same_count = 0
same = False
demo = 1

for tidx, qindices in tidx_2_qrel_query_idx_pair.items():
    text_code = text_codes[tidx]
    text_code[text_code == -1] = 0
    # print(f'******Text Code   ({tidx}):******\n{t.decode(text_code)}')

    for query_set_idx, query_set in enumerate(query_sets):
        query_dataset = query_datasets[query_set_idx]
        qindice = qindices[query_set_idx]

        for qrel_idx, qidx in qindice:
            dyn_text_code = dyn_text_codes[query_set_idx][qrel_idx]
            dyn_text_code[dyn_text_code == -1] = 0

            for c in dyn_text_code:
                if (c == text_code).all():
                    same_count += 1
                    same = True
                    # if demo:
                    #     print(f'******{query_set} Query ({qidx}):******\n{t.decode(query_dataset[qidx]["query"]["input_ids"])}')
                    #     print(f'******Text Code   ({tidx}):******\n{t.decode(text_code)}')
                    #     print(f"******Sorted Text Code for {query_set}:******\n{t.decode(dyn_text_code)}")
                else:
                    same = False
            if not same and demo:
                print(f'\n******{query_set} Query ({qidx}):******\n{t.decode(query_dataset[qidx]["query"]["input_ids"])}')
                print(f'******Text Code   ({tidx}):******\n{t.decode(text_code)}')
                print(f"******Sorted Text Code for {query_set}:******\n{t.batch_decode(dyn_text_code, skip_special_tokens=True)}")

                x = input()
                if x == "s":
                    raise StopIteration

same_count

In [5]:
# cases that either res1 or res2 succeeds

model1_result = load_pickle(f"data/cache/MSMARCO-passage/retrieve/AutoTSG/dev/retrieval_result.pkl")
model2_result = load_pickle(f"data/cache/MSMARCO-passage/retrieve/BM25/dev/retrieval_result.pkl")

code_type1 = "words_comma_plus_stem"
code_type2 = "words_comma_plus_stem"
code_tokenizer = "t5"
code_length = 34

text_codes1 = np.memmap(
    f"data/cache/{config.dataset}/codes/{code_type1}/{code_tokenizer}/{code_length}/codes.mmp",
    mode="r",
    dtype=np.int32
).reshape(len(text_dataset), -1).copy()
text_codes2 = np.memmap(
    f"data/cache/{config.dataset}/codes/{code_type2}/{code_tokenizer}/{code_length}/codes.mmp",
    mode="r",
    dtype=np.int32
).reshape(len(text_dataset), -1).copy()

t = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, code_tokenizer))

query_dataset = QueryDataset(config)

positives = load_pickle(f"data/cache/{config.dataset}/dataset/query/dev/positives.pkl")
positive_docs = {}

for k, v in positives.items():
    positive = v[0]
    positive_docs[positive] = k

query_sets = ["train"]
query_datasets = []
queries = defaultdict(list)
for query_set_idx, query_set in enumerate(query_sets):
    qrels = load_pickle(f"data/cache/{config.dataset}/dataset/query/{query_set}/qrels.pkl")
    for x in qrels:
        qidx = x[-2]
        tidx = x[-1]
        queries[tidx].append((query_set_idx, qidx))
    query_datasets.append(QueryDataset(config, query_set))

for qidx, positive in positives.items():
    # if qidx != 3315:
    #     continue
    positive = positive[0]
    model1_res = model1_result[qidx]
    model2_res = model2_result[qidx]
    query = query_dataset[qidx]["query"]["input_ids"]
    showcase = False
    if positive in model1_res and positive not in model2_res:
    # if (positive in model1_res and model1_res.index(positive) < 5) and (positive not in model2_res or model2_res.index(positive) > 5):
        text_code1 = text_codes1[positive]
        text_code2 = text_codes2[positive]
        false_pos_code = text_codes2[model2_res[0]]
        string = f"Model1 wins: {qidx, positive}"
        showcase = True

    elif positive in model2_res and positive not in model1_res:
        text_code1 = text_codes1[positive]
        text_code2 = text_codes2[positive]
        false_pos_code = text_codes1[model1_res[0]]
        string = f"Model2 wins: {qidx, positive}"
        showcase = True

    if showcase:
        print(f"{string:*^50}")
        print(f"{'Query': <50}: {t.decode(query, skip_special_tokens=True)}")
        print(f"{f'Target Anchor ({code_type1})': <50}: {t.decode(text_code1[text_code1 != -1], skip_special_tokens=True)}")
        # print(f"{f'Target Anchor ({code_type2})': <50}: {t.decode(text_code2[text_code2 != -1], skip_special_tokens=True)}")
        print(f"{'False Positive Anchor': <50}: {t.decode(false_pos_code[false_pos_code != -1], skip_special_tokens=True)}")
        # print(f"{f'Target Title ({code_type1})': <50}: {t.decode(text_dataset[positive]['text']['input_ids'][:20], skip_special_tokens=True)}")
        # print(f"{f'Negative Title  ({code_type2})': <50}: {t.batch_decode(text_dataset[model2_res]['text']['input_ids'][:, :20], skip_special_tokens=True)}")

        all_q = [[] for _ in query_sets]
        for query_set_idx, query_idx in queries[positive]:
            query = t.decode(query_datasets[query_set_idx][query_idx]["query"]["input_ids"], skip_special_tokens=True)
            all_q[query_set_idx].append(query)
        for j, query in enumerate(all_q):
            if len(query):
                print(f"{query_sets[j] + ' Queries': <50}: {query}")        
        x = input()
        if x == "s":
            break

[2023-08-09 02:44:40,033] INFO (Dataset) initializing MSMARCO-passage memmap Query dev dataset...
[2023-08-09 02:44:41,091] INFO (Dataset) initializing MSMARCO-passage memmap Query train dataset...


***********Model2 wins: (259, 7067274)************
Query                                             : ____________________ is considered the father of modern medicine.
Target Anchor (words_comma_plus_stem)             : medicine, father, modern, hippocrates, considered, true, punishment, believe, false, weegy, illness, because,
False Positive Anchor                             : renaissance, civilisation, rebirth, western, european, 15th, ad, key, changes, century,
***********Model2 wins: (1523, 7067796)***********
Query                                             : botulinum definition
Target Anchor (words_comma_plus_stem)             : toxin, botulinum, botulism, neurotoxin, powerful, clostridium, produced, definition, medical,
False Positive Anchor                             : fungal, toxin, fungi, toxic, fungus, weapons, botulinum, dangerous, aflatoxins, none,
***********Model2 wins: (2319, 7067891)***********
Query                                             : do physicians pay 

In [7]:
# cases for the failed dev queries and the corresponding train queries for a paticular model

dataset = "MSMARCO-passage"
retrieval_result = load_pickle(f"data/cache/{dataset}/retrieve/AutoTSG/dev/retrieval_result.pkl")
positives = load_pickle(f"data/cache/{dataset}/dataset/query/dev/positives.pkl")
query_dataset = QueryDataset(config)

miss_queries = {}
miss_docs = {}

success_queries = {}
success_docs = {}

for k, v in positives.items():
    positive = v[0]
    res = retrieval_result[k]
    if positive not in res:
        miss_queries[k] = positive
        miss_docs[positive] = k
    else:
        success_queries[k] = positive
        success_docs[positive] = k

query_sets = ["train"]
overlap_miss = defaultdict(list)
overlap_success = defaultdict(list)

query_datasets = []

for query_set_idx, query_set in enumerate(query_sets):
    qrels = load_pickle(f"data/cache/{config.dataset}/dataset/query/{query_set}/qrels.pkl")
    for x in qrels:
        qidx = x[-2]
        tidx = x[-1]
        if tidx in miss_docs:
            overlap_miss[tidx].append((query_set_idx, qidx))
        elif tidx in success_docs:
            overlap_success[tidx].append((query_set_idx, qidx))
    query_datasets.append(QueryDataset(config, query_set))

print(f"Mean query number for missed queries: {mean_len(overlap_miss.values())}\nMean query number for succeeded queries: {mean_len(overlap_success.values())}")

for k, v in miss_queries.items():
    # positive = v[0]
    positive = v

    query = query_dataset[k]["query"]["input_ids"]
    text = text_dataset[positive]["text"]["input_ids"]
    text_code = text_codes[positive]
    print(f"\n{f'Qidx: {k} Tidx: {positive}':*^40}")
    print(f"{'Query': <20}: {t.decode(query, skip_special_tokens=True)}")
    # print(f"{'Target':*^25}\n{t.decode(text, skip_special_tokens=True)}")
    print(f"{'Anchor': <20}: {t.decode(text_code[text_code != -1], skip_special_tokens=True)}")
    print(f"{'Text': <20}: {t.decode(text_dataset[positive]['text']['input_ids'][:100], skip_special_tokens=True)}")
    if positive in overlap_miss:
        qindices = overlap_miss[positive]
        queries = [[] for _ in query_sets]
        for query_set_idx, query_idx in qindices:
            queries[query_set_idx].append(t.decode(query_datasets[query_set_idx][query_idx]['query']['input_ids'], skip_special_tokens=True))
        for j, query in enumerate(queries):
            if len(query):
                print(f"{query_sets[j] + ' Queries': <20}: {query}")
    x = input()
    if x == "s":
        break

[2023-08-06 05:50:00,647] INFO (Dataset) initializing MSMARCO-passage memmap Query dev dataset...
[2023-08-06 05:50:00,990] INFO (Dataset) initializing MSMARCO-passage memmap Query train dataset...


Mean query number for missed queries: 1.424
Mean query number for succeeded queries: 1.1515151515151516

********Qidx: 1284 Tidx: 7067032********
Query               : how many years did william bradford serve as governor of plymouth colony?
Anchor              : bradford, plymouth, william, governor, 1657, 1590, colony, leiden, mayflower, leader,
Text                : http://en.wikipedia.org/wiki/William_Bradford_(Plymouth_Colony_governor) William Bradford (c.1590 â 1657) was an English Separatist leader in Leiden, Holland and in Plymouth Colony was a signatory to the Mayflower Compact. He served as Plymouth Colony Governor five times covering about thirty years between 1621 and 1657.

********Qidx: 3650 Tidx: 7067056********
Query               : define preventive
Anchor              : preventive, deter, aggression, adjective, obstacle, military, hindering, comparative, most, carried, more, superlative, acting,
Text                : Adjective[edit] preventive â(comparative more preve

In [None]:
# false prune plot
model = "BOW_qg"
retrieval_result = load_pickle(f"data/cache/{config.dataset}/retrieve/{model}/dev/bm25-trie.pkl")
positives = load_pickle(f"data/cache/{config.dataset}/dataset/query/dev/positives.pkl")

# code_type = "words_comma_plus_stem"
# code_type = "title_comma-first"
code_type = "bm25_comma"
code_length = 34
code_tokenizer = "t5"
wordset = WordSetIndex(save_dir=f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}", sep_token_id=6)
wordset.fit(None)

all_missed = 0
for k, v in retrieval_result.items():
    positive = positives[k][0]
    if positive not in v:
        all_missed += 1
print(all_missed)

cutoffs = [1,2,3,4,5]
counts = [{"all_missed": 0, "false_prune": 0, "wordset_match": 0, "decode_step": i} for i in cutoffs]

docs = wordset.docs
stopwords = np.array([0, 1, -1])
skip_qindices = {}

for i, cutoff in enumerate(cutoffs):
    false_prune_count = 0
    wordset_match_count = 0

    for qidx, res in retrieval_result.items():
        if qidx in skip_qindices:
            continue

        positive = positives[qidx][0]
        if positive not in res:
            gt_doc = docs[positive]
            generated_doc = docs[res]
            # invalid_at_cutoff = ((gt_doc[cutoff] == -1) * (generated_doc[:, cutoff] == -1)).astype(bool)
            # print(invalid_at_cutoff)
            
            diff_at_cutoff = (gt_doc[cutoff] != generated_doc[:, cutoff])
            if diff_at_cutoff.all():
                false_prune_count += 1
                skip_qindices[qidx] = 1
                for gdoc in generated_doc:
                    overlap = np.intersect1d(gt_doc, gdoc[:cutoff])  # n
                    idx = (overlap[..., None] == stopwords[None, ...]).any(-1)
                    overlap = overlap[~idx]                 # <=n
                    if len(overlap):
                        wordset_match_count += 1
                        break

        
    if i == 0:
        counts[i]["all_missed"] = all_missed
    else:
        all_missed = all_missed - counts[i - 1]["false_prune"]
        counts[i]["all_missed"] = all_missed
    
    counts[i]["false_prune"] = false_prune_count
    counts[i]["wordset_match"] = wordset_match_count
    
print(counts)
data = pd.DataFrame(counts)
data.drop(columns=["all_missed"], inplace=True)
data = data.melt(id_vars=["decode_step"], value_vars=["false_prune", "wordset_match"], value_name="count")
ax = sns.barplot(data, x="decode_step", y="count", hue="variable")
ax.set_ylabel("#Error Query")
ax.set_xlabel("Decode Step")

In [None]:
# check duplicated codes and corresponding train queries

dup_index = np.argwhere((text_codes == duplicates.loc[0].to_numpy()[:-1]).all(-1))[:, 0]

# get train queries
queries = [None for _ in range(len(dup_index))]
arange = np.arange(len(queries))
for x in train_dataset.qrels:
    if x[-1] in dup_index:
        idx = arange[dup_index == x[-1]][0]
        queries[idx] = t.decode(train_dataset.query_datasets[0][x[-2]]["query"]["input_ids"], skip_special_tokens=True)
print(len(dup_index))

j = 0
with open("/share/project/peitian/Data/Adon/Top300k/collection.tsv") as f:
    for i,line in enumerate(f):
        if i in dup_index:
            for k, v in positives.items():
                if idx in v:
                    print(f"{'Dev Query':*^20}\n{t.decode(query_dataset[k]['query']['input_ids'], skip_special_tokens=True)}")
            print(f"{'Train Query':*^20}\n{queries[j]}")
            print(line)
            j += 1
            x = input()
            if x == "s":
                break

In [None]:
# cases that either res1 or res2 succeeds (different text_codes)

model1_result = load_pickle(f"data/cache/NQ/retrieve/BOW_doct5-miss-doc/dev/best.pkl")
model2_result = load_pickle(f"data/cache/NQ/retrieve/BOW_doct5-miss-doc/dev/50.pkl")

t = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, code_tokenizer))

query_dataset = QueryDataset(config)

positives = load_pickle("data/cache/NQ/dataset/query/dev/positives.pkl")
positive_docs = {}

for k, v in positives.items():
    positive = v[0]
    positive_docs[positive] = k

query_sets = ["train-sub", "doct5-miss-sub"]
query_datasets = []
queries = defaultdict(list)
for query_set_idx, query_set in enumerate(query_sets):
    qrels = load_pickle(f"data/cache/{config.dataset}/dataset/query/{query_set}/qrels.pkl")
    for x in qrels:
        qidx = x[-2]
        tidx = x[-1]
        queries[tidx].append((query_set_idx, qidx))
    query_datasets.append(QueryDataset(config, query_set))

code_type1 = "words_comma"
code_length1 = 34
text_codes1 = np.memmap(
    f"data/cache/{config.dataset}/codes/{code_type1}/{code_tokenizer}/{code_length1}/codes.mmp",
    mode="r",
    dtype=np.int32
).reshape(len(text_dataset), -1).copy()

code_type2 = "words_comma"
code_length2 = 50
text_codes2 = np.memmap(
    f"data/cache/{config.dataset}/codes/{code_type2}/{code_tokenizer}/{code_length2}/codes.mmp",
    mode="r",
    dtype=np.int32
).reshape(len(text_dataset), -1).copy()

for qidx, positive in positives.items():
    positive = positive[0]
    model1_res = model1_result[qidx]
    model2_res = model2_result[qidx]
    query = query_dataset[qidx]["query"]["input_ids"]
    showcase = False
    if positive in model1_res and positive not in model2_res:
        text_code = text_codes1[positive]
        false_pos_code = text_codes2[model2_res[0]]
        string = f"Model1 wins: {qidx, positive}"
        showcase = True
    
    elif positive in model2_res and positive not in model1_res:
        text_code = text_codes2[positive]
        false_pos_code = text_codes1[model1_res[0]]
        string = f"Model2 wins: {qidx, positive}"
        showcase = True

    if showcase:
        print(f"{string:*^50}")
        print(f"{'Query': <30}: {t.decode(query, skip_special_tokens=True)}")
        print(f"{'Target Anchor': <30}: {t.decode(text_code[text_code != -1], skip_special_tokens=True)}")
        string = f"False Positive Anchor"
        print(f"{string: <30}: {t.decode(false_pos_code[false_pos_code != -1], skip_special_tokens=True)}")

        all_q = [[] for _ in query_sets]
        for query_set_idx, query_idx in queries[positive]:
            query = t.decode(query_datasets[query_set_idx][query_idx]["query"]["input_ids"], skip_special_tokens=True)
            all_q[query_set_idx].append(query)
        for j, query in enumerate(all_q):
            if len(query):
                print(f"{query_sets[j] + ' Queries': <30}: {query}")        
        x = input()
        if x == "s":
            break

In [None]:
# check the document frequency of each code position
code_type = "bm25_comma-sample"
code_tokenizer = "t5"
code_length = 26
wordset = WordSetIndex(save_dir=f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}", sep_token_id=6)
wordset.fit(None)

max_word_num = (wordset.docs != -1).sum(-1).max()
doc_nums = np.zeros(max_word_num, dtype=np.int32)
valid_nums = np.zeros(max_word_num, dtype=np.int32)

for i in range(max_word_num):
    i_th_position = wordset.docs[:, i]
    valid_nums[i] = (i_th_position != -1).sum()
    i_th_position = i_th_position[i_th_position != -1]
    inverted_lists = wordset.inverted_lists[i_th_position]
    doc_nums[i] = sum([len(x) for x in inverted_lists])
wordset.inverse_vocab.shape, doc_nums / valid_nums

In [None]:
# create new query set based on an existing one

# dataset = "Top300k-filter"
dataset = "NQ"
ori_query_set = "nci"
query_set = "nci-miss"
k = 3

try:
    qid2idx = load_pickle(f"data/cache/{dataset}/dataset/query/{ori_query_set}/id2index.pkl")
except FileNotFoundError:
    qid2idx = {}
    with open(f"{config.data_root}/{dataset}/queries.{ori_query_set}.tsv") as f:
        for qidx, line in enumerate(tqdm(f, desc="Collecting qid2idx")):
            qid = line.split("\t")[0]
            qid2idx[qid] = qidx
            
tid2idx = load_pickle(f"data/cache/{dataset}/dataset/text/id2index.pkl")

qindices = []
tid2qrels = defaultdict(list)

train_positives = load_pickle(f"data/cache/{dataset}/dataset/query/train/positives.pkl")
train_positives = set([x[0] for x in train_positives.values()])
miss_docs = set(range(len(text_dataset))) - train_positives
print(f"number of documents missing in training set: {len(miss_docs)}")

with open(f"{config.data_root}/{dataset}/qrels.{ori_query_set}.tsv") as ori_qrel_file, open(f"{config.data_root}/{dataset}/qrels.{query_set}.tsv", "w") as qrel_file, open(f"{config.data_root}/{dataset}/queries.{ori_query_set}.tsv") as ori_query_file, open(f"{config.data_root}/{dataset}/queries.{query_set}.tsv", "w") as query_file:
    for i, line in enumerate(ori_qrel_file):
        qid, _, tid, _ = line.strip().split("\t")
        qidx = qid2idx[qid]

        # filter out the existing ones
        tidx = tid2idx[tid]
        if tidx in miss_docs:
            tid2qrels[tid].append(line)
            qindices.append(qidx)

        # # keep the first k elements
        # if len(tid2qrels[tid]) >= k:
        #     continue
        # else:
        #     tid2qrels[tid].append(line)
        #     qindices.append(qidx)

    qindices = set(qindices)
    for i, line in enumerate(ori_query_file):
        if i in qindices:
            query_file.write(line)

    for qrels in tid2qrels.values():
        for line in qrels:
            qrel_file.write(line)

In [2]:
# create pseudo-queries from document
dataset = "MSMARCO-passage"
query_set = "doc"
text_col = [2]
query_length = 32

with open(f"{config.data_root}/{dataset}/collection.tsv") as collection_file, open(f"{config.data_root}/{dataset}/queries.{query_set}.tsv", "w") as query_file, open(f"{config.data_root}/{dataset}/qrels.{query_set}.tsv", "w") as qrel_file:
    for tidx, line in enumerate(tqdm(collection_file)):
        fields = line.split("\t")
        tid = fields[0]
        query_fields = [field.strip() for col_idx, field in enumerate(fields) if col_idx in text_col and len(field) > 1]
        # maximum number of words
        query = " ".join(query_fields).split(" ")[:query_length]
        query = " ".join(query)

        query_file.write("\t".join([str(tidx), query]) + "\n")
        qrel_file.write("\t".join([str(tidx), "0", tid, "1"]) + "\n")

8841823it [01:04, 136960.90it/s]


In [None]:
# filter query sets with ANCE
ori_query_set = "doct5-5"
query_set = "doct5-5-filter"

filter_model = "ANCE"
filter_results = load_pickle(f"data/cache/{config.dataset}/retrieve/{filter_model}/{ori_query_set}/retrieval_result.pkl")
filter_results = {k: v[:10] for k, v in filter_results.items()}

tid2idx = load_pickle(f"data/cache/{config.dataset}/dataset/text/id2index.pkl")
qid2idx = load_pickle(f"data/cache/{config.dataset}/dataset/query/{ori_query_set}/id2index.pkl")

with open(f"{config.data_root}/{config.dataset}/qrels.{ori_query_set}.tsv") as ori_qrel_file, open(f"{config.data_root}/{config.dataset}/queries.{ori_query_set}.tsv") as ori_query_file, open(f"{config.data_root}/{config.dataset}/qrels.{query_set}.tsv", "w") as qrel_file, open(f"{config.data_root}/{config.dataset}/queries.{query_set}.tsv", "w") as query_file:
    qindices = set()
    for i, line in enumerate(ori_qrel_file):
        qid, _, tid, _ = line.strip().split("\t")
        qidx = qid2idx[qid]
        tidx = tid2idx[tid]
        if tidx in filter_results[qidx]:
            qrel_file.write(line)
            qindices.add(qidx)
    for i, line in enumerate(ori_query_file):
        if i in qindices:
            query_file.write(line)

In [None]:
# filter MSMARCO Top300k

# filter_indices = []
# for dup_idx in dup_indices:
#     text = t.decode(text_dataset[dup_idx]["text"]["input_ids"], skip_special_tokens=True)
#     text_length = len(text.split(" "))
#     if text_length < 100:
#         filter_indices.append(dup_idx)

filter_indices = set(np.argwhere(df.duplicated().to_numpy())[:, 0].tolist())

with open("/share/project/peitian/Data/Adon/Top300k/collection.tsv") as ori_collection, open("/share/project/peitian/Data/Adon/Top300k/qrels.train.tsv") as ori_train_qrels, open("/share/project/peitian/Data/Adon/Top300k/qrels.dev.tsv") as ori_dev_qrels, open("/share/project/peitian/Data/Adon/Top300k-filter/collection.tsv", "w") as collection, open("/share/project/peitian/Data/Adon/Top300k-filter/qrels.train.tsv", "w") as train_qrels, open("/share/project/peitian/Data/Adon/Top300k-filter/qrels.dev.tsv", "w") as dev_qrels:
    # shutil.copy("/share/project/peitian/Data/Adon/Top300k/queries.train.tsv", "/share/project/peitian/Data/Adon/Top300k-filter/queries.train.tsv")
    # shutil.copy("/share/project/peitian/Data/Adon/Top300k/queries.dev.tsv", "/share/project/peitian/Data/Adon/Top300k-filter/queries.dev.tsv")

    for i, line in enumerate(ori_collection):
        if i not in filter_indices:
            collection.write(line)

    tid2idx = load_pickle("/share/project/peitian/Code/Uni-Retriever/src/data/cache/Top300k/dataset/text/id2index.pkl")
    for line in ori_train_qrels:
        qid, _, tid, _ = line.strip().split()
        tidx = tid2idx[tid]
        if tidx in filter_indices:
            continue
        train_qrels.write(line)
    
    for line in ori_dev_qrels:
        qid, _, tid, _ = line.strip().split()
        tidx = tid2idx[tid]
        if tidx in filter_indices:
            continue
        dev_qrels.write(line)

In [None]:
# convert yujia DSI code to my format

tid2idx = load_pickle("/share/project/peitian/Code/Uni-Retriever/src/data/cache/Top300k-filter/dataset/text/id2index.pkl")

code_type = "DSI-semantic"
code_tokenizer = "t5"
code_length = 10

t = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, code_tokenizer))
text_codes = np.memmap(
    makedirs(f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="w+",
    dtype=np.int32,
    shape=(len(text_dataset), code_length)
)
text_codes[:, 0] = 0
text_codes[:, 1:] = -1

count = 0
with open("/share/project/webbrain-zhouyujia/transfer/data/encoded_docid/t5_semantic_structured_top_300k.txt") as f:
    index = 0
    for line in tqdm(f):
        tid, code = line.strip().split()
        tid = tid[1:-1].upper()
        if tid in tid2idx:
            count += 1
            code = [int(x) for x in code.split(",")]
            text_codes[index, 1:len(code)+1] = code
            index += 1

In [None]:
# convert yujia ultron code to my format

dataset = "Rand300k-filter"
tid2idx = load_pickle(f"/share/project/peitian/Code/Uni-Retriever/src/data/cache/{dataset}/dataset/text/id2index.pkl")

code_type = "ultron"
code_tokenizer = "t5"
code_length = 34

config = Config()
with initialize(version_base=None, config_path="../data/config/"):
    overrides = [
        f"base={dataset}",
    ]
    hydra_config = compose(config_name="_example", overrides=overrides)
    config._from_hydra(hydra_config)
text_dataset = TextDataset(config)
# t = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, code_tokenizer))
text_codes = np.memmap(
    makedirs(f"data/cache/{dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="w+",
    dtype=np.int32,
    shape=(len(text_dataset), code_length)
)
text_codes[:, 0] = 0
text_codes[:, 1:] = -1

count = 0
with open("/share/project/webbrain-zhouyujia/transfer/data/encoded_docid/t5_url_title_rand_300k.txt") as f:
    index = 0
    for line in tqdm(f):
        tid, code = line.strip().split()
        tid = tid[1:-1].upper()
        if tid in tid2idx:
            count += 1
            code = [int(x) for x in code.split(",")]
            if len(code) > 33:
                code = code[:33]
                code[-1] = 1
            text_codes[index, 1:len(code)+1] = code
            index += 1

In [None]:
# convert yujia doct5 to my format

tid2idx = load_pickle("/share/project/peitian/Code/Uni-Retriever/src/data/cache/Top300k-filter/dataset/text/id2index.pkl")

with open("/share/project/webbrain-zhouyujia/transfer/data/msmarco-data/fake_query_10_all.txt") as f, open("/share/project/peitian/Data/Adon/Top300k-filter/queries.doct5.tsv", "w") as query_file, open("/share/project/peitian/Data/Adon/Top300k-filter/qrels.doct5.tsv", "w") as qrel_file:
    for i, line in enumerate(tqdm(f)):
        tid, query = line.split("\t")
        tid = tid[1:-1].upper()

        qid = str(i)
        if tid in tid2idx:
            query_file.write("\t".join([qid, query]))
            qrel_file.write("\t".join([qid, "0", tid, "1"]) + "\n")

In [None]:
# merge ultron code
code_type = "ultron"
code_tokenizer = "t5"
code_length = 34
text_codes_rand = np.memmap(
    makedirs(f"data/cache/Rand300k-filter/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="r",
    dtype=np.int32
).reshape(-1, code_length).copy()
text_codes_top = np.memmap(
    makedirs(f"data/cache/Top300k-filter/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="r",
    dtype=np.int32
).reshape(-1, code_length).copy()

print(text_codes_rand.shape, text_codes_top.shape)

config = Config()
with initialize(version_base=None, config_path="../data/config/"):
    overrides = [
        f"base=MS600k",
    ]
    hydra_config = compose(config_name="_example", overrides=overrides)
    config._from_hydra(hydra_config)
text_dataset = TextDataset(config)
text_codes = np.memmap(
    makedirs(f"data/cache/MS600k/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="w+",
    dtype=np.int32,
    shape=(len(text_dataset), code_length)
)
text_codes[:, 0] = 0
text_codes[:, 1:] = -1

text_codes[:len(text_codes_rand)] = text_codes_rand
text_codes[len(text_codes_rand):] = text_codes_top

In [None]:
# split code from MS600k

# code_type = "ANCE_hier"
code_type = "NCI-bias"
code_tokenizer = "t5"
code_length = 10
# code_type_split = "ANCE_hier_600k"
code_type_split = "NCI_600k-bias"

config = Config()
with initialize(version_base=None, config_path="../data/config/"):
    overrides = [
        f"base=MS600k",
    ]
    hydra_config = compose(config_name="_example", overrides=overrides)
    config._from_hydra(hydra_config)
text_dataset = TextDataset(config)

text_codes = np.memmap(
    makedirs(f"data/cache/MS600k/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="r",
    dtype=np.int32,
).reshape(text_dataset.text_num, code_length).copy()

config = Config()
with initialize(version_base=None, config_path="../data/config/"):
    overrides = [
        f"base=Rand300k-filter",
    ]
    hydra_config = compose(config_name="_example", overrides=overrides)
    config._from_hydra(hydra_config)
rand_text_dataset = TextDataset(config)

text_codes_rand = np.memmap(
    makedirs(f"data/cache/Rand300k-filter/codes/{code_type_split}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="w+",
    dtype=np.int32,
    shape=(rand_text_dataset.text_num, code_length),
)
text_codes_top = np.memmap(
    makedirs(f"data/cache/Top300k-filter/codes/{code_type_split}/{code_tokenizer}/{code_length}/codes.mmp"),
    mode="w+",
    dtype=np.int32,
    shape=(text_dataset.text_num - text_codes_rand.shape[0], code_length)    
)

text_codes_rand[:] = text_codes[:rand_text_dataset.text_num]
text_codes_top[:] = text_codes[rand_text_dataset.text_num:]

In [None]:
# split some queries

ori_dataset = "NQ"
dataset = "NQ-50k-seen"
query_set = "doct5"

tid2idx = load_pickle(f"data/cache/{dataset}/dataset/text/id2index.pkl")
qids = set()

with \
    open(f"{config.data_root}/{ori_dataset}/qrels.{query_set}.tsv") as ori_qrel_file, \
    open(f"{config.data_root}/{dataset}/qrels.{query_set}.tsv", "w") as qrel_file, \
    open(f"{config.data_root}/{ori_dataset}/queries.{query_set}.tsv") as ori_query_file, \
    open(f"{config.data_root}/{dataset}/queries.{query_set}.tsv", "w") as query_file:
    for i, line in enumerate(ori_qrel_file):
        qid, _, tid, _ = line.strip().split("\t")
        if tid in tid2idx:
            qids.add(qid)
            qrel_file.write(line)
    
    for line in ori_query_file:
        qid, query = line.split("\t")
        if qid in qids:
            query_file.write(line)


In [None]:
# convert t5 code to bert code

code_type = "chat"
# code_type = "words-weight"
code_tokenizer = "t5"
code_length = 50
# code_length = 34

text_codes = np.memmap(
    f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp",
    mode="r",
    dtype=np.int32
).reshape(len(text_dataset), -1).copy()

new_code_tokenizer = "bert"
new_t = AutoTokenizer.from_pretrained(os.path.join(config.plm_root, new_code_tokenizer))

new_text_codes = []
max_length = 0
for text_code in tqdm(text_codes):
    text_code = text_code[text_code != -1]
    decoded = t.decode(text_code, skip_special_tokens=True)
    encoded = new_t.encode(decoded, padding=False)
    if new_code_tokenizer == "bert":
        encoded = new_t.encode(decoded, padding=False)[1:]
    new_text_codes.append(encoded)
    if len(encoded) > max_length:
        max_length = len(encoded)

# plus one for the leading padding token
mmp_path = f"data/cache/{config.dataset}/codes/{code_type}/{new_code_tokenizer}/{max_length + 1}/codes.mmp"
makedirs(mmp_path)
new_codes_mmp = np.memmap(
    mmp_path,
    mode="w+",
    dtype=np.int32,
    shape=(len(text_dataset), max_length + 1)
)
new_codes_mmp[:,1:] = -1

for qrel_idx, code in enumerate(tqdm(new_text_codes)):
    new_codes_mmp[qrel_idx, 1:len(code) + 1] = code

print(f"saving at {mmp_path}")

In [None]:
# reorder words-comma in ascending df

new_docs = np.zeros_like(wordset.docs) - 1
for i, doc in enumerate(tqdm(wordset.docs)):
    doc = doc[doc != -1]
    inverted_lists = wordset.inverted_lists[doc]
    dfs = [len(x) for x in inverted_lists]

    new_doc = sorted(zip(doc, dfs), key=lambda x: x[1])
    new_doc = [x[0] for x in new_doc]
    new_docs[i, :len(new_doc)] = new_doc

code_type = "words-comma-df"
path = f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"
makedirs(path)

print(f"saving at {path}...")

new_text_codes = np.memmap(
    path,
    mode="w+",
    dtype=np.int32,
    shape=text_codes.shape
)
new_text_codes[:, 0] = text_codes[:, 0]
new_text_codes[:, 1:] = -1

sep_token_id = int(text_codes[0][text_codes[0] != -1][-1])

for i, new_doc in enumerate(tqdm(new_docs)):
    new_doc = new_doc[new_doc != -1]
    tokens = wordset.inverse_vocab[new_doc].reshape(-1)
    tokens = tokens[tokens != -1].tolist()
    tokens.append(sep_token_id)
    new_text_codes[i, 1: len(tokens) + 1] = tokens

In [None]:
# add new documents to Top300k-filter
tid2index = load_pickle("/share/project/peitian/Code/Uni-Retriever/src/data/cache/Top300k-filter/dataset/text/id2index.pkl")

with \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/collection.backup.tsv") as f, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/collection.tsv", "w") as g, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/qrels.dev.backup.tsv") as dev_qrel, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/qrels.dev.tsv", "w") as new_dev_qrel:

    tids = []
    for i, line in enumerate(f):
        tid = line.split("\t")[0]
        if tid not in tid2index:
            tids.append(tid)
            g.write(line)
    
    for line in dev_qrel:
        qid, _, tid, _ = line.strip().split()
        if tid in tids:
            new_dev_qrel.write(line)

In [None]:
tid2index = load_pickle("/share/project/peitian/Code/Uni-Retriever/src/data/cache/Rand300k-filter/dataset/text/id2index.pkl")
tindex2id = {v: k for k, v in tid2index.items()}
dup_tids = {tindex2id[x] for x in dup_indices}

# filtered 300k
with \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/collection.backup.tsv") as collection_file, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/collection.tsv", "w") as new_collection_file, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/qrels.dev.backup.tsv") as dev_qrel, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/qrels.dev.tsv", "w") as new_dev_qrel:

    tids = []
    for i, line in enumerate(collection_file):
        tid = line.split("\t")[0]
        if tid in tid2index and tid not in dup_tids:
            tids.append(tid)
            new_collection_file.write(line)
    
    for line in dev_qrel:
        qid, _, tid, _ = line.strip().split()
        if tid in tids:
            new_dev_qrel.write(line)

In [None]:
tid2index = load_pickle("data/cache/Rand300k-filter/dataset/text/id2index.pkl")
tindex2id = {v: k for k, v in tid2index.items()}

positives = load_pickle("data/cache/Rand300k-filter/dataset/query/dev/positives.pkl")
positives = set([v[0] for v in positives.values()])
positive_tids = [tindex2id[x] for x in positives]
complement = sample(list([k for k, v in tid2index.items() if v not in positives]), 100000 - len(positives))
tids = set(positive_tids + complement)
print(len(tids))

# filtered 300k
with \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/collection.backup.tsv") as collection_file, \
    open("/share/project/peitian/Data/Adon/Rand100k-filter/collection.tsv", "w") as new_collection_file, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/qrels.dev.backup.tsv") as dev_qrel, \
    open("/share/project/peitian/Data/Adon/Rand100k-filter/qrels.dev.tsv", "w") as new_dev_qrel, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/queries.dev.backup.tsv") as dev_query, \
    open("/share/project/peitian/Data/Adon/Rand100k-filter/queries.dev.tsv", "w") as new_dev_query:

    for i, line in enumerate(collection_file):
        tid = line.split("\t")[0]
        if tid in tids:
            tids.add(tid)
            new_collection_file.write(line)

    qids = set()
    for line in dev_qrel:
        qid, _, tid, _ = line.strip().split()
        if tid in tids:
            new_dev_qrel.write(line)
            qids.add(qid)

    for line in dev_query:
        qid = line.strip().split()[0]
        if qid in qids:
            new_dev_query.write(line)

In [None]:
with \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/collection.tsv") as rand_collection_file, \
    open("/share/project/peitian/Data/Adon/Top300k-filter/collection.tsv") as top_collection_file, \
    open("/share/project/peitian/Data/Adon/MS600k/collection.tsv", "w") as new_collection_file, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/qrels.dev.tsv") as rand_dev_qrel, \
    open("/share/project/peitian/Data/Adon/Top300k-filter/qrels.dev.tsv") as top_dev_qrel, \
    open("/share/project/peitian/Data/Adon/MS600k/qrels.dev.tsv", "w") as new_dev_qrel, \
    open("/share/project/peitian/Data/Adon/Rand300k-filter/queries.dev.tsv") as rand_dev_query, \
    open("/share/project/peitian/Data/Adon/Top300k-filter/queries.dev.tsv") as top_dev_query, \
    open("/share/project/peitian/Data/Adon/MS600k/queries.dev.tsv", "w") as new_dev_query:

    rand_tids = []
    for line in rand_collection_file:
        tid = line.split("\t")[0]
        rand_tids.append(tid)
        new_collection_file.write(line)
    
    top_tids = []
    for line in top_collection_file:
        tid = line.split("\t")[0]
        top_tids.append(tid)
        new_collection_file.write(line)

    assert len(set(rand_tids).intersection(set(top_tids))) == 0

    for line in rand_dev_query:
        new_dev_query.write(line)
    for line in top_dev_query:
        new_dev_query.write(line)
    for line in rand_dev_qrel:
        new_dev_qrel.write(line)
    for line in top_dev_qrel:
        new_dev_qrel.write(line)


In [None]:
# split NQ to Seen and Unseen for evaluating performance of adding new documents
# d0_tindices, d0_train_positives, d0_dev_positives, d0_test_positives (3915)
# d1_tindices, d1_test_positives (3915)

train_positives = load_pickle(f"data/cache/{config.dataset}/dataset/query/train/positives.pkl")
test_positives = load_pickle(f"data/cache/{config.dataset}/dataset/query/dev/positives.pkl")

tidx_2_train_query = defaultdict(list)
for k, v in train_positives.items():
    tidx_2_train_query[v[0]].append(k)
tindices_with_train_query = set([x[0] for x in train_positives.values()])

tidx_2_test_query = defaultdict(list)
for k, v in test_positives.items():
    tidx_2_test_query[v[0]].append(k)
tindices_with_test_query = set([x[0] for x in test_positives.values()])

tindices = set(range(len(text_dataset)))

test_qindices = set(test_positives.keys())
d1_test_positives = {}
# unseen documents on the entire training set
d1_tindices = tindices - tindices_with_train_query
for tidx in d1_tindices:
    qindices = tidx_2_test_query[tidx]
    for qidx in qindices:
        d1_test_positives[qidx] = [tidx]

# sampled queries
d1_test_candidates = test_qindices - set(d1_test_positives.keys())
# only keep the queries with 1 relevant document
d1_test_qindices = set()
for qidx in d1_test_candidates:
    if len(tidx_2_test_query[test_positives[qidx][0]]) == 1:
        d1_test_qindices.add(qidx)
d1_test_candidates = random.sample(list(d1_test_qindices), 3915 - len(d1_test_positives))
for qidx in d1_test_candidates:
    d1_test_positives[qidx] = test_positives[qidx]

d1_tindices = set([x[0] for x in d1_test_positives.values()])
# sampled documents relevant to some train queries
d1_candidates = tindices_with_train_query - tindices_with_test_query
d1_tindices.update(random.sample(list(d1_candidates), 50000 - len(d1_tindices)))

d0_test_positives = {}
d0_test_tindices = tindices_with_test_query - d1_tindices
for tidx in d0_test_tindices:
    qindices = tidx_2_test_query[tidx]
    for qidx in qindices:
        d0_test_positives[qidx] = [tidx]

# sanity
d0_tindices = tindices - d1_tindices
for tidx in d0_tindices:
    assert tidx in tidx_2_train_query
assert len(d0_test_positives) == 3915
assert d0_test_tindices - d0_tindices == set()

d0_train_positives = {}
d0_dev_positives = {}

d0_qindices = set()
d0_tindices_with_multiple_rel = set()
for tidx in d0_tindices:
    qindices = tidx_2_train_query[tidx]
    d0_qindices.update(qindices)
    if len(qindices) > 1:
        d0_tindices_with_multiple_rel.add(tidx)

d0_dev_tindices = random.sample(list(d0_tindices_with_multiple_rel), 3000)
for tidx in d0_dev_tindices:
    qindices = tidx_2_train_query[tidx]
    assert len(qindices) > 1
    dev_qidx = random.sample(qindices, 1)[0]
    d0_dev_positives[dev_qidx] = [tidx]
    d0_qindices.remove(dev_qidx)

for qidx in d0_qindices:
    d0_train_positives[qidx] = train_positives[qidx]

In [None]:
tidx2id = {v: k for k, v in load_pickle(f"data/cache/{config.dataset}/dataset/text/id2index.pkl").items()}
# 50k
d0_dataset = "NQ-50k-seen"
# the remaining 50k
d1_dataset = "NQ-50k-unseen"

train_qidx2id = {v: k for k, v in load_pickle(f"data/cache/{config.dataset}/dataset/query/train/id2index.pkl").items()}
dev_qidx2id = {v: k for k, v in load_pickle(f"data/cache/{config.dataset}/dataset/query/dev/id2index.pkl").items()}

os.makedirs(f"{config.data_root}/{d0_dataset}", exist_ok=True)
os.makedirs(f"{config.data_root}/{d1_dataset}", exist_ok=True)

# with open(f"{config.data_root}/{config.dataset}/collection.tsv") as collection_file, \
#     open(f"{config.data_root}/{config.dataset}/queries.train.tsv") as train_query_file, \
#     open(f"{config.data_root}/{config.dataset}/queries.dev.tsv") as dev_query_file, \
#     open(f"{config.data_root}/{d0_dataset}/collection.tsv", "w") as d0_collection_file, \
#     open(f"{config.data_root}/{d0_dataset}/queries.train.tsv", "w") as d0_train_query_file, \
#     open(f"{config.data_root}/{d0_dataset}/queries.dev.tsv", "w") as d0_dev_query_file, \
#     open(f"{config.data_root}/{d0_dataset}/queries.test.tsv", "w") as d0_test_query_file, \
#     open(f"{config.data_root}/{d0_dataset}/qrels.train.tsv", "w") as d0_train_qrel_file, \
#     open(f"{config.data_root}/{d0_dataset}/qrels.dev.tsv", "w") as d0_dev_qrel_file, \
#     open(f"{config.data_root}/{d0_dataset}/qrels.test.tsv", "w") as d0_test_qrel_file, \
#     open(f"{config.data_root}/{d1_dataset}/collection.tsv", "w") as d1_collection_file, \
#     open(f"{config.data_root}/{d1_dataset}/queries.dev.tsv", "w") as d1_test_query_file, \
#     open(f"{config.data_root}/{d1_dataset}/qrels.dev.tsv", "w") as d1_test_qrel_file:

#     for tidx, line in enumerate(tqdm(collection_file)):
#         if tidx in d0_tindices:
#             d0_collection_file.write(line)
#         if tidx in d1_tindices:
#             d1_collection_file.write(line)
    
#     for qidx, line in enumerate(tqdm(train_query_file)):
#         if qidx in d0_train_positives:
#             positive = d0_train_positives[qidx][0]
#             d0_train_query_file.write(line)
#             d0_train_qrel_file.write("\t".join([train_qidx2id[qidx], "0", tidx2id[positive], "1"]) + "\n")
#         elif qidx in d0_dev_positives:
#             positive = d0_dev_positives[qidx][0]
#             d0_dev_query_file.write(line)
#             d0_dev_qrel_file.write("\t".join([train_qidx2id[qidx], "0", tidx2id[positive], "1"]) + "\n")

#     for qidx, line in enumerate(tqdm(dev_query_file)):
#         if qidx in d0_test_positives:
#             positive = d0_test_positives[qidx][0]
#             d0_test_query_file.write(line)
#             d0_test_qrel_file.write("\t".join([dev_qidx2id[qidx], "0", tidx2id[positive], "1"]) + "\n")
#         if qidx in d1_test_positives:
#             positive = d1_test_positives[qidx][0]
#             d1_test_query_file.write(line)
#             d1_test_qrel_file.write("\t".join([dev_qidx2id[qidx], "0", tidx2id[positive], "1"]) + "\n")

In [None]:
# code_types = ["ANCE_hier", "words_comma", "words_comma_plus_stem", "title"]
code_types = ["id"]
code_lengths = [8]
code_tokenizer = "t5"

for code_type, code_length in zip(code_types, code_lengths):
    text_codes = np.memmap(
        f"data/cache/{config.dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp",
        mode="r",
        dtype=np.int32
    ).reshape(len(text_dataset), code_length).copy()

    d0_text_codes = np.memmap(
        makedirs(f"data/cache/{d0_dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
        mode="w+",
        dtype=np.int32,
        shape=(len(d0_tindices), code_length)
    )
    d1_text_codes = np.memmap(
        makedirs(f"data/cache/{d1_dataset}/codes/{code_type}/{code_tokenizer}/{code_length}/codes.mmp"),
        mode="w+",
        dtype=np.int32,
        shape=(len(d1_tindices), code_length)
    )
    counts = [0 for _ in range(4)]
    for tidx, text_code in enumerate(text_codes):
        if tidx in d0_tindices:
            d0_text_codes[counts[0]] = text_code
            counts[0] += 1
        if tidx in d1_tindices:
            d1_text_codes[counts[1]] = text_code
            counts[1] += 1
    assert all([counts[i] == len(eval(f"d{i}_tindices")) for i in range(2)])

In [None]:
# case for generation likelihood
model = AM.from_pretrained("/share/project/peitian/Code/Uni-Retriever/src/data/cache/NQ/ckpts/BOW_qg/first+random3+em", device="cpu")

In [None]:
query = default_collate([query_dataset[3315]["query"]])

text_code = [0]
# for x in "executive office of president".split(" "):
# for x in "white house communication director".split(" "):
for x in "cristeta comerford executive chef".split(" "):
# for x in "executive chef white house".split(" "):
# for x in "white house executive chef".split(" "):
    text_code += t.encode(x, add_special_tokens=False) + [6]
text_code += [1]
text_code = default_collate([text_code])

# text_code = default_collate([text_codes[108361].astype(np.int64)])

text_code[text_code == -1] = 0
display(t.decode(text_code[0]))

with torch.no_grad():
    logits = model.plm(**query, decoder_input_ids=text_code).logits
    logits = torch.log_softmax(logits, dim=-1)
    logits = logits.gather(dim=-1, index=text_code[:, 1:, None])[0,:,0]
    # logits[4] = logits[4] - 4
    cum = logits.cumsum(dim=-1).numpy().round(3)
    tokens = t.convert_ids_to_tokens(text_code[0])[1:]
list(zip(cum, tokens))

In [None]:
# interleave two query sets to form a unified one with all documents having the same number of queries

k = 5
dataset = "NQ"
query_set = "pad5"
main_query_set = "train"
sup_query_set = "doct5"

main_qid2idx = load_pickle(f"data/cache/{dataset}/dataset/query/{main_query_set}/id2index.pkl")
sup_qid2idx = load_pickle(f"data/cache/{dataset}/dataset/query/{sup_query_set}/id2index.pkl")
tid2idx = load_pickle(f"data/cache/{dataset}/dataset/text/id2index.pkl")

with \
    open(f"{config.data_root}/{dataset}/qrels.{main_query_set}.tsv") as main_qrel_file, \
    open(f"{config.data_root}/{dataset}/qrels.{sup_query_set}.tsv") as sup_qrel_file, \
    open(f"{config.data_root}/{dataset}/qrels.{query_set}.tsv", "w") as qrel_file, \
    open(f"{config.data_root}/{dataset}/queries.{main_query_set}.tsv") as main_query_file, \
    open(f"{config.data_root}/{dataset}/queries.{sup_query_set}.tsv") as sup_query_file, \
    open(f"{config.data_root}/{dataset}/queries.{query_set}.tsv", "w") as query_file:

    new_qid = 0
    # map qidx to a new query id
    qidx2newid = {}
    qid2idx = load_pickle(f"data/cache/{dataset}/dataset/query/{main_query_set}/id2index.pkl")
    main_tid2qidx = defaultdict(list)
    for i, line in enumerate(main_qrel_file):
        qid, _, tid, _ = line.strip().split("\t")
        # keep the first k elements
        if len(main_tid2qidx[tid]) < k:
            qidx = qid2idx[qid]
            main_tid2qidx[tid].append(qidx)
            qidx2newid[qidx] = new_qid
            new_qid += 1

    # slice main queries
    for qidx, line in enumerate(main_query_file):
        if qidx in qidx2newid:
            qid, query = line.split("\t")
            qid = qidx2newid[qidx]
            query_file.write("\t".join([str(qid), query]))

    # store qrel using new_qid
    for tid, qindices in main_tid2qidx.items():
        for qidx in qindices:
            qid = qidx2newid[qidx]
            qrel_file.write("\t".join([str(qid), "0", tid, "1"]) + '\n')

    qidx2newid = {}
    sup_tid2qidx = defaultdict(list)
    qid2idx = load_pickle(f"data/cache/{dataset}/dataset/query/{sup_query_set}/id2index.pkl")
    for i, line in enumerate(sup_qrel_file):
        qid, _, tid, _ = line.strip().split("\t")
        # keep the first k elements
        # pad to k
        if len(main_tid2qidx[tid]) + len(sup_tid2qidx[tid]) < k:
            qidx = qid2idx[qid]
            sup_tid2qidx[tid].append(qidx)
            qidx2newid[qidx] = new_qid
            new_qid += 1
    
    # slice sup queries
    for qidx, line in enumerate(sup_query_file):
        if qidx in qidx2newid:
            qid, query = line.split("\t")
            qid = qidx2newid[qidx]
            query_file.write("\t".join([str(qid), query]))

    # store qrel using new_qid
    for tid, qindices in sup_tid2qidx.items():
        for qidx in qindices:
            qid = qidx2newid[qidx]
            qrel_file.write("\t".join([str(qid), "0", tid, "1"]) + '\n')