# ColBERTv2: Indexing & Search Notebook

We start by importing the relevant classes. As we'll see below, `Indexer` and `Searcher` are the key actors here. 

In [105]:
import os
import sys
# sys.path.insert(0, '../')
# sys.path.insert(0, '/home/zhanj289/projects/cs224u_nlu_project/ColBERT')
sys.path.insert(0, './ColBERT')
# os.chdir('ColBERT')
import csv
import pandas as pd

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection
from colbert import Indexer, Searcher

In [273]:
dataset = 'bioasq'

datasplit = 'all'

The workflow here assumes an IR dataset: a set of queries and a corresponding collection of passages.

The classes `Queries` and `Collection` provide a convenient interface for working with such datasets.

### ColBERT model pretrained

In [None]:
!mkdir -p downloads/

# ColBERTv2 checkpoint trained on MS MARCO Passage Ranking (388MB compressed)
!wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz -P downloads/
!tar -xvzf downloads/colbertv2.0.tar.gz -C downloads/

# The LoTTE dev and test sets (3.4GB compressed)
# !wget https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/lotte.tar.gz -P downloads/
# !tar -xvzf downloads/lotte.tar.gz -C downloads/

#### Prepare data for indexing

In [97]:
import json
with open('./data/bioasq/training10b.json', 'r') as f:
    bioasq_json = json.load(f)

In [261]:
### Construct collections of passages

def generate_bioasq_passage(bioasq_json, length = 1024):
    
    bioasq_passage= []
    eid = 0
    
    for i in range(len(bioasq_json['questions'])):

        sample = bioasq_json['questions'][i]

        if sample['type'] in ['factoid', 'list']:
            
            
        # flatten all the snip and use as context
            context = '' 
            for snip in [ele['text'].strip() for ele in sample['snippets']]:
                snip += ' '
                context += snip
                
            ## some cleaning string
            context = context.replace('\n', ' ')
            
            ## limit the length of context
            ### Max: 4096 (for eleuther model)
            context = context[:1024]
            
            # create tuple
            context_tup = (str(eid), context)

            bioasq_passage.append(context_tup)
            # adding id
            eid +=1
            
    return bioasq_passage

In [262]:
### Construct collections of queries

def generate_bioasq_query(bioasq_json, length = 1024):
    
    bioasq_query= []
    
    eid = 0
    
    for i in range(len(bioasq_json['questions'])):
        
        
        sample = bioasq_json['questions'][i]

        if sample['type'] in ['factoid', 'list']:
            
            # adding id
            
            
            query = sample['body']
            
            ## some cleaning string
            query = query.replace('\n', ' ')
        # flatten all the snip and use as context

            ## limit the length of context
            ### Max: 4096 (for eleuther model)
            query = query[:1024]
            
            # create 
            new_query = (str(eid), query)
            
            bioasq_query.append(new_query)
            
            eid += 1
            
    return bioasq_query

In [263]:
## how long? probably starting from 1024
bioasq_passage = generate_bioasq_passage(bioasq_json, 1024)

In [264]:
bioasq_query = generate_bioasq_query(bioasq_json, 1024)

In [265]:
with open('experiments/bioasq_passage.tsv', 'w+', newline='') as f_output:
    csv_output = csv.writer(f_output, delimiter='\t')

    csv_output.writerows(bioasq_passage)

In [266]:
with open('experiments/bioasq_query.tsv', 'w+', newline='') as f_output:
    csv_output = csv.writer(f_output, delimiter='\t')

    csv_output.writerows(bioasq_query)

In [267]:
test1 = pd.read_csv('experiments/bioasq_passage.tsv', sep = '\t')
test1.head()

In [None]:
test2 = pd.read_csv('experiments/bioasq_query.tsv', sep ='\t', header=None)
test2.head()

In [270]:
queries = Queries(path='experiments/bioasq_query.tsv')
collection = Collection(path='experiments/bioasq_passage.tsv')

[May 25, 22:30:01] #> Loading the queries from experiments/bioasq_query.tsv ...
[May 25, 22:30:01] #> Got 2068 queries. All QIDs are unique.

[May 25, 22:30:01] #> Loading collection...
0M 


In [272]:
print(queries[24])
print()
print(collection[200])
print()

Treatment of which disease was investigated in the MR CLEAN study?

Exome Sequencing Identifies a Rare HSPG2 Variant Associated with Familial Idiopathic Scoliosis. Overall, these findings demonstrate a novel role for kif6 in spinal development and identify a new candidate gene for human idiopathic scoliosis. HL1 is of interest, as it encodes an axon guidance protein related to Robo3. Mutations in the Robo3 protein cause horizontal gaze palsy with progressive scoliosis (HGPPS), a rare disease marked by severe scoliosis. Other top associations in our GWAS were with SNPs in the DSCAM gene encoding an axon guidance protein in the same structural class with Chl1 and Robo3.



## Indexing

For efficient search, we can pre-compute the ColBERT representation of each passage and index them.

Below, the `Indexer` take a model checkpoint and writes a (compressed) index to disk. We then prepare a `Searcher` for retrieval from this index.

(With four Titan V GPUs, indexing should take about 13 minutes. The output is fairly long/ugly at the moment!)

In [274]:
nbits = 2   # encode each dimension with 2 bits
doc_maxlen = 300   # truncate passages at 300 tokens

checkpoint = 'ColBERT/docs/downloads/colbertv2.0'
index_name = f'{dataset}.{datasplit}.{nbits}bits'

In [275]:
index_name

'bioasq.all.2bits'

In [276]:
with Run().context(RunConfig(nranks=1, experiment='bioasq')):  # nranks specifies the number of GPUs to use.
    config = ColBERTConfig(doc_maxlen=doc_maxlen, nbits=nbits)

    indexer = Indexer(checkpoint=checkpoint, config=config)
    indexer.index(name=index_name, collection=collection, overwrite=True)



[May 25, 22:35:21] #> Creating directory /home/zhanj289/projects/cs224u_nlu_project/experiments/bioasq/indexes/bioasq.all.2bits 


#> Starting...
nranks = 1 	 num_gpus = 1 	 device=0
{
    "nprobe": 2,
    "ncandidates": 8192,
    "index_path": null,
    "nbits": 2,
    "kmeans_niters": 20,
    "similarity": "cosine",
    "bsize": 64,
    "accumsteps": 1,
    "lr": 1e-5,
    "maxsteps": 400000,
    "save_every": null,
    "resume": false,
    "warmup": 20000,
    "warmup_bert": null,
    "relu": false,
    "nway": 64,
    "use_ib_negatives": true,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "query_maxlen": 32,
    "attend_to_mask_tokens": false,
    "interaction": "colbert",
    "dim": 128,
    "doc_maxlen": 300,
    "mask_punctuation": true,
    "checkpoint": "ColBERT\/docs\/downloads\/colbertv2.0",
    "triples": "\/future\/u\/okhattab\/root\/unit\/experiments\/2021.10\/downstream.distillation.round2.2_score\/round2.nway6.cosine.ib\/example

0it [00:00, ?it/s]

[May 25, 22:37:01] [0] 		 #> Saving chunk 0: 	 2,068 passages and 378,124 embeddings. From #0 onward.


1it [00:17, 17.73s/it]


[May 25, 22:37:04] [0] 		 #> Saving the indexing metadata to /home/zhanj289/projects/cs224u_nlu_project/experiments/bioasq/indexes/bioasq.all.2bits/metadata.json ..
#> Joined...


In [277]:
indexer.get_index() # You can get the absolute path of the index, if needed.

'/home/zhanj289/projects/cs224u_nlu_project/experiments/bioasq/indexes/bioasq.all.2bits'

## Search

Having built the index and prepared our `searcher`, we can search for individual query strings.

We can use the `queries` set we loaded earlier — or you can supply your own questions. Feel free to get creative! But keep in mind this set of ~300k lifestyle passages can only answer a small, focused set of questions!

In [278]:
# To create the searcher using its relative name (i.e., not a full path), set
# experiment=value_used_for_indexing in the RunConfig.
with Run().context(RunConfig(experiment='bioasq')):
    searcher = Searcher(index=index_name)


# If you want to customize the search latency--quality tradeoff, you can also supply a
# config=ColBERTConfig(nprobe=.., ncandidates=..) argument. The default (2, 8192) works well,
# but you can trade away some latency to gain more extensive search with (4, 16384).
# Conversely, you can get faster search with (1, 4096).

[May 25, 22:37:36] #> Loading collection...
0M 
[May 25, 22:37:41] #> Building the emb2pid mapping..
[May 25, 22:37:41] len(self.emb2pid) = 378124


In [62]:
results

([3736, 7629, 656, 2724, 314],
 [1, 2, 3, 4, 5],
 [15.3125, 14.84375, 14.6875, 14.6484375, 14.5703125])

In [279]:
len(searcher.collection)

2068

In [281]:
searcher.collection[500]

'Feline leukemia virus subgroup C receptor (FLVCR1), a member of the SLC49 family of four paralogous genes, is a cell surface heme exporter, essential for erythropoiesis and systemic iron homeostasis. Feline leukemia virus subgroup C receptor (FLVCR1), a member of the SLC49 family of four paralogous genes, is a cell surface heme exporter, essential for erythropoiesis and systemic iron homeostasis. Feline leukemia virus subgroup C receptor (FLVCR1), a member of the SLC49 family of four paralogous genes, is a cell surface heme exporter, essential for erythropoiesis and systemic iron homeostasis. Disruption of FLVCR1 function blocks development of erythroid progenitors, likely due to heme toxicity. Heme is critical for a variety of cellular processes, but excess intracellular heme may result in oxidative stress and membrane injury. Feline leukemia virus subgroup C receptor (FLVCR1), a member of the SLC49 family of four paralogous genes, is a cell surface heme exporter, essential for eryth

In [None]:
query = queries[37]   # or supply your own query

print(f"#> {query}")

# Find the top-3 passages for this query
results = searcher.search(query, k=5)

# Print out the top-k retrieved passages
for passage_id, passage_rank, passage_score in zip(*results):
    print(f"\t [{passage_rank}] \t\t {passage_score:.1f} \t\t {searcher.collection[passage_id]}")

## Batch Search

In many applications, you have a large batch of queries and you need to maximize the overall throughput. For that, you can use the `searcher.search_all(queries, k)` method, which returns a `Ranking` object that organizes the results across all queries.

(Batching provides many opportunities for higher-throughput search, though we have not implemented most of those optimizations for compressed indexes yet.)

In [283]:
rankings = searcher.search_all(queries, k=5).todict()

100%|██████████| 2068/2068 [00:19<00:00, 105.99it/s]


In [None]:
rankings[30]  # For query 30, a list of (passage_id, rank, score) for the top-k passages