In [None]:
## Set current working directory to the root of the project
## !! Run this once only once
import os
os.chdir('../..')

In [None]:
from argparse import Namespace
import h5py
import pandas as pd
import pickle
import torch
import re
import random
import csv
import jsonlines
import numpy as np
import pickle
import time
import gzip
from tqdm import tqdm, trange
from sklearn.cluster import KMeans
from typing import Any, List, Sequence, Callable
from itertools import islice, zip_longest
import transformers
from transformers import BertTokenizerFast, BertModel, AutoTokenizer, AutoModelForSeq2SeqLM
from sklearn.cluster import MiniBatchKMeans

import pyarrow.csv as csv

import preprocess

In [None]:
raw_dir = 'data/NQ320K/raw'
output_dir = 'data/NQ320K/output'
cache_dir = 'data/NQ320K/cache'
pretrained_dir = 'data/pretrained'

In [None]:
bert_model_path = bert_model_identifier = 'bert-base-uncased'
genq_model_identifier = 'doc2query-t5-base-msmarco-ft_NQ320K'
genq_model_path = "data/pretrained/doc2query-t5-base-msmarco-ft_NQ320K"

# 1. Preparation

## 1.1 Extract data into `data/NQ320K/raw`

#### Open the following links in your browser to download automatically:

##### NQ Train: https://storage.cloud.google.com/natural_questions/v1.0-simplified/simplified-nq-train.jsonl.gz
##### NQ Dev: https://storage.cloud.google.com/natural_questions/v1.0-simplified/nq-dev-all.jsonl.gz

#### __Note:__ Please download them directly via your browser (e.g., Microsoft Edge) then place them into `data/NQ320K/raw` directory.
#### Do not use `gsutil` or `wget` command directly on the above links to prevent file incompatibility or corruption.

### 1.1.1 Extract query-document samples from raw files

In [None]:
columns = [
    "query",
    "queryid",
    "long_answer",
    "short_answer",
    "title",
    "abstract",
    "content",
    "doc",
    "language",
]

In [None]:
nq_dev = []

with gzip.open(f"{raw_dir}/v1.0-simplified_nq-dev-all.jsonl.gz", "r+") as f:
    for item in tqdm(jsonlines.Reader(f)):
        
        arr = []
        ## question_text
        question_text = item['question_text']
        arr.append(question_text)

        tokens = []
        for i in item['document_tokens']:
            tokens.append(i['token'])
        document_text = ' '.join(tokens)
        
        ## example_id
        example_id = str(item['example_id'])
        arr.append(example_id)

        # document_text = item['document_text']
        ## long_answer
        annotation = item['annotations'][0]
        has_long_answer = annotation['long_answer']['start_token'] >= 0

        long_answers = [
            a['long_answer']
            for a in item['annotations']
            if a['long_answer']['start_token'] >= 0 and has_long_answer
        ]
        if has_long_answer:
            start_token = long_answers[0]['start_token']
            end_token = long_answers[0]['end_token']
            x = document_text.split(' ')
            long_answer = ' '.join(x[start_token:end_token])
            long_answer = re.sub('<[^<]+?>', '', long_answer).replace('\n', '').strip()
        arr.append(long_answer) if has_long_answer else arr.append('')

        # short_answer
        has_short_answer = annotation['short_answers'] or annotation['yes_no_answer'] != 'NONE'
        short_answers = [
            a['short_answers']
            for a in item['annotations']
            if a['short_answers'] and has_short_answer
        ]
        if has_short_answer and len(annotation['short_answers']) != 0:
            sa = []
            for i in short_answers[0]:
                start_token_s = i['start_token']
                end_token_s = i['end_token']
                shorta = ' '.join(x[start_token_s:end_token_s])
                sa.append(shorta)
            short_answer = '|'.join(sa)
            short_answer = re.sub('<[^<]+?>', '', short_answer).replace('\n', '').strip()
        arr.append(short_answer) if has_short_answer else arr.append('')

        ## title
        arr.append(item['document_title'])

        ## abs
        if document_text.find('<P>') != -1:
            abs_start = document_text.index('<P>')
            abs_end = document_text.index('</P>')
            abs = document_text[abs_start+3:abs_end]
        else:
            abs = ''
        arr.append(abs)

        ## content
        if document_text.rfind('</Ul>') != -1:
            final = document_text.rindex('</Ul>')
            document_text = document_text[:final]
            if document_text.rfind('</Ul>') != -1:
                final = document_text.rindex('</Ul>')
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
            else:
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
        else:
            content = document_text[abs_end+4:]
            content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
            content = re.sub(' +', ' ', content)
            arr.append(content)
        doc_tac = item['document_title'] + abs + content
        arr.append(doc_tac)
        language = 'en'
        arr.append(language)
        nq_dev.append(arr)

dev = pd.DataFrame(nq_dev, columns=columns)
dev.to_csv(f"{raw_dir}/dev.tsv", sep="\t", index=False)

In [None]:
nq_train = []
with gzip.open(f"{raw_dir}/v1.0-simplified_simplified-nq-train.jsonl.gz", "r+") as f:
    for item in tqdm(jsonlines.Reader(f)):
        ## question_text
        arr = []
        question_text = item["question_text"]
        arr.append(question_text)

        ## example_id
        example_id = str(item["example_id"])
        arr.append(example_id)

        document_text = item["document_text"]

        ## long_answer
        annotation = item["annotations"][0]
        has_long_answer = annotation["long_answer"]["start_token"] >= 0

        long_answers = [
            a["long_answer"]
            for a in item["annotations"]
            if a["long_answer"]["start_token"] >= 0 and has_long_answer
        ]
        if has_long_answer:
            start_token = long_answers[0]["start_token"]
            end_token = long_answers[0]["end_token"]
            x = document_text.split(" ")
            long_answer = " ".join(x[start_token:end_token])
            long_answer = re.sub("<[^<]+?>", "", long_answer).replace("\n", "").strip()
        arr.append(long_answer) if has_long_answer else arr.append("")

        # short_answer
        has_short_answer = (
            annotation["short_answers"] or annotation["yes_no_answer"] != "NONE"
        )
        short_answers = [
            a["short_answers"]
            for a in item["annotations"]
            if a["short_answers"] and has_short_answer
        ]
        if has_short_answer and len(annotation["short_answers"]) != 0:
            sa = []
            for i in short_answers[0]:
                start_token_s = i["start_token"]
                end_token_s = i["end_token"]
                shorta = " ".join(x[start_token_s:end_token_s])
                sa.append(shorta)
            short_answer = "|".join(sa)
            short_answer = (
                re.sub("<[^<]+?>", "", short_answer).replace("\n", "").strip()
            )
        arr.append(short_answer) if has_short_answer else arr.append("")

        ## title
        if document_text.find("<H1>") != -1:
            title_start = document_text.index("<H1>")
            title_end = document_text.index("</H1>")
            title = document_text[title_start + 4 : title_end]
        else:
            title = ""
        arr.append(title)

        ## abs
        if document_text.find("<P>") != -1:
            abs_start = document_text.index("<P>")
            abs_end = document_text.index("</P>")
            abs = document_text[abs_start + 3 : abs_end]
        else:
            abs = ""
        arr.append(abs)

        ## content
        if document_text.rfind("</Ul>") != -1:
            final = document_text.rindex("</Ul>")
            document_text = document_text[:final]
            if document_text.rfind("</Ul>") != -1:
                final = document_text.rindex("</Ul>")
                content = document_text[abs_end + 4 : final]
                content = re.sub("<[^<]+?>", "", content).replace("\n", "").strip()
                content = re.sub(" +", " ", content)
                arr.append(content)
            else:
                content = document_text[abs_end + 4 : final]
                content = re.sub("<[^<]+?>", "", content).replace("\n", "").strip()
                content = re.sub(" +", " ", content)
                arr.append(content)
        else:
            content = document_text[abs_end + 4 :]
            content = re.sub("<[^<]+?>", "", content).replace("\n", "").strip()
            content = re.sub(" +", " ", content)
            arr.append(content)

        doc_tac = title + abs + content
        arr.append(doc_tac)

        language = "en"
        arr.append(language)
        nq_train.append(arr)

train = pd.DataFrame(nq_train, columns=columns)
train.to_csv(f"{raw_dir}/train.tsv", sep="\t", index=False)

### 1.2.2 Load extracted samples and collect unique documents

In [None]:
## read large csv files (use pyarrow to accelerate)
def load_dev():
    dev = csv.read_csv(
        f'{raw_dir}/dev.tsv',
        read_options=csv.ReadOptions(block_size=2**25),
        parse_options=csv.ParseOptions(invalid_row_handler=lambda invalidrow:"skip", delimiter="\t")
    ).to_pandas()
    dev['title'] = dev['title'].map(lower)
    print('dev.shape:', dev.shape)
    return dev

def load_train():
    train = csv.read_csv(
        f'{raw_dir}/train.tsv',
        read_options=csv.ReadOptions(block_size=2**25),
        parse_options=csv.ParseOptions(invalid_row_handler=lambda invalidrow:"skip", delimiter="\t")
    ).to_pandas()
    train['title'] = train['title'].map(lower)
    print('train.shape:', train.shape)
    return train

## Clean data
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
def lower(x):
    text = tokenizer.tokenize(x)
    id_ = tokenizer.convert_tokens_to_ids(text)
    return tokenizer.decode(id_)

In [None]:
train = load_train()
dev = load_dev()

In [None]:
## Concat train doc and validation doc to obtain full document collection
full = pd.concat([train, dev], axis=0)
full.reset_index(inplace = True)

## Remove duplicated documents based on titles
docs = full.drop_duplicates('title')[['title', 'abstract', 'doc']]
docs.reset_index(inplace=True, drop=True)
docs.index.name = "docid"

docs.fillna({"title": ""}, inplace=True)
assert not pd.isnull(docs['title']).any()

# del full

In [None]:
(docs['title'] == 'nan').any()

In [None]:
## The total amount of documents : 109739
assert len(docs) == 109739

## Statistics
print(f"# all unique documents: {len(docs)}")
print("----------- training set --------------")
print(f"# Queries: {len(train)}")
print(f"# Documents mentioned in training set: {len(train['title'].unique())}")

print("----------- dev set --------------")
print(f"# Queries: {len(dev)}")
print(f"# Documents mentioned dev set: {len(dev['title'].unique())}")

In [None]:
# docs.tsv
docs.to_csv(f"{raw_dir}/docs.tsv", sep='\t')

In [None]:
!head -5 {raw_dir}/docs.tsv

## 1.3 Finetuning the document-to-query model

### 1.3.1 Download docT5query model

In [None]:
import transformers

identifier = "castorini/doc2query-t5-base-msmarco"
model = transformers.T5ForConditionalGeneration.from_pretrained(identifier)
tokenizer = transformers.T5TokenizerFast.from_pretrained(identifier)

os.makedirs(pretrained_dir, exist_ok=True)
model.save_pretrained(f"{pretrained_dir}/doc2query-t5-base-msmarco")
tokenizer.save_pretrained(f"{pretrained_dir}/doc2query-t5-base-msmarco")

### 1.3.2 Finetuning

Finetuning docT5query model with `train.doc_query.tsv` \
by executing the following bash command in your terminal.

```bash
    nohup python -m preprocess.finetune_t5 \
        --raw_ckpt data/pretrained/doc2query-t5-base-msmarco \
        --finetuned_ckpt data/pretrained/doc2query-t5-base-msmarco-ft_NQ320K \
        --train_data_path data/NQ320K/raw/train.doc_query.tsv \
        --val_data_path data/NQ320K/raw/dev.doc_query.tsv \
        --epochs 10 \
        --lr 5e-5 \
        --weight_decay 1e-2 \
        --batch_size 8 \
        --doc_max_len 512 \
        --query_max_len 64 \
        --test1000 0 \
        --num_nodes 1 \
    > log.finetuning_doc2query.log 2>&1 &
```

It takes approximately `five hours` on four RTX4090 GPUs. \
After the finetuning is done, the finetuned model will be saved to
```
    data/pretrained/doc2query-t5-base-msmarco-ft_NQ320K
```

You can optionally check the finetuning progress by executing in your terminal:
```bash
    tail -f log.finetuning_doc2query.log
```

## 1.4 Generate queries with the finetuned document-to-query model

Generate queries for every document in `docs.tsv` and save them as `genq.tsv`, \
in which every line contains a document id and a generated query.

It takes around 40 minutes on four RTX4090 GPUs.

In [None]:
import preprocess.genq
args = Namespace(
    model_path=genq_model_path,
    docs_path=f"{raw_dir}/docs.tsv",
    output_path=f"{cache_dir}/{genq_model_identifier}/genq.tsv",
    doc_max_len=512,
    query_max_len=32,
    genq_per_doc=15,
    n_gpus=4,
    batch_size=16,
)
preprocess.genq.main(args)

In [None]:
!head -5 {cache_dir}/{genq_model_identifier}/genq.tsv

# 2. Hierarchical K-means Indexing (HKmI)

## 2.1 Produce for every document a BERT embedding

Encode every document in `docs.tsv` as a vector. The embeddings will be saved as `doc_emb.h5`.

The encoding process will take approximately 6 minutes on four RTX4090 GPUs.

In [None]:
import preprocess.bert_embedding
args = Namespace(
    docs_path=f"{raw_dir}/docs.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/doc_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="doc",
)
preprocess.bert_embedding.main(args)

## 2.2 Apply K-means clustering on documents

Applying hierarchical K-means on document embeddings.

Every document will be assigned an ID string, which will be saved as `docid2index.tsv`

In [None]:
import preprocess.kmeans
args = Namespace(
    embedding_path=f"{cache_dir}/{bert_model_identifier}/doc_emb.h5",
    output_path=f"{cache_dir}/{bert_model_identifier}/docid2index.HKmI.tsv",
    v_dim=768,
    k=30,
    c=30,
    seed=7,
    n_init=1,   # can be increased to 10/100 to enhance quality at the cost of running time
    tol=1e-6,
)
preprocess.kmeans.main(args)

## 2.3 Use document segments as queries

For every document, we random select 10~12 segments of 64 tokens as queries.

In [None]:
def get_seg(tokens):
    begin = random.randrange(0, max(1, len(tokens) - 64))
    end = begin + 64
    seg = ' '.join(tokens[begin: end])
    return seg

docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col=0, na_filter=False)

with open(f"{cache_dir}/docseg.tsv", "wt") as f:
    f.write("docid\tquery\n")

    for docid, doc in docs["doc"].items():
        tokens = doc.split(" ")
        nsegs = 10 + max(0, len(tokens)-3000) // 3000
        for _ in range(nsegs):
            seg = get_seg(tokens)
            f.write(f"{docid}\t{seg}\n")

## 2.4 Compiling training data for training retrieval model

- training set

A training sample should have three entries: *query, index, docid*

| File Name | Description |
| --- | --- |
| realq_train.tsv | real queries (ground truth) |
| genq.tsv  | generated queries from documents |
| title_abs.tsv | concatenation of document title and abstract as query |
| docseg.tsv | document segments as queries |

- dev (evaluation) set

A validation sample should have two entries: *query, docid*

| File Name | Description |
| --- | --- |
| realq_dev.tsv | real queries (ground truth) |

- supporting files

| File Name | Description |
| --- | --- |
| docid2index.tsv | mapping from docid to index, used for evaluation |

In [None]:
import BMI.io
from BMI.io import (
    StringIndexing,
    DocumentRetrievalTrainingFile,
    DocumentRetrievalInferenceFile,
    intarray_to_string,
)

In [None]:
hkmi_dirname = f"HKmI.{bert_model_identifier}.{genq_model_identifier}"
os.makedirs(f"{output_dir}/{hkmi_dirname}", exist_ok=True)

In [None]:
docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col=0, na_filter=False)
title2docid = dict(zip(docs["title"], docs.index))
docid2index = StringIndexing.from_tsv(f"{cache_dir}/{bert_model_identifier}/docid2index.HKmI.tsv")

In [None]:
# docid2index.tsv
docid2index.to_tsv(f"{output_dir}/{hkmi_dirname}/docid2index.tsv")
docid2index.to_pandas()

In [None]:
# realq_train.tsv
if 'train' not in globals():
    train = pd.read_csv(f"{raw_dir}/train.tsv", usecols=["title", "query"], sep="\t")
docids = train['title'].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=train["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/realq_train.tsv")
file.to_pandas()

In [None]:
# realq_dev.tsv
if 'dev' not in globals():
    dev = pd.read_csv(f"{raw_dir}/dev.tsv", usecols=["title", "query"], sep="\t")
docids = dev['title'].apply(title2docid.get)

file = DocumentRetrievalInferenceFile(
    queries=dev["query"],
    docids=docids,
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/realq_dev.tsv")
file.to_pandas()

In [None]:
# title_abs.tsv
title_abs = docs["title"].fillna("") + " " + docs["abstract"].fillna("")
docids = docs["title"].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=title_abs,
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/title_abs.tsv")
file.to_pandas()

In [None]:
# genq.tsv
genq = pd.read_csv(f"{cache_dir}/{genq_model_identifier}/genq.tsv", usecols=["docid", "query"], sep="\t")
docids = genq["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=genq["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/genq.tsv")
file.to_pandas()

In [None]:
# docseg.tsv
docseg = pd.read_csv(f"{cache_dir}/docseg.tsv", usecols=["docid", "query"], sep="\t")
docids = docseg["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=docseg["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/docseg.tsv")
file.to_pandas()

# 3. Bottleneck-Minimal Indexing (BMI)

## 3.1 Produce for every query a BERT embedding

In [None]:
import preprocess.bert_embedding

### 3.1.1 RealQ: real queries (training set)

In [None]:
args = Namespace(
    docs_path=f"{output_dir}/{hkmi_dirname}/realq_train.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/realq_train_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

### 3.1.2 GenQ: generated queries by the finetuned document-to-query model

In [None]:
args = Namespace(
    docs_path=f"{cache_dir}/{genq_model_identifier}/genq.tsv",
    output_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/genq_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

### 3.1.3 DocSeg: using document segments as queries 

In [None]:
args = Namespace(
    docs_path=f"{output_dir}/{hkmi_dirname}/docseg.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/docseg_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

## 3.2 Apply K-means clustering on documents

### 3.2.1 Calculate centroid vector for every document

In [None]:
X, ids = [], []
for path in [
    f"{cache_dir}/{bert_model_identifier}/realq_train_emb.h5",
    f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/genq_emb.h5",
    f"{cache_dir}/{bert_model_identifier}/docseg_emb.h5",
]:
    with h5py.File(path, 'r') as f:
        X.append(f["embs"][:])
        ids.append(f["ids"][:])
X = np.concatenate(X, axis=0)
ids = np.concatenate(ids, axis=0)

In [None]:
embs = pd.DataFrame({'emb': list(X), 'docid': ids})
centroids = embs.groupby("docid").apply(lambda slice: np.stack(slice['emb'].values).mean(0))

In [None]:
centroids

In [None]:
path = f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/doc_emb.centroid.realq_genq_docseg.h5"
with h5py.File(path, 'w') as f:
    f['embs'] = np.stack(centroids.values, dtype=np.float32)
    f['ids'] = centroids.index

### 3.2.2 Run k-means

In [None]:
import preprocess.kmeans
import importlib
importlib.reload(preprocess.kmeans)

args = Namespace(
    embedding_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/doc_emb.centroid.realq_genq_docseg.h5",
    output_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/docid2index.BMI.realq_genq_docseg.tsv",
    v_dim=768,
    k=30,
    c=30,
    seed=7,
    n_init=1,   # can be increased to 10/100 to enhance quality at the cost of running time
    tol=1e-6,
)
preprocess.kmeans.main(args)

## 3.3 Compiling training data for training retrieval model

- training set

A training sample should have three entries: *query, index, docid*

| File Name | Description |
| --- | --- |
| realq_train.tsv | real queries (ground truth) |
| genq.tsv  | generated queries from documents |
| title_abs.tsv | concatenation of document title and abstract as query |
| docseg.tsv | document segments as queries |

- dev (evaluation) set

A validation sample should have two entries: *query, docid*

| File Name | Description |
| --- | --- |
| realq_dev.tsv | real queries (ground truth) |

- supporting files

| File Name | Description |
| --- | --- |
| docid2index.tsv | mapping from docid to index, used for evaluation |

In [None]:
import BMI.io
from BMI.io import (
    StringIndexing,
    DocumentRetrievalTrainingFile,
    DocumentRetrievalInferenceFile,
    intarray_to_string,
)

In [None]:
bmi_dirname = f"BMI.{bert_model_identifier}.{genq_model_identifier}.realq_genq_docseg"
os.makedirs(f"{output_dir}/{bmi_dirname}", exist_ok=True)

In [None]:
docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col=0, na_filter=False)
title2docid = dict(zip(docs["title"], docs.index))
docid2index = StringIndexing.from_tsv(f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/docid2index.BMI.realq_genq_docseg.tsv")

In [None]:
# docid2index.tsv
docid2index.to_tsv(f"{output_dir}/{bmi_dirname}/docid2index.tsv")
docid2index.to_pandas()

In [None]:
# realq_train.tsv
if 'train' not in globals():
    train = load_train()
docids = train['title'].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=train["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/realq_train.tsv")
file.to_pandas()

In [None]:
# realq_dev.tsv
if 'dev' not in globals():
    dev = load_dev()
docids = dev['title'].apply(title2docid.get)

file = DocumentRetrievalInferenceFile(
    queries=dev["query"],
    docids=docids,
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/realq_dev.tsv")
file.to_pandas()

In [None]:
# title_abs.tsv
title_abs = docs["title"].fillna("") + " " + docs["abstract"].fillna("")
docids = docs["title"].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=title_abs,
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/title_abs.tsv")
file.to_pandas()

In [None]:
# genq.tsv
genq = pd.read_csv(f"{cache_dir}/{genq_model_identifier}/genq.tsv", usecols=["docid", "query"], sep="\t")
docids = genq["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=genq["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/genq.tsv")
file.to_pandas()

In [None]:
# docseg.tsv
docseg = pd.read_csv(f"{cache_dir}/docseg.tsv", usecols=["docid", "query"], sep="\t")
docids = docseg["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=docseg["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/docseg.tsv")
file.to_pandas()