In [1]:
## Set current working directory to the root of the project
## !! Run this once only once
import os
os.chdir('../..')

In [2]:
from argparse import Namespace
import h5py
import pandas as pd
import pickle
import torch
import re
import random
import csv
import jsonlines
import numpy as np
import pickle
import time
import gzip
from tqdm import tqdm, trange
from sklearn.cluster import KMeans
from typing import Any, List, Sequence, Callable
from itertools import islice, zip_longest
import transformers
from transformers import BertTokenizerFast, BertModel, AutoTokenizer, AutoModelForSeq2SeqLM
from sklearn.cluster import MiniBatchKMeans

import pyarrow.csv as csv

import preprocess

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
raw_dir = 'data/NQ320K/raw'
output_dir = 'data/NQ320K/output'
cache_dir = 'data/NQ320K/cache'
pretrained_dir = 'data/pretrained'

In [5]:
bert_model_path = bert_model_identifier = 'bert-base-uncased'
genq_model_identifier = 'doc2query-t5-base-msmarco-ft_NQ320K'
genq_model_path = "data/pretrained/doc2query-t5-base-msmarco-ft_NQ320K"

# 1. Preparation

## 1.1 Extract data into `data/NQ320K/raw`

###### Download NQ Train and Dev dataset from https://ai.google.com/research/NaturalQuestions/download
###### NQ Train: https://storage.cloud.google.com/natural_questions/v1.0-simplified/simplified-nq-train.jsonl.gz
###### NQ Dev: https://storage.cloud.google.com/natural_questions/v1.0-simplified/nq-dev-all.jsonl.gz

### 1.1.1 Extract query-document samples from raw files

In [14]:
columns = [
    "query",
    "queryid",
    "long_answer",
    "short_answer",
    "title",
    "abstract",
    "content",
    "doc",
    "language",
]

In [None]:
nq_dev = []

with gzip.open(f"{raw_dir}/v1.0-simplified_nq-dev-all.jsonl.gz", "r+") as f:
    for item in tqdm(jsonlines.Reader(f)):
        
        arr = []
        ## question_text
        question_text = item['question_text']
        arr.append(question_text)

        tokens = []
        for i in item['document_tokens']:
            tokens.append(i['token'])
        document_text = ' '.join(tokens)
        
        ## example_id
        example_id = str(item['example_id'])
        arr.append(example_id)

        # document_text = item['document_text']
        ## long_answer
        annotation = item['annotations'][0]
        has_long_answer = annotation['long_answer']['start_token'] >= 0

        long_answers = [
            a['long_answer']
            for a in item['annotations']
            if a['long_answer']['start_token'] >= 0 and has_long_answer
        ]
        if has_long_answer:
            start_token = long_answers[0]['start_token']
            end_token = long_answers[0]['end_token']
            x = document_text.split(' ')
            long_answer = ' '.join(x[start_token:end_token])
            long_answer = re.sub('<[^<]+?>', '', long_answer).replace('\n', '').strip()
        arr.append(long_answer) if has_long_answer else arr.append('')

        # short_answer
        has_short_answer = annotation['short_answers'] or annotation['yes_no_answer'] != 'NONE'
        short_answers = [
            a['short_answers']
            for a in item['annotations']
            if a['short_answers'] and has_short_answer
        ]
        if has_short_answer and len(annotation['short_answers']) != 0:
            sa = []
            for i in short_answers[0]:
                start_token_s = i['start_token']
                end_token_s = i['end_token']
                shorta = ' '.join(x[start_token_s:end_token_s])
                sa.append(shorta)
            short_answer = '|'.join(sa)
            short_answer = re.sub('<[^<]+?>', '', short_answer).replace('\n', '').strip()
        arr.append(short_answer) if has_short_answer else arr.append('')

        ## title
        arr.append(item['document_title'])

        ## abs
        if document_text.find('<P>') != -1:
            abs_start = document_text.index('<P>')
            abs_end = document_text.index('</P>')
            abs = document_text[abs_start+3:abs_end]
        else:
            abs = ''
        arr.append(abs)

        ## content
        if document_text.rfind('</Ul>') != -1:
            final = document_text.rindex('</Ul>')
            document_text = document_text[:final]
            if document_text.rfind('</Ul>') != -1:
                final = document_text.rindex('</Ul>')
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
            else:
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
                arr.append(content)
        else:
            content = document_text[abs_end+4:]
            content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
            content = re.sub(' +', ' ', content)
            arr.append(content)
        doc_tac = item['document_title'] + abs + content
        arr.append(doc_tac)
        language = 'en'
        arr.append(language)
        nq_dev.append(arr)

dev = pd.DataFrame(nq_dev, columns=columns)
dev.to_csv(f"{raw_dir}/dev.tsv", sep="\t", index=False)

In [None]:
nq_train = []
with gzip.open(f"{raw_dir}/v1.0-simplified_simplified-nq-train.jsonl.gz", "r+") as f:
    for item in tqdm(jsonlines.Reader(f)):
        ## question_text
        arr = []
        question_text = item["question_text"]
        arr.append(question_text)

        ## example_id
        example_id = str(item["example_id"])
        arr.append(example_id)

        document_text = item["document_text"]

        ## long_answer
        annotation = item["annotations"][0]
        has_long_answer = annotation["long_answer"]["start_token"] >= 0

        long_answers = [
            a["long_answer"]
            for a in item["annotations"]
            if a["long_answer"]["start_token"] >= 0 and has_long_answer
        ]
        if has_long_answer:
            start_token = long_answers[0]["start_token"]
            end_token = long_answers[0]["end_token"]
            x = document_text.split(" ")
            long_answer = " ".join(x[start_token:end_token])
            long_answer = re.sub("<[^<]+?>", "", long_answer).replace("\n", "").strip()
        arr.append(long_answer) if has_long_answer else arr.append("")

        # short_answer
        has_short_answer = (
            annotation["short_answers"] or annotation["yes_no_answer"] != "NONE"
        )
        short_answers = [
            a["short_answers"]
            for a in item["annotations"]
            if a["short_answers"] and has_short_answer
        ]
        if has_short_answer and len(annotation["short_answers"]) != 0:
            sa = []
            for i in short_answers[0]:
                start_token_s = i["start_token"]
                end_token_s = i["end_token"]
                shorta = " ".join(x[start_token_s:end_token_s])
                sa.append(shorta)
            short_answer = "|".join(sa)
            short_answer = (
                re.sub("<[^<]+?>", "", short_answer).replace("\n", "").strip()
            )
        arr.append(short_answer) if has_short_answer else arr.append("")

        ## title
        if document_text.find("<H1>") != -1:
            title_start = document_text.index("<H1>")
            title_end = document_text.index("</H1>")
            title = document_text[title_start + 4 : title_end]
        else:
            title = ""
        arr.append(title)

        ## abs
        if document_text.find("<P>") != -1:
            abs_start = document_text.index("<P>")
            abs_end = document_text.index("</P>")
            abs = document_text[abs_start + 3 : abs_end]
        else:
            abs = ""
        arr.append(abs)

        ## content
        if document_text.rfind("</Ul>") != -1:
            final = document_text.rindex("</Ul>")
            document_text = document_text[:final]
            if document_text.rfind("</Ul>") != -1:
                final = document_text.rindex("</Ul>")
                content = document_text[abs_end + 4 : final]
                content = re.sub("<[^<]+?>", "", content).replace("\n", "").strip()
                content = re.sub(" +", " ", content)
                arr.append(content)
            else:
                content = document_text[abs_end + 4 : final]
                content = re.sub("<[^<]+?>", "", content).replace("\n", "").strip()
                content = re.sub(" +", " ", content)
                arr.append(content)
        else:
            content = document_text[abs_end + 4 :]
            content = re.sub("<[^<]+?>", "", content).replace("\n", "").strip()
            content = re.sub(" +", " ", content)
            arr.append(content)

        doc_tac = title + abs + content
        arr.append(doc_tac)

        language = "en"
        arr.append(language)
        nq_train.append(arr)

train = pd.DataFrame(nq_train, columns=columns)
train.to_csv(f"{raw_dir}/train.tsv", sep="\t", index=False)

307373it [10:32, 486.33it/s]


### 1.2.2 Load extracted samples and collect unique documents

In [149]:
## read large csv files (use pyarrow to accelerate)
def load_dev():
    dev = csv.read_csv(
        f'{raw_dir}/dev.tsv',
        read_options=csv.ReadOptions(block_size=2**25),
        parse_options=csv.ParseOptions(invalid_row_handler=lambda invalidrow:"skip", delimiter="\t")
    ).to_pandas()
    dev['title'] = dev['title'].map(lower)
    print('dev.shape:', dev.shape)
    return dev

def load_train():
    train = csv.read_csv(
        f'{raw_dir}/train.tsv',
        read_options=csv.ReadOptions(block_size=2**25),
        parse_options=csv.ParseOptions(invalid_row_handler=lambda invalidrow:"skip", delimiter="\t")
    ).to_pandas()
    train['title'] = train['title'].map(lower)
    print('train.shape:', train.shape)
    return train

## Clean data
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
def lower(x):
    text = tokenizer.tokenize(x)
    id_ = tokenizer.convert_tokens_to_ids(text)
    return tokenizer.decode(id_)



In [None]:
train = load_train()
dev = load_dev()

In [73]:
## Concat train doc and validation doc to obtain full document collection
full = pd.concat([train, dev], axis=0)
full.reset_index(inplace = True)

## Remove duplicated documents based on titles
docs = full.drop_duplicates('title')[['title', 'abstract', 'doc']]
docs.reset_index(inplace=True, drop=True)
docs.index.name = "docid"

docs.fillna({"title": ""}, inplace=True)
assert not pd.isnull(docs['title']).any()

# del full

In [75]:
(docs['title'] == 'nan').any()

True

In [17]:
## The total amount of documents : 109737
assert len(docs) == 109737

## Statistics
print(f"# all unique documents: {len(docs)}")
print("----------- training set --------------")
print(f"# Queries: {len(train)}")
print(f"# Documents mentioned in training set: {len(train['title'].unique())}")

print("----------- dev set --------------")
print(f"# Queries: {len(dev)}")
print(f"# Documents mentioned dev set: {len(dev['title'].unique())}")

# all unique documents: 109737
----------- training set --------------
# Queries: 307369
# Documents mentioned in training set: 108024
----------- dev set --------------
# Queries: 7830
# Documents mentioned dev set: 6930


In [18]:
# docs.tsv
docs.to_csv(f"{raw_dir}/docs.tsv", sep='\t')

In [19]:
!head -5 {raw_dir}/docs.tsv

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


docid	title	abstract	doc
0	email marketing	 Email marketing is the act of sending a commercial message , typically to a group of people , using email . In its broadest sense , every email sent to a potential or current customer could be considered email marketing . It usually involves using email to send advertisements , request business , or solicit sales or donations , and is meant to build loyalty , trust , or brand awareness . Marketing emails can be sent to a purchased lead list or a current customer database . The term usually refers to sending email messages with the purpose of enhancing a merchant 's relationship with current or previous customers , encouraging customer loyalty and repeat business , acquiring new customers or convincing current customers to purchase something immediately , and sharing third - party ads . 	 Email marketing  Email marketing is the act of sending a commercial message , typically to a group of people , using email . In its broadest sense , every em

## 1.3 Finetuning the document-to-query model

### 1.3.1 Download docT5query model

In [None]:
import transformers

identifier = "castorini/doc2query-t5-base-msmarco"
model = transformers.T5ForConditionalGeneration.from_pretrained(identifier)
tokenizer = transformers.T5TokenizerFast.from_pretrained(identifier)

os.makedirs(pretrained_dir, exist_ok=True)
model.save_pretrained(f"{pretrained_dir}/doc2query-t5-base-msmarco")
tokenizer.save_pretrained(f"{pretrained_dir}/doc2query-t5-base-msmarco")



('data/NQ320K/pretrained/doc2query-t5-base-msmarco/tokenizer_config.json',
 'data/NQ320K/pretrained/doc2query-t5-base-msmarco/special_tokens_map.json',
 'data/NQ320K/pretrained/doc2query-t5-base-msmarco/spiece.model',
 'data/NQ320K/pretrained/doc2query-t5-base-msmarco/added_tokens.json',
 'data/NQ320K/pretrained/doc2query-t5-base-msmarco/tokenizer.json')

### 1.3.2 Finetuning

Finetuning docT5query model with `train.doc_query.tsv` \
by executing the following bash command in your terminal.

```bash
    nohup python -m preprocess.finetune_t5 \
        --raw_ckpt data/pretrained/doc2query-t5-base-msmarco \
        --finetuned_ckpt data/pretrained/doc2query-t5-base-msmarco-ft_NQ320K \
        --train_data_path data/NQ320K/raw/train.doc_query.tsv \
        --val_data_path data/NQ320K/raw/dev.doc_query.tsv \
        --epochs 10 \
        --lr 5e-5 \
        --weight_decay 1e-2 \
        --batch_size 8 \
        --doc_max_len 512 \
        --query_max_len 64 \
        --test1000 0 \
        --num_nodes 1 \
    > log.finetuning_doc2query.log 2>&1 &
```

It takes approximately `five hours` on four RTX4090 GPUs. \
After the finetuning is done, the finetuned model will be saved to
```
    data/pretrained/doc2query-t5-base-msmarco-ft_NQ320K
```

You can optionally check the finetuning progress by executing in your terminal:
```bash
    tail -f log.finetuning_doc2query.log
```

## 1.4 Generate queries with the finetuned document-to-query model

Generate queries for every document in `docs.tsv` and save them as `genq.tsv`, \
in which every line contains a document id and a generated query.

It takes around 40 minutes on four RTX4090 GPUs.

In [None]:
import preprocess.genq
args = Namespace(
    model_path=genq_model_path,
    docs_path=f"{raw_dir}/docs.tsv",
    output_path=f"{cache_dir}/{genq_model_identifier}/genq.tsv",
    doc_max_len=512,
    query_max_len=32,
    genq_per_doc=15,
    n_gpus=4,
    batch_size=16,
)
preprocess.genq.main(args)

  0%|          | 0/6859 [00:00<?, ?it/s]

Initializing ...
Initializing ...
Initializing ...Initializing ...

Initialization finished
Initialization finished
Initialization finished
Initialization finished


100%|██████████| 6859/6859 [39:27<00:00,  2.90it/s]


In [None]:
!head -5 {cache_dir}/{genq_model_identifier}/genq.tsv

docid	query
0	two types of advertising in email marketing and how they work
0	who sends email to find out how to get leads
0	what are the advantages and disadvantages of email marketing
0	which is an example of an email marketing method


# 2. Hierarchical K-means Indexing (HKmI)

## 2.1 Produce for every document a BERT embedding

Encode every document in `docs.tsv` as a vector. The embeddings will be saved as `doc_emb.h5`.

The encoding process will take approximately 6 minutes on four RTX4090 GPUs.

In [None]:
import preprocess.bert_embedding
args = Namespace(
    docs_path=f"{raw_dir}/docs.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/doc_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
)
preprocess.bert_embedding.main(args)

Initializing ...
Initializing ...


Initializing ...Initializing ...

  0%|          | 0/109737 [00:00<?, ?it/s]

Initialization finished


  0%|          | 20/109737 [00:04<6:59:50,  4.36it/s]

Initialization finished


  0%|          | 40/109737 [00:04<3:03:25,  9.97it/s]

Initialization finished
Initialization finished


100%|██████████| 109737/109737 [05:37<00:00, 325.51it/s]


## 2.2 Apply K-means clustering on documents

Applying hierarchical K-means on document embeddings.

Every document will be assigned an ID string, which will be saved as `docid2index.tsv`

In [None]:
import preprocess.kmeans
args = Namespace(
    embedding_path=f"{cache_dir}/{bert_model_identifier}/doc_emb.h5",
    output_path=f"{cache_dir}/{bert_model_identifier}/docid2index.HKmI.tsv",
    v_dim=768,
    k=30,
    c=30,
    seed=7,
    n_init=1,   # can be increased to 10/100 to enhance quality at the cost of running time
    tol=1e-6,
)
preprocess.kmeans.main(args)

Fri May 17 20:39:17 2024 Start First Clustering
(109737,)
148
Fri May 17 20:39:17 2024 Start Recursively Clustering...


100%|██████████| 30/30 [00:16<00:00,  1.79it/s]


## 2.3 Use document segments as queries

For every document, we random select 10~12 segments of 64 tokens as queries.

In [None]:
def get_seg(tokens):
    begin = random.randrange(0, max(1, len(tokens) - 64))
    end = begin + 64
    seg = ' '.join(tokens[begin: end])
    return seg

docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col=0, na_filter=False)

with open(f"{cache_dir}/docseg.tsv", "wt") as f:
    f.write("docid\tquery\n")

    for docid, doc in docs["doc"].items():
        tokens = doc.split(" ")
        nsegs = 10 + max(0, len(tokens)-3000) // 3000
        for _ in range(nsegs):
            seg = get_seg(tokens)
            f.write(f"{docid}\t{seg}\n")

## 2.4 Compiling training data for training retrieval model

- training set

A training sample should have three entries: *query, index, docid*

| File Name | Description |
| --- | --- |
| realq_train.tsv | real queries (ground truth) |
| genq.tsv  | generated queries from documents |
| title_abs.tsv | concatenation of document title and abstract as query |
| docseg.tsv | document segments as queries |

- dev (evaluation) set

A validation sample should have two entries: *query, docid*

| File Name | Description |
| --- | --- |
| realq_dev.tsv | real queries (ground truth) |

- supporting files

| File Name | Description |
| --- | --- |
| docid2index.tsv | mapping from docid to index, used for evaluation |

In [84]:
import BMI.io
from BMI.io import (
    StringIndexing,
    DocumentRetrievalTrainingFile,
    DocumentRetrievalInferenceFile,
    intarray_to_string,
)

In [106]:
hkmi_dirname = f"HKmI.{bert_model_identifier}.{genq_model_identifier}"
os.makedirs(f"{output_dir}/{hkmi_dirname}", exist_ok=True)

In [113]:
docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col=0, na_filter=False)
title2docid = dict(zip(docs["title"], docs.index))
docid2index = StringIndexing.from_tsv(f"{cache_dir}/{bert_model_identifier}/docid2index.HKmI.tsv")

In [None]:
# docid2index.tsv
docid2index.to_tsv(f"{output_dir}/{hkmi_dirname}/docid2index.tsv")
docid2index.to_pandas()

Unnamed: 0,docid,index
0,0,"(10, 22, 10, 0)"
1,1,"(2, 18, 19, 0)"
2,2,"(17, 8, 5, 0)"
3,3,"(25, 19, 6, 0)"
4,4,"(8, 3, 20, 0)"
...,...,...
109732,109732,"(9, 6, 8, 10)"
109733,109733,"(20, 13, 24, 1)"
109734,109734,"(6, 6, 5, 18)"
109735,109735,"(19, 20, 18, 2)"


In [91]:
# realq_train.tsv
if 'train' not in globals():
    train = pd.read_csv(f"{raw_dir}/train.tsv", usecols=["title", "query"], sep="\t")
docids = train['title'].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=train["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/realq_train.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,which is the most common use of opt-in e-mail ...,10-22-10-0,0
1,how i.met your mother who is the mother,2-18-19-0,1
2,what type of fertilisation takes place in humans,17-8-5-0,2
3,who had the most wins in the nfl,25-19-6-0,3
4,what happened to the lost settlement of roanoke,8-3-20-0,4
...,...,...,...
307364,who have been the hosts of the price is right,8-1-17-10,73063
307365,who sang the song mama told me not to come,19-21-3-7,32552
307366,who plays grey worm on game of thrones,11-20-9-16,99609
307367,working principle of high pressure sodium vapo...,1-18-28-5,37159


In [92]:
# realq_dev.tsv
if 'dev' not in globals():
    dev = pd.read_csv(f"{raw_dir}/dev.tsv", usecols=["title", "query"], sep="\t")
docids = dev['title'].apply(title2docid.get)

file = DocumentRetrievalInferenceFile(
    queries=dev["query"],
    docids=docids,
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/realq_dev.tsv")
file.to_pandas()

Unnamed: 0,query,docid
0,what do the 3 dots mean in math,
1,when was the writ watch invented by who,
2,who wrote the song photograph by ringo starr,
3,who is playing the halftime show at super bowl...,
4,star wars the clone wars anakin voice actor,
...,...,...
7825,original cast of natasha pierre and the great ...,
7826,which of the following is not a provision of t...,
7827,define divergence of vector field explain its ...,
7828,which of the following factors is not affectin...,


In [93]:
# title_abs.tsv
title_abs = docs["title"].fillna("") + " " + docs["abstract"].fillna("")
docids = docs["title"].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=title_abs,
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/title_abs.tsv")
file.to_pandas()

Unnamed: 0_level_0,query,index,docid
docid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,email marketing Email marketing is the act of...,10-22-10-0,0
1,the mother ( how i met your mother ) Tracy Mc...,2-18-19-0,1
2,human fertilization Human fertilization is th...,17-8-5-0,2
3,list of national football league career quarte...,25-19-6-0,3
4,roanoke colony The Roanoke Colony ( / ˈroʊəˌn...,8-3-20-0,4
...,...,...,...
109732,los alamitos circle The Los Alamitos Traffic ...,9-6-8-10,109732
109733,"perfect hash function In computer science , a...",20-13-24-1,109733
109734,chrysler 300c The Chrysler Corporation has us...,6-6-5-18,109734
109735,i see the want to in your eyes `` I See the W...,19-20-18-2,109735


In [94]:
# genq.tsv
genq = pd.read_csv(f"{cache_dir}/{genq_model_identifier}/genq.tsv", usecols=["docid", "query"], sep="\t")
docids = genq["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=genq["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/genq.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,two types of advertising in email marketing an...,10-22-10-0,0
1,who sends email to find out how to get leads,10-22-10-0,0
2,what are the advantages and disadvantages of e...,10-22-10-0,0
3,which is an example of an email marketing method,10-22-10-0,0
4,what is meant by an email to send a message or...,10-22-10-0,0
...,...,...,...
1646050,under the sarbanes oxley act the legal consequ...,26-17-10-5,109736
1646051,which of the following are included in the sar...,26-17-10-5,109736
1646052,the sarbanes-oxley act of 2002 was enacted in ...,26-17-10-5,109736
1646053,what are the main sections of the sarbanes-oxl...,26-17-10-5,109736


In [96]:
# docseg.tsv
docseg = pd.read_csv(f"{cache_dir}/docseg.tsv", usecols=["docid", "query"], sep="\t")
docids = docseg["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=docseg["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/docseg.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,Early Email Blasts Results in Higher Click & O...,10-22-10-0,0
1,`` Why Email Marketing is King '' . Harvard Bu...,10-22-10-0,0
2,Companies usually collect a list of customer o...,10-22-10-0,0
3,would like to receive the newsletter . With a ...,10-22-10-0,0
4,recipient ; this does not apply to business em...,10-22-10-0,0
...,...,...,...
1168560,"Oxley , two separate sections came into effect...",26-17-10-5,109736
1168561,08 - 27 . Jump up ^ `` FEI Survey '' . Fei.med...,26-17-10-5,109736
1168562,million . Costs of evaluating manual control p...,26-17-10-5,109736
1168563,"from the highs , but before Sarbanes -- Oxley ...",26-17-10-5,109736


# 3. Bottleneck-Minimal Indexing (BMI)

## 3.1 Produce for every query a BERT embedding

In [None]:
import preprocess.bert_embedding

### 3.1.1 RealQ: real queries (training set)

In [None]:
args = Namespace(
    docs_path=f"{output_dir}/{hkmi_dirname}/realq_train.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/realq_train_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

### 3.1.2 GenQ: generated queries by the finetuned document-to-query model

In [None]:
args = Namespace(
    docs_path=f"{cache_dir}/{genq_model_identifier}/genq.tsv",
    output_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/genq_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

### 3.1.3 DocSeg: using document segments as queries 

In [None]:
args = Namespace(
    docs_path=f"{output_dir}/{hkmi_dirname}/docseg.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/docseg_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

## 3.2 Apply K-means clustering on documents

### 3.2.1 Calculate centroid vector for every document

In [120]:
X, ids = [], []
for path in [
    f"{cache_dir}/{bert_model_identifier}/realq_train_emb.h5",
    f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/genq_emb.h5",
    f"{cache_dir}/{bert_model_identifier}/docseg_emb.h5",
]:
    with h5py.File(path, 'r') as f:
        X.append(f["embs"][:])
        ids.append(f["ids"][:])
X = np.concatenate(X, axis=0)
ids = np.concatenate(ids, axis=0)

In [121]:
embs = pd.DataFrame({'emb': list(X), 'docid': ids})
centroids = embs.groupby("docid").apply(lambda slice: np.stack(slice['emb'].values).mean(0))

  centroids = embs.groupby("docid").apply(lambda slice: np.stack(slice['emb'].values).mean(0))


In [122]:
centroids

docid
0         [-0.11100459, 0.072091825, -0.08475395, -0.192...
1         [-0.15832388, -0.060702156, -0.22076872, -0.01...
2         [-0.38695332, -0.020935988, -0.15999137, -0.21...
3         [-0.7586992, -0.14397368, -0.30243143, -0.2124...
4         [-0.40315586, 0.009933988, -0.24240935, -0.004...
                                ...                        
109732    [-0.26088616, -0.16444989, -0.01110338, -0.070...
109733    [-0.44992974, -0.24237911, 0.053350333, -0.088...
109734    [-0.54605424, -0.16090146, -0.10764443, -0.098...
109735    [-0.19014256, 0.046280567, -0.039669055, -0.00...
109736    [-0.42444065, 0.042702753, -0.2906938, -0.3260...
Length: 109737, dtype: object

In [126]:
path = f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/doc_emb.centroid.realq_genq_docseg.h5"
with h5py.File(path, 'w') as f:
    f['embs'] = np.stack(centroids.values, dtype=np.float32)
    f['ids'] = centroids.index

### 3.2.2 Run k-means

In [129]:
import preprocess.kmeans
import importlib
importlib.reload(preprocess.kmeans)

args = Namespace(
    embedding_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/doc_emb.centroid.realq_genq_docseg.h5",
    output_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/docid2index.BMI.realq_genq_docseg.tsv",
    v_dim=768,
    k=30,
    c=30,
    seed=7,
    n_init=1,   # can be increased to 10/100 to enhance quality at the cost of running time
    tol=1e-6,
)
preprocess.kmeans.main(args)

(109737, 768)
Sun May 19 22:37:19 2024 Start First Clustering
(109737,)
68
Sun May 19 22:37:20 2024 Start Recursively Clustering...


100%|██████████| 30/30 [00:19<00:00,  1.56it/s]


## 3.3 Compiling training data for training retrieval model

- training set

A training sample should have three entries: *query, index, docid*

| File Name | Description |
| --- | --- |
| realq_train.tsv | real queries (ground truth) |
| genq.tsv  | generated queries from documents |
| title_abs.tsv | concatenation of document title and abstract as query |
| docseg.tsv | document segments as queries |

- dev (evaluation) set

A validation sample should have two entries: *query, docid*

| File Name | Description |
| --- | --- |
| realq_dev.tsv | real queries (ground truth) |

- supporting files

| File Name | Description |
| --- | --- |
| docid2index.tsv | mapping from docid to index, used for evaluation |

In [130]:
import BMI.io
from BMI.io import (
    StringIndexing,
    DocumentRetrievalTrainingFile,
    DocumentRetrievalInferenceFile,
    intarray_to_string,
)

In [131]:
bmi_dirname = f"BMI.{bert_model_identifier}.{genq_model_identifier}.realq_genq_docseg"
os.makedirs(f"{output_dir}/{bmi_dirname}", exist_ok=True)

In [132]:
docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col=0, na_filter=False)
title2docid = dict(zip(docs["title"], docs.index))
docid2index = StringIndexing.from_tsv(f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/docid2index.BMI.realq_genq_docseg.tsv")

In [133]:
# docid2index.tsv
docid2index.to_tsv(f"{output_dir}/{bmi_dirname}/docid2index.tsv")
docid2index.to_pandas()

Unnamed: 0,docid,index
0,0,"(2, 17, 5, 0)"
1,1,"(11, 24, 18, 0)"
2,2,"(5, 18, 23, 0)"
3,3,"(10, 8, 17, 0)"
4,4,"(9, 12, 29, 0)"
...,...,...
109732,109732,"(3, 6, 24, 1)"
109733,109733,"(18, 3, 5, 3)"
109734,109734,"(21, 1, 2, 6)"
109735,109735,"(8, 4, 16, 18)"


In [137]:
# realq_train.tsv
if 'train' not in globals():
    train = load_train()
docids = train['title'].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=train["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/realq_train.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,which is the most common use of opt-in e-mail ...,2-17-5-0,0
1,how i.met your mother who is the mother,11-24-18-0,1
2,what type of fertilisation takes place in humans,5-18-23-0,2
3,who had the most wins in the nfl,10-8-17-0,3
4,what happened to the lost settlement of roanoke,9-12-29-0,4
...,...,...,...
307364,who have been the hosts of the price is right,15-28-29-2,73063
307365,who sang the song mama told me not to come,8-27-13-10,32552
307366,who plays grey worm on game of thrones,19-2-7-15,99609
307367,working principle of high pressure sodium vapo...,7-7-2-1,37159


In [151]:
# realq_dev.tsv
if 'dev' not in globals():
    dev = load_dev()
docids = dev['title'].apply(title2docid.get)

file = DocumentRetrievalInferenceFile(
    queries=dev["query"],
    docids=docids,
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/realq_dev.tsv")
file.to_pandas()

Unnamed: 0,query,docid
0,what do the 3 dots mean in math,101577
1,when was the writ watch invented by who,56156
2,who wrote the song photograph by ringo starr,108024
3,who is playing the halftime show at super bowl...,2114
4,star wars the clone wars anakin voice actor,108025
...,...,...
7825,original cast of natasha pierre and the great ...,28339
7826,which of the following is not a provision of t...,109736
7827,define divergence of vector field explain its ...,43021
7828,which of the following factors is not affectin...,36157


In [139]:
# title_abs.tsv
title_abs = docs["title"].fillna("") + " " + docs["abstract"].fillna("")
docids = docs["title"].apply(title2docid.get)
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=title_abs,
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/title_abs.tsv")
file.to_pandas()

Unnamed: 0_level_0,query,index,docid
docid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,email marketing Email marketing is the act of...,2-17-5-0,0
1,the mother ( how i met your mother ) Tracy Mc...,11-24-18-0,1
2,human fertilization Human fertilization is th...,5-18-23-0,2
3,list of national football league career quarte...,10-8-17-0,3
4,roanoke colony The Roanoke Colony ( / ˈroʊəˌn...,9-12-29-0,4
...,...,...,...
109732,los alamitos circle The Los Alamitos Traffic ...,3-6-24-1,109732
109733,"perfect hash function In computer science , a...",18-3-5-3,109733
109734,chrysler 300c The Chrysler Corporation has us...,21-1-2-6,109734
109735,i see the want to in your eyes `` I See the W...,8-4-16-18,109735


In [140]:
# genq.tsv
genq = pd.read_csv(f"{cache_dir}/{genq_model_identifier}/genq.tsv", usecols=["docid", "query"], sep="\t")
docids = genq["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=genq["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/genq.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,two types of advertising in email marketing an...,2-17-5-0,0
1,who sends email to find out how to get leads,2-17-5-0,0
2,what are the advantages and disadvantages of e...,2-17-5-0,0
3,which is an example of an email marketing method,2-17-5-0,0
4,what is meant by an email to send a message or...,2-17-5-0,0
...,...,...,...
1646050,under the sarbanes oxley act the legal consequ...,2-27-5-9,109736
1646051,which of the following are included in the sar...,2-27-5-9,109736
1646052,the sarbanes-oxley act of 2002 was enacted in ...,2-27-5-9,109736
1646053,what are the main sections of the sarbanes-oxl...,2-27-5-9,109736


In [141]:
# docseg.tsv
docseg = pd.read_csv(f"{cache_dir}/docseg.tsv", usecols=["docid", "query"], sep="\t")
docids = docseg["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=docseg["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/docseg.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,Early Email Blasts Results in Higher Click & O...,2-17-5-0,0
1,`` Why Email Marketing is King '' . Harvard Bu...,2-17-5-0,0
2,Companies usually collect a list of customer o...,2-17-5-0,0
3,would like to receive the newsletter . With a ...,2-17-5-0,0
4,recipient ; this does not apply to business em...,2-17-5-0,0
...,...,...,...
1168560,"Oxley , two separate sections came into effect...",2-27-5-9,109736
1168561,08 - 27 . Jump up ^ `` FEI Survey '' . Fei.med...,2-27-5-9,109736
1168562,million . Costs of evaluating manual control p...,2-27-5-9,109736
1168563,"from the highs , but before Sarbanes -- Oxley ...",2-27-5-9,109736
