In [1]:
import os
os.chdir('../../')
!pwd

/home/duxin/code/Bottleneck-Minimal-Indexing


In [2]:
from argparse import Namespace
import random
import tqdm
import numpy as np
import pandas as pd
import transformers
import pyarrow.csv as csv

In [3]:
raw_dir = 'data/Marco-Lite/raw'
cache_dir = 'data/Marco-Lite/cache'
output_dir = 'data/Marco-Lite/output'
pretrained_dir = 'data/pretrained'

!mkdir -p {raw_dir}
!mkdir -p {cache_dir}
!mkdir -p {output_dir}
!mkdir -p {pretrained_dir}

bert_model_identifier = bert_model_path = "bert-base-uncased"
genq_model_identifier = "doc2query-t5-base-msmarco"
genq_model_path = f"{pretrained_dir}/doc2query-t5-base-msmarco"

# 1. Preparation

## 1.1 Download and preprocess the MS MARCO dataset

https://github.com/microsoft/msmarco/

| Type   | Filename                                                                                                              | File size |              Num Records | Format                                                         |
|--------|-----------------------------------------------------------------------------------------------------------------------|----------:|-------------------------:|----------------------------------------------------------------|
| Corpus | [msmarco-docs.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz)                          |     22 GB |               3,213,835  | tsv: docid, url, title, body                                   |
| Corpus | [msmarco-docs.trec](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docs.trec.gz)                        |     22 GB |               3,213,835  | TREC DOC format (same content as msmarco-docs.tsv)                                               |
| Corpus | [msmarco-docs-lookup.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docs-lookup.tsv.gz)            |    101 MB |               3,213,835  | tsv: docid, offset_trec, offset_tsv                            |
| Train  | [msmarco-doctrain-queries.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz)  |     15 MB |                 367,013  | tsv: qid, query                                                |
| Train  | [msmarco-doctrain-top100](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-top100.gz)            |    1.8 GB |              36,701,116  | TREC submission: qid, "Q0", docid, rank, score, runstring      |
| Train  | [msmarco-doctrain-qrels.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz)      |    7.6 MB |                 384,597  | TREC qrels format                                              |
| Train  | [msmarco-doctriples.py](https://github.com/microsoft/TREC-2019-Deep-Learning/blob/master/utils/msmarco-doctriples.py) |         - |                       -  | Python script generates training triples |
| Dev    | [msmarco-docdev-queries.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz)      |    216 KB |                   5,193  | tsv: qid, query                                                |
| Dev    | [msmarco-docdev-top100](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-top100.gz)        |       27 MB |                     519,300  | TREC submission: qid, "Q0", docid, rank, score, runstring      |
| Dev    | [msmarco-docdev-qrels.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz)          |    112 KB |                   5,478  | TREC qrels format                                              |
| Test    | [docleaderboard-queries.tsv](https://msmarco.z22.web.core.windows.net/msmarcoranking/docleaderboard-queries.tsv.gz)          |     124K |                   5,793  | tsv: qid, query                                              |
| Test    | [docleaderboard-top100](https://msmarco.z22.web.core.windows.net/msmarcoranking/docleaderboard-top100.tsv.gz)          |   2.9M |                  579,300  | TREC submission: qid, "Q0", docid, rank, score, runstring       |


In [None]:
## Download the following:
## - msmarco-docs.tsv
## - msmarco-doctrain-quries.tsv
## - msmarco-doctrain-qrels.tsv
## - msmarco-docdev-qrels.tsv

!wget -S --header="accept-encoding: gzip" \
    -O {raw_dir}/msmarco-docs.tsv.gz \
    https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docs.tsv.gz \

!wget -S --header="accept-encoding: gzip" \
    -O {raw_dir}/msmarco-doctrain-queries.tsv.gz \
    https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-queries.tsv.gz

!wget -S --header="accept-encoding: gzip" \
    -O {raw_dir}/msmarco-docdev-queries.tsv.gz \
    https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-queries.tsv.gz

!wget -S --header="accept-encoding: gzip" \
    -O {raw_dir}/msmarco-doctrain-qrels.tsv.gz \
    https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-doctrain-qrels.tsv.gz

!wget -S --header="accept-encoding: gzip" \
    -O {raw_dir}/msmarco-docdev-qrels.tsv.gz \
    https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz

In [4]:
def load_docs():
    docs = csv.read_csv(
        f'{raw_dir}/msmarco-docs.tsv.gz',
        read_options=csv.ReadOptions(block_size=2**25, column_names=['docid', 'url', 'title', 'doc']),
        parse_options=csv.ParseOptions(invalid_row_handler=lambda invalidrow:"skip", delimiter="\t")
    ).to_pandas()
    return docs

docs_full = load_docs()
print('docs_full.shape:', docs_full.shape)

docs.shape: (3213835, 4)


In [116]:
qrels_train_full = csv.read_csv(
    f"{raw_dir}/msmarco-doctrain-qrels.tsv.gz",
    read_options=csv.ReadOptions(column_names=["queryid","iteration","docid","relevance"]),
    parse_options=csv.ParseOptions(delimiter=" "),
).to_pandas()
queries_train_full = csv.read_csv(
    f"{raw_dir}/msmarco-doctrain-queries.tsv.gz",
    read_options=csv.ReadOptions(column_names=["queryid", "query"]),
    parse_options=csv.ParseOptions(delimiter="\t"),
).to_pandas()
queries_train_full = pd.merge(queries_train_full, qrels_train_full, on="queryid")
queries_train_full['query'] = queries_train_full['query'].str.strip()
print('qrels_train_full.shape:', qrels_train_full.shape)
print('queries_train_full.shape:', queries_train_full.shape)

qrels_dev_full = csv.read_csv(
    f"{raw_dir}/msmarco-docdev-qrels.tsv.gz",
    read_options=csv.ReadOptions(column_names=["queryid","iteration","docid","relevance"]),
    parse_options=csv.ParseOptions(delimiter=" "),
).to_pandas()
queries_dev_full = csv.read_csv(
    f"{raw_dir}/msmarco-docdev-queries.tsv.gz",
    read_options=csv.ReadOptions(column_names=["queryid", "query"]),
    parse_options=csv.ParseOptions(delimiter="\t"),
).to_pandas()
queries_dev_full['query'] = queries_dev_full['query'].str.strip()
queries_dev_full = pd.merge(queries_dev_full, qrels_dev_full, on="queryid")
print('qrels_dev_full.shape:', qrels_dev_full.shape)
print('queries_dev_full.shape:', queries_dev_full.shape)

qrels_train_full.shape: (367013, 4)
queries_train_full.shape: (367013, 5)
qrels_dev_full.shape: (5193, 4)
queries_dev_full.shape: (5193, 5)


## 1.2 Extract a part of MS MARCO as "Lite"

In [182]:
docids = pd.read_csv(f"{raw_dir}/docids.tsv")['docid']
docs = docs_full.set_index("docid").loc[docids].reset_index()

train = pd.merge(
    pd.read_csv(f"{raw_dir}/query2docid.train.tsv", sep='\t'), docs[['docid', 'title', 'doc']],
    on='docid', how="left")
dev = pd.merge(
    pd.read_csv(f"{raw_dir}/query2docid.dev.tsv", sep='\t'), docs[['docid', 'title', 'doc']],
    on='docid', how="left")

In [184]:
train['doc'] = train['title'] + " " + train['doc']
dev['doc'] = dev['title'] + " " + dev['doc']
docs['doc'] = docs['title'] + " " + docs['doc']

In [185]:
train[['query', 'docid', 'doc']].to_csv(f"{raw_dir}/train.tsv", sep="\t", index=False)
dev[['query', 'docid', 'doc']].to_csv(f"{raw_dir}/dev.tsv", sep="\t", index=False)

In [192]:
docs[['docid', 'doc']].to_csv(f"{raw_dir}/docs.tsv", sep="\t", index=False)

## 1.3 Generate queries with a document-to-query (docT5query) model

### 1.3.1 Download docT5model

In [193]:
import transformers

if not os.path.exists(f"{pretrained_dir}/doc2query-t5-base-msmarco"):
    identifier = "castorini/doc2query-t5-base-msmarco"
    model = transformers.T5ForConditionalGeneration.from_pretrained(identifier)
    tokenizer = transformers.T5TokenizerFast.from_pretrained(identifier)
    
    os.makedirs(pretrained_dir, exist_ok=True)
    model.save_pretrained(f"{pretrained_dir}/doc2query-t5-base-msmarco")
    tokenizer.save_pretrained(f"{pretrained_dir}/doc2query-t5-base-msmarco")



### 1.3.2 Generate five queries for each document

This generation process takes approximately 15 minutes on four RTX4090 GPUs.

In [202]:
import preprocess.genq
args = Namespace(
    model_path=genq_model_path,
    docs_path=f"{raw_dir}/docs.tsv",
    output_path=f"{cache_dir}/{genq_model_identifier}/genq.tsv",
    doc_max_len=512,
    query_max_len=32,
    genq_per_doc=5,
    n_gpus=4,
    batch_size=16,
)
preprocess.genq.main(args)

100%|████████████████████████████████████████████████████████████| 8654/8654 [16:33<00:00,  8.71it/s]


Initializing ...
Initialization finished
Initializing ...
Initialization finished
Initializing ...
Initialization finished


# 2. Hierarchical K-means Indexing (HKmI)

## 2.1 Produce for every document a BERT embedding

Encode every document in `docs.tsv` as a vector. The embeddings will be saved as `doc_emb.h5`.

The encoding process will take approximately 6 minutes on four RTX4090 GPUs.

In [4]:
import preprocess.bert_embedding
args = Namespace(
    docs_path=f"{raw_dir}/docs.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/doc_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="doc",
)
preprocess.bert_embedding.main(args)

Initializing ...
Initializing ...
Initializing ...
Initializing ...

  0%|                                                                     | 0/138457 [00:00<?, ?it/s]






Initialization finished


  0%|                                                         | 20/138457 [00:07<14:12:02,  2.71it/s]

Initialization finished
Initialization finished


  0%|                                                          | 60/138457 [00:07<3:46:32, 10.18it/s]

Initialization finished


100%|███████████████████████████████████████████████████████| 138457/138457 [04:03<00:00, 567.67it/s]


## 2.2 Apply K-means clustering on documents

Applying hierarchical K-means on document embeddings.

Every document will be assigned an ID string, which will be saved as `docid2index.tsv`

In [56]:
import preprocess.kmeans

args = Namespace(
    embedding_path=f"{cache_dir}/{bert_model_identifier}/doc_emb.h5",
    output_path=f"{cache_dir}/{bert_model_identifier}/docid2index.HKmI.tsv",
    v_dim=768,
    k=30,
    c=30,
    seed=7,
    n_init=1,   # can be increased to 10/100 to enhance quality at the cost of running time
    tol=1e-6,
)
preprocess.kmeans.main(args)

(138457, 768)
Wed May 22 00:15:52 2024 Start First Clustering
(138457,)
150
Wed May 22 00:15:53 2024 Start Recursively Clustering...


100%|████████████████████████████████████████████████████████████████| 30/30 [00:18<00:00,  1.59it/s]


## 2.3 Use document segments as queries

For every document, we random select 10~12 segments of 64 tokens as queries.

In [57]:
def get_seg(tokens):
    begin = random.randrange(0, max(1, len(tokens) - 64))
    end = begin + 64
    seg = ' '.join(tokens[begin: end])
    return seg

docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", index_col="docid", na_filter=False)

with open(f"{cache_dir}/docseg.tsv", "wt", encoding="utf8") as f:
    f.write("docid\tquery\n")

    for docid, doc in docs["doc"].items():
        tokens = doc.split(" ")
        nsegs = 10 + max(0, len(tokens)-3000) // 3000
        for _ in range(nsegs):
            seg = get_seg(tokens)
            f.write(f"{docid}\t{seg}\n")

## 2.4 Compiling training data for training retrieval model

- training set

A training sample should have three entries: *query, index, docid*

| File Name | Description |
| --- | --- |
| realq_train.tsv | real queries (ground truth) |
| genq.tsv  | generated queries from documents |
| title_abs.tsv | concatenation of document title and abstract as query |
| docseg.tsv | document segments as queries |

- dev (evaluation) set

A validation sample should have two entries: *query, docid*

| File Name | Description |
| --- | --- |
| realq_dev.tsv | real queries (ground truth) |

- supporting files

| File Name | Description |
| --- | --- |
| docid2index.tsv | mapping from docid to index, used for evaluation |

In [58]:
import BMI.io
from BMI.io import (
    StringIndexing,
    DocumentRetrievalTrainingFile,
    DocumentRetrievalInferenceFile,
    intarray_to_string,
)


In [6]:
hkmi_dirname = f"HKmI.{bert_model_identifier}.{genq_model_identifier}"
os.makedirs(f"{output_dir}/{hkmi_dirname}", exist_ok=True)

In [60]:
docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", na_filter=False)
docid2index = StringIndexing.from_tsv(f"{cache_dir}/{bert_model_identifier}/docid2index.HKmI.tsv")

In [61]:
# docid2index.tsv
docid2index.to_tsv(f"{output_dir}/{hkmi_dirname}/docid2index.tsv")
docid2index.to_pandas()

Unnamed: 0,docid,index
0,D3233725,"(28, 14, 16, 0)"
1,D1885729,"(23, 25, 1, 0)"
2,D15500,"(27, 20, 14, 0)"
3,D2456256,"(28, 2, 23, 0)"
4,D3205738,"(1, 7, 22, 0)"
...,...,...
138452,D2169873,"(25, 4, 12, 5)"
138453,D712256,"(5, 20, 20, 1)"
138454,D3221007,"(4, 17, 21, 1)"
138455,D3112790,"(12, 9, 10, 3)"


In [63]:
# realq_train.tsv
if 'train' not in globals():
    train = pd.read_csv(f"{raw_dir}/train.tsv", usecols=["docid", "query"], sep="\t")
docids = train['docid']
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=train["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/realq_train.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,another name for the primary visual cortex is,29-21-7-0,D2955018
1,the vitamin that prevents beriberi is,5-5-8-0,D508131
2,contextual spoken language understanding,20-6-27-6,D1350520
3,dosimetry medical definition,18-2-21-6,D304123
4,what color is royal,28-6-9-15,D1450821
...,...,...,...
183942,why did scientists suspect that the moon coole...,25-9-0-5,D2781869
183943,why did rosa parks protest,15-10-25-2,D2008201
183944,amex india customer care number,17-14-26-3,D630512
183945,_________ justice is designed to repair the ha...,18-18-8-1,D59235


In [64]:
# realq_dev.tsv
if 'dev' not in globals():
    dev = pd.read_csv(f"{raw_dir}/dev.tsv", usecols=["docid", "query"], sep="\t")
docids = dev['docid']

file = DocumentRetrievalInferenceFile(
    queries=dev["query"],
    docids=docids,
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/realq_dev.tsv")
file.to_pandas()

Unnamed: 0,query,docid
0,androgen receptor define,D1650436
1,3/5 of 60,D1547717
2,does suddenlink carry espn3,D2830290
3,explain what a bone scan is and what it is use...,D125453
4,is the louisiana sales tax 4.75,D2523421
...,...,...
2787,why do people use gypsum in soil,D977977
2788,why do people grind teeth in sleep,D3062847
2789,why do jefferson and stanton include these sim...,D2361582
2790,why do children get aggressive,D1073324


In [67]:
# title_abs.tsv
title_abs = docs["doc"].apply(lambda doc: ' '.join(doc.split()[:64]))
docids = docs["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=title_abs,
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/title_abs.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,Dogo Argentino Dogo Argentino Miscellaneous Th...,28-14-16-0,D3233725
1,How to Kill Weeds Without Killing Plants Weeds...,23-25-1-0,D1885729
2,"How to Learn Martial Arts ""Pressure Points"" Ed...",27-20-14-0,D15500
3,TMG Trimethylglycine Swanson Ultra TMG Trimeth...,28-2-23-0,D2456256
4,Elavil Could you or a loved one be experiencin...,1-7-22-0,D3205738
...,...,...,...
138452,Fact-checking immigration Fact-checking immigr...,25-4-12-5,D2169873
138453,Barium Swallow See related health topics and r...,5-20-20-1,D712256
138454,Best Integrated Development Environment (IDE) ...,4-17-21-1,D3221007
138455,Fin 310 Chapter 2 12 terms ty_hentges Fin 310 ...,12-9-10-3,D3112790


In [68]:
# genq.tsv
genq = pd.read_csv(f"{cache_dir}/{genq_model_identifier}/genq.tsv", usecols=["docid", "query"], sep="\t")
docids = genq["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=genq["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/genq.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,what is the color of my dogo rancho,28-14-16-0,D3233725
1,what is the breed name for a argentino mastiff,28-14-16-0,D3233725
2,what is the breed of dogo argentino,28-14-16-0,D3233725
3,what is the length of a dogo argentino?,28-14-16-0,D3233725
4,what type of dogo is the argentino doodle,28-14-16-0,D3233725
...,...,...,...
692280,distance between rome and ancona italy,17-13-12-12,D2803363
692281,distance from ancona to rome italy,17-13-12-12,D2803363
692282,distance rome italy to ancona,17-13-12-12,D2803363
692283,how far is italy from ancona,17-13-12-12,D2803363


In [69]:
# docseg.tsv
docseg = pd.read_csv(f"{cache_dir}/docseg.tsv", usecols=["docid", "query"], sep="\t")
docids = docseg["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=docseg["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{hkmi_dirname}/docseg.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,Dogos require vigorous exercise to stay at the...,28-14-16-0,D3233725
1,its physical virtues turn it into a real athle...,28-14-16-0,D3233725
2,pounds Life Expectancy: 9-15 years Barking Lev...,28-14-16-0,D3233725
3,& Feeding Good nutrition for Dogo Argentinos i...,28-14-16-0,D3233725
4,Expectancy: 9-15 years Barking Level: Barks Wh...,28-14-16-0,D3233725
...,...,...,...
1391428,driving distance for a different route. If you...,17-13-12-12,D2803363
1391429,"another possible route, you can try Google Map...",17-13-12-12,D2803363
1391430,"from Ancona, Italy to Rome, Italy is:193 miles...",17-13-12-12,D2803363
1391431,Rome road trip Map of driving directions from ...,17-13-12-12,D2803363


# 3. Bottleneck-Minimal Indexing (BMI)

## 3.1 Produce for every query a BERT embedding

In [4]:
import preprocess.bert_embedding

### 3.1.1 RealQ: real queries (training set)

In [7]:
args = Namespace(
    docs_path=f"{output_dir}/{hkmi_dirname}/realq_train.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/realq_train_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

Initializing ...
Initializing ...
Initializing ...
Initializing ...




Initialization finished
Initialization finished
Initialization finished
Initialization finished


100%|██████████████████████████████████████████████████████| 183947/183947 [00:29<00:00, 6145.55it/s]


### 3.1.2 GenQ: generated queries by the finetuned document-to-query model

In [8]:
args = Namespace(
    docs_path=f"{cache_dir}/{genq_model_identifier}/genq.tsv",
    output_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/genq_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

Initializing ...
Initializing ...
Initializing ...
Initializing ...




Initialization finished
Initialization finished
Initialization finished
Initialization finished


100%|██████████████████████████████████████████████████████| 692285/692285 [01:42<00:00, 6759.89it/s]


### 3.1.3 DocSeg: using document segments as queries 

In [9]:
args = Namespace(
    docs_path=f"{output_dir}/{hkmi_dirname}/docseg.tsv",
    output_path=f"{cache_dir}/{bert_model_identifier}/docseg_emb.h5",
    model_path=bert_model_path,
    max_len=512,
    n_gpus=4,
    text_col="query",
)
preprocess.bert_embedding.main(args)

Initializing ...
Initializing ...
Initializing ...
Initializing ...




Initialization finished
Initialization finished


  0%|                                                       | 180/1391433 [00:05<11:54:49, 32.44it/s]

Initialization finished
Initialization finished


100%|████████████████████████████████████████████████████| 1391433/1391433 [05:22<00:00, 4315.91it/s]


## 3.2 Apply K-means clustering on documents

### 3.2.1 Calculate centroid vector for every document

In [10]:
from BMI.io import IndexedEmbeddings

X, ids = [], []
for path in [
    f"{cache_dir}/{bert_model_identifier}/realq_train_emb.h5",
    f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/genq_emb.h5",
    f"{cache_dir}/{bert_model_identifier}/docseg_emb.h5",
]:
    embeddings = IndexedEmbeddings.from_h5(path)
    X.append(embeddings.embs)
    ids.append(embeddings.ids)
X = np.concatenate(X, axis=0)
ids = np.concatenate(ids, axis=0)

In [11]:
embs = pd.DataFrame({'emb': list(X), 'docid': ids})
centroids = embs.groupby("docid").apply(lambda slice: np.stack(slice['emb'].values).mean(0))

  centroids = embs.groupby("docid").apply(lambda slice: np.stack(slice['emb'].values).mean(0))


In [12]:
centroids

docid
D1000111    [-0.67136544, 0.12668706, -0.15803394, -0.4214...
D1000120    [-0.11027141, -0.0127727715, 0.05979115, 0.062...
D1000128    [-0.3655967, -0.004500894, -0.17365022, -0.298...
D1000147    [-0.7446443, 0.051425174, 0.006697178, 0.07320...
D1000171    [-0.046932623, 0.051572345, 0.0068646716, -0.1...
                                  ...                        
D999957     [-0.4969293, -0.28059012, 0.23285414, -0.10521...
D999968     [-0.24978884, -0.1374097, -0.29497665, -0.0315...
D999978     [-0.4752474, -0.3227341, 0.37801948, 0.1673784...
D999990     [-0.30692142, -0.0790581, 0.25198063, -0.08345...
D999994     [-0.31444538, -0.12734632, -0.49284962, -0.184...
Length: 138457, dtype: object

In [15]:
path = f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/doc_emb.centroid.realq_genq_docseg.h5"
embeddings = IndexedEmbeddings(ids=centroids.index.tolist(), embs=np.stack(centroids.values, dtype=np.float32))
embeddings.to_h5(path)

### 3.2.2 Run k-means

In [16]:
import preprocess.kmeans
import importlib
importlib.reload(preprocess.kmeans)

args = Namespace(
    embedding_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/doc_emb.centroid.realq_genq_docseg.h5",
    output_path=f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/docid2index.BMI.realq_genq_docseg.tsv",
    v_dim=768,
    k=30,
    c=30,
    seed=7,
    n_init=1,   # can be increased to 10/100 to enhance quality at the cost of running time
    tol=1e-6,
)
preprocess.kmeans.main(args)

(138457, 768)
Wed May 22 00:32:43 2024 Start First Clustering
(138457,)
105
Wed May 22 00:32:44 2024 Start Recursively Clustering...


100%|████████████████████████████████████████████████████████████████| 30/30 [00:18<00:00,  1.65it/s]


## 3.3 Compiling training data for training retrieval model

- training set

A training sample should have three entries: *query, index, docid*

| File Name | Description |
| --- | --- |
| realq_train.tsv | real queries (ground truth) |
| genq.tsv  | generated queries from documents |
| title_abs.tsv | concatenation of document title and abstract as query |
| docseg.tsv | document segments as queries |

- dev (evaluation) set

A validation sample should have two entries: *query, docid*

| File Name | Description |
| --- | --- |
| realq_dev.tsv | real queries (ground truth) |

- supporting files

| File Name | Description |
| --- | --- |
| docid2index.tsv | mapping from docid to index, used for evaluation |

In [18]:
import BMI.io
from BMI.io import (
    StringIndexing,
    DocumentRetrievalTrainingFile,
    DocumentRetrievalInferenceFile,
    intarray_to_string,
)

In [19]:
bmi_dirname = f"BMI.{bert_model_identifier}.{genq_model_identifier}.realq_genq_docseg"
os.makedirs(f"{output_dir}/{bmi_dirname}", exist_ok=True)

In [29]:
docs = pd.read_csv(f"{raw_dir}/docs.tsv", sep="\t", na_filter=False)
docid2index = StringIndexing.from_tsv(f"{cache_dir}/{genq_model_identifier}.{bert_model_identifier}/docid2index.BMI.realq_genq_docseg.tsv")

In [30]:
# docid2index.tsv
docid2index.to_tsv(f"{output_dir}/{bmi_dirname}/docid2index.tsv")
docid2index.to_pandas()

Unnamed: 0,docid,index
0,D1000111,"(27, 7, 3, 0)"
1,D1000120,"(19, 27, 29, 0)"
2,D1000128,"(20, 7, 9, 5, 0)"
3,D1000147,"(13, 10, 9, 0)"
4,D1000171,"(26, 0, 16, 0)"
...,...,...
138452,D999957,"(17, 23, 9, 10)"
138453,D999968,"(7, 6, 15, 13)"
138454,D999978,"(3, 25, 15, 12)"
138455,D999990,"(2, 10, 8, 4)"


In [31]:
# realq_train.tsv
if 'train' not in globals():
    train = pd.read_csv(f"{raw_dir}/train.tsv", sep="\t")
docids = train['docid']
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=train["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/realq_train.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,another name for the primary visual cortex is,0-2-17-3,D2955018
1,the vitamin that prevents beriberi is,5-6-12-12,D508131
2,contextual spoken language understanding,11-12-16-0,D1350520
3,dosimetry medical definition,25-2-14-7,D304123
4,what color is royal,29-4-11-3,D1450821
...,...,...,...
183942,why did scientists suspect that the moon coole...,10-12-2-4,D2781869
183943,why did rosa parks protest,7-16-8-1,D2008201
183944,amex india customer care number,3-28-0-2,D630512
183945,_________ justice is designed to repair the ha...,11-1-1-9,D59235


In [32]:
# realq_dev.tsv
if 'dev' not in globals():
    dev = pd.read_csv(f"{raw_dir}/dev.tsv", sep="\t")
docids = dev['docid']

file = DocumentRetrievalInferenceFile(
    queries=dev["query"],
    docids=docids,
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/realq_dev.tsv")
file.to_pandas()

Unnamed: 0,query,docid
0,androgen receptor define,D1650436
1,3/5 of 60,D1547717
2,does suddenlink carry espn3,D2830290
3,explain what a bone scan is and what it is use...,D125453
4,is the louisiana sales tax 4.75,D2523421
...,...,...
2787,why do people use gypsum in soil,D977977
2788,why do people grind teeth in sleep,D3062847
2789,why do jefferson and stanton include these sim...,D2361582
2790,why do children get aggressive,D1073324


In [33]:
# title_abs.tsv
title_abs = docs['doc'].apply(lambda doc: " ".join(doc.split()[:64]))
docids = docs["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=title_abs,
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/title_abs.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,Dogo Argentino Dogo Argentino Miscellaneous Th...,29-12-4-10,D3233725
1,How to Kill Weeds Without Killing Plants Weeds...,22-29-13-2,D1885729
2,"How to Learn Martial Arts ""Pressure Points"" Ed...",11-4-17-4,D15500
3,TMG Trimethylglycine Swanson Ultra TMG Trimeth...,6-7-13-4,D2456256
4,Elavil Could you or a loved one be experiencin...,27-23-26-13,D3205738
...,...,...,...
138452,Fact-checking immigration Fact-checking immigr...,17-1-21-11,D2169873
138453,Barium Swallow See related health topics and r...,9-1-7-9,D712256
138454,Best Integrated Development Environment (IDE) ...,11-16-16-8,D3221007
138455,Fin 310 Chapter 2 12 terms ty_hentges Fin 310 ...,11-10-1-4,D3112790


In [34]:
# genq.tsv
genq = pd.read_csv(f"{cache_dir}/{genq_model_identifier}/genq.tsv", usecols=["docid", "query"], sep="\t")
docids = genq["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=genq["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/genq.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,what is the color of my dogo rancho,29-12-4-10,D3233725
1,what is the breed name for a argentino mastiff,29-12-4-10,D3233725
2,what is the breed of dogo argentino,29-12-4-10,D3233725
3,what is the length of a dogo argentino?,29-12-4-10,D3233725
4,what type of dogo is the argentino doodle,29-12-4-10,D3233725
...,...,...,...
692280,distance between rome and ancona italy,3-13-0-2,D2803363
692281,distance from ancona to rome italy,3-13-0-2,D2803363
692282,distance rome italy to ancona,3-13-0-2,D2803363
692283,how far is italy from ancona,3-13-0-2,D2803363


In [35]:
# docseg.tsv
docseg = pd.read_csv(f"{cache_dir}/docseg.tsv", usecols=["docid", "query"], sep="\t")
docids = docseg["docid"]
indexes = docids.apply(docid2index.get_index)

file = DocumentRetrievalTrainingFile(
    queries=docseg["query"],
    docids=docids,
    indexes=indexes.apply(intarray_to_string),
)
file.to_tsv(f"{output_dir}/{bmi_dirname}/docseg.tsv")
file.to_pandas()

Unnamed: 0,query,index,docid
0,Dogos require vigorous exercise to stay at the...,29-12-4-10,D3233725
1,its physical virtues turn it into a real athle...,29-12-4-10,D3233725
2,pounds Life Expectancy: 9-15 years Barking Lev...,29-12-4-10,D3233725
3,& Feeding Good nutrition for Dogo Argentinos i...,29-12-4-10,D3233725
4,Expectancy: 9-15 years Barking Level: Barks Wh...,29-12-4-10,D3233725
...,...,...,...
1391428,driving distance for a different route. If you...,3-13-0-2,D2803363
1391429,"another possible route, you can try Google Map...",3-13-0-2,D2803363
1391430,"from Ancona, Italy to Rome, Italy is:193 miles...",3-13-0-2,D2803363
1391431,Rome road trip Map of driving directions from ...,3-13-0-2,D2803363
