## Preprocess data

In [5]:
#Import libraries for preprocessing step
import os
import sys
import jsonlines
import json

from dotenv import load_dotenv
load_dotenv()
sys.path.append(os.getenv('ROOT_DIR'))

In [6]:
#Create all constant variables
PROCESSED_DATA_DIR = os.getenv('PROCESSED_DATA_DIR')
TESTING_DATA_DIR = os.getenv('TESTING_DATA_DIR')
RAW_DATA_DIR = os.getenv('RAW_DATA_DIR')
PROCESSED_TESTING_CONTEXTS_PATH = os.path.join(PROCESSED_DATA_DIR,"contexts-test.jsonl")
PROCESSED_TESTING_QUERIES_PATH = os.path.join(PROCESSED_DATA_DIR,"queries-test.jsonl")

DATA_LANG_DICT = {"squad":"en", "korquad":"ko", "fquad":"fr", "germanquad":"de", "uitviquad":"vi"}


### Set up

In [3]:
#set up
if not os.path.exists(PROCESSED_DATA_DIR):
    os.system("mkdir {dir}".format(dir=PROCESSED_DATA_DIR))
else:
    for filename in os.listdir(PROCESSED_DATA_DIR):
        if filename.endswith("-test.jsonl") or filename.endswith("-test.json"):
            os.system("rm -f {file}".format(file = os.path.join(PROCESSED_DATA_DIR, filename)))
        
if not os.path.exists(TESTING_DATA_DIR):
    os.system("mkdir {dir}".format(dir=TESTING_DATA_DIR))
else:
    for filename in os.listdir(TESTING_DATA_DIR):
        os.system("rm -f {file}".format(file = os.path.join(PROCESSED_DATA_DIR, filename)))

### Process raw data

In [8]:
#combine all raw data
import json
import os
from dotenv import load_dotenv
load_dotenv()

def get_data(input_filename, lang):
    """
    Convert hierarchical data in input_filename to tabular data and save it to output_filename.
    """ 
    # Load the data from the input file
    with open(input_filename) as input_file:
        data = json.load(input_file)
        if "data" in data:
            data = data["data"]

    res = []
    for article in data:
        for p in article["paragraphs"]:

            #add new context to context file

            for qas in p["qas"]:
                if not qas["answers"]:
                    continue
                
                #add new query to query file
                new_query = {"query": qas["question"], "context":p["context"], "query_lang":lang, "context_lang":lang}
                res.append(new_query)
    return res


all_data = []
for data_name in DATA_LANG_DICT.keys():
    all_data += get_data(os.path.join(RAW_DATA_DIR, "{data_name}-test.json".format(data_name=data_name)), DATA_LANG_DICT[data_name])

with open(os.path.join(PROCESSED_DATA_DIR, "test.json"), "w") as file:
    json.dump(all_data, file)

### Translate testing data

In [None]:
import os
import sys
import jsonlines
import json

from dotenv import load_dotenv
load_dotenv()
sys.path.append(os.getenv('ROOT_DIR'))

PROCESSED_DATA_DIR = os.getenv('PROCESSED_DATA_DIR')
TESTING_DATA_DIR = os.getenv('TESTING_DATA_DIR')

DATA_LANG_DICT = {"squad":"en", "korquad":"ko", "fquad":"fr", "germanquad":"de", "uitviquad":"vi"}

#translate data
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

#Load translation model a.k.a mbart fine-tuned checkpt
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
model.cuda()
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
lang_dict = {"en":"en_XX", "de":"de_DE", "ko":"ko_KR", "vi":"vi_VN", "fr":"fr_XX"}

#load original data
with open(os.path.join(PROCESSED_DATA_DIR, "test.json")) as input_file:
    data = json.load(input_file)

n = len(data)
new_data = []
query_dict = {}

for row in data:
    new_data.append(row)
    if row["context_lang"] not in query_dict:
        query_dict[row["context_lang"]] = []
    query_dict[row["context_lang"]].append(row)
print("Finished adding existing data to output file")

batch_size = 8
for output_lang in DATA_LANG_DICT.values():
    for input_lang in query_dict.keys():
        tokenizer.src_lang = lang_dict[input_lang]

        if input_lang == output_lang:
            continue
        lang_data = query_dict[input_lang]
        num_batch = 0
        total_batch = int(len(lang_data)/batch_size)
        for start in range(0,len(lang_data),batch_size):
            batch = [lang_data[i]["query"] for i in range(start, min(start+batch_size,len(lang_data)))]
            
            encoded_inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")
            generated_tokens = model.generate(
                **encoded_inputs,
                forced_bos_token_id=tokenizer.lang_code_to_id[lang_dict[output_lang]]
            )
            translated_queries = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

            for i, j in enumerate(range(start, min(start+batch_size,len(lang_data)))):
                new_row = {"query": translated_queries[i], "context":lang_data[i]["context"], "query_lang":output_lang, "context_lang":input_lang}
                new_data.append(new_row)
            num_batch += 1
            print("Finished {num_batch}/{total_batch} translating from {input_lang} to {output_lang}\n".format(num_batch=num_batch, total_batch = total_batch, input_lang=input_lang, output_lang=output_lang))
        print("Finished translating from {input_lang} to {ouput_lang}".format(input_lang=input_lang, output_lang=output_lang))

with open(os.path.join(TESTING_DATA_DIR, "test-batch.json"), "w") as file:
    json.dump(new_data, file)

## Testing

In [3]:
import json
import os

TESTING_DATA_DIR = os.getenv('TESTING_DATA_DIR')

### Get Ground Truth and Retrieve queries and contexts

In [4]:
contexts = set()

with open(os.path.join(TESTING_DATA_DIR, "test-batch.json")) as f:
    data = json.load(f)
    for line in data:
        contexts.add(line["context"])

## Testing

In [6]:
%pip install "colbert-ir[faiss-gpu, torch]"

Note: you may need to restart the kernel to use updated packages.


In [5]:
import colbert
from colbert import Indexer, Searcher
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection

nbits = 2   # encode each dimension with 2 bits
doc_maxlen = 300   # truncate passages at 300 tokens
index_name = f'{nbits}bits'

checkpoint = '/Volumes/Users/ly_k1/Documents/mColBERT/experiments/default/none/2024-03/29/14.20.21/checkpoints/colbert'
collection = list(contexts)

In [5]:
#index
with Run().context(RunConfig(nranks=1, experiment='testing-2')):  # nranks specifies the number of GPUs to use.
    config = ColBERTConfig(doc_maxlen=doc_maxlen, nbits=nbits)

    indexer = Indexer(checkpoint=checkpoint, config=config)
    indexer.index(name=index_name, collection=collection, overwrite=True)
indexer.get_index() # You can get the absolute path of the index, if needed.



[Apr 07, 11:34:42] #> Creating directory /Volumes/Users/ly_k1/Documents/mColBERT/notebooks/experiments/testing-2/indexes/2bits 


#> Starting...
nranks = 1 	 num_gpus = 1 	 device=0
{
    "query_token_id": "[unused0]",
    "doc_token_id": "[unused1]",
    "query_token": "[Q]",
    "doc_token": "[D]",
    "ncells": null,
    "centroid_score_threshold": null,
    "ndocs": null,
    "load_index_with_mmap": false,
    "index_path": null,
    "nbits": 2,
    "kmeans_niters": 4,
    "resume": false,
    "similarity": "cosine",
    "bsize": 64,
    "accumsteps": 1,
    "lr": 1e-5,
    "maxsteps": 500000,
    "save_every": null,
    "warmup": 20000,
    "warmup_bert": null,
    "relu": false,
    "nway": 2,
    "use_ib_negatives": true,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "model_name": null,
    "query_maxlen": 32,
    "attend_to_mask_tokens": false,
    "interaction": "colbert",
    "dim": 128,
    "doc_maxlen": 300,
    "mask_punctuation": 

[W socket.cpp:426] [c10d] The server socket cannot be initialized on [::]:13155 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:13155 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [::ffff:127.0.0.1]:13155 (errno: 97 - Address family not supported by protocol).


[Apr 07, 11:34:45] [0] 		 # of sampled PIDs = 4826 	 sampled_pids[:3] = [3412, 83, 2446]
[Apr 07, 11:34:45] [0] 		 #> Encoding 4826 passages..
[Apr 07, 11:34:55] [0] 		 avg_doclen_est = 208.1322021484375 	 len(local_sample) = 4,826
[Apr 07, 11:34:57] [0] 		 Creaing 8,192 partitions.
[Apr 07, 11:34:57] [0] 		 *Estimated* 1,004,446 embeddings.
[Apr 07, 11:34:57] [0] 		 #> Saving the indexing plan to /Volumes/Users/ly_k1/Documents/mColBERT/notebooks/experiments/testing-2/indexes/2bits/plan.json ..
Clustering 954446 points in 128D to 8192 clusters, redo 1 times, 4 iterations
  Preprocessing in 0.04 s
  Iteration 3 (0.81 s, search 0.74 s): objective=404329 imbalance=1.240 nsplit=0       
[Apr 07, 11:35:01] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 07, 11:35:01] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[0.05, 0.049, 0.047, 0.048, 0.046, 0.049, 0.046, 0.048, 0.05, 0.04

0it [00:00, ?it/s]

[Apr 07, 11:35:12] [0] 		 #> Saving chunk 0: 	 4,826 passages and 1,004,446 embeddings. From #0 onward.


1it [00:10, 10.26s/it]
100%|██████████| 1/1 [00:00<00:00, 583.11it/s]
100%|██████████| 8192/8192 [00:00<00:00, 225374.79it/s]


[Apr 07, 11:35:12] [0] 		 #> Checking all files were saved...
[Apr 07, 11:35:12] [0] 		 Found all files!
[Apr 07, 11:35:12] [0] 		 #> Building IVF...
[Apr 07, 11:35:12] [0] 		 #> Loading codes...
[Apr 07, 11:35:12] [0] 		 Sorting codes...
[Apr 07, 11:35:12] [0] 		 Getting unique codes...
[Apr 07, 11:35:12] #> Optimizing IVF to store map from centroids to list of pids..
[Apr 07, 11:35:12] #> Building the emb2pid mapping..
[Apr 07, 11:35:12] len(emb2pid) = 1004446
[Apr 07, 11:35:12] #> Saved optimized IVF to /Volumes/Users/ly_k1/Documents/mColBERT/notebooks/experiments/testing-2/indexes/2bits/ivf.pid.pt
[Apr 07, 11:35:12] [0] 		 #> Saving the indexing metadata to /Volumes/Users/ly_k1/Documents/mColBERT/notebooks/experiments/testing-2/indexes/2bits/metadata.json ..
#> Joined...


'/Volumes/Users/ly_k1/Documents/mColBERT/notebooks/experiments/testing-2/indexes/2bits'

### Search

In [6]:
with Run().context(RunConfig(experiment='testing-2')):
    print()
    searcher = Searcher(index=index_name, collection=collection)


[Apr 07, 19:35:24] #> Loading codec...
[Apr 07, 19:35:24] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 07, 19:35:24] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Apr 07, 19:35:24] #> Loading IVF...
[Apr 07, 19:35:24] #> Loading doclens...


100%|██████████| 1/1 [00:00<00:00, 971.58it/s]

[Apr 07, 19:35:24] #> Loading codes and residuals...



100%|██████████| 1/1 [00:00<00:00, 56.57it/s]


In [6]:
res_dict = {}
n=len(data)
for i, row in enumerate(data):
    res = searcher.search(row["query"], k=10)[0]
    res_dict[row["query"]] = res
    if not (i % 100):
        print("Finish searching {i}/{n}".format(i=i, n=n))


#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: . Which NFL team represented the AFC at Super Bowl 50?, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([     0,      3,      5, 130078, 186831,   7175,  33636,    297,     70,
            62,  27529,     99,   4265, 131793,    836,     32,      2, 250001,
        250001, 250001, 250001, 250001, 250001, 250001, 250001, 250001, 250001,
        250001, 250001, 250001, 250001, 250001])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

Finish searching 0/121945
Finish searching 100/121945
Finish searching 200/121945
Finish searching 300/121945
Finish searching 400/121945
Finish searching 500/121945
Finish searching 600/121945
Finish searching 700/121945
Finish searching 800/121945
Finish searching 900/121945
Finish searching 1000/121945
Finish searching 1100/121945
Finish searching 1200/121945
Finish 

### Calculate accuracy

In [9]:
sum_dict = {k1:{k2:0 for k2 in ["en", "de", "ko", "fr", "vi"]} for k1 in ["en", "de", "ko", "fr", "vi"]}

for i, row in enumerate(data):
    sum_dict[row["query_lang"]][row["context_lang"]] += 1

In [11]:
top20_dict = {k1:{k2:0 for k2 in ["en", "de", "ko", "fr", "vi"]} for k1 in ["en", "de", "ko", "fr", "vi"]}

res_dict = res

for k, v in res_dict.items():
    row = data[k]
    pids = [tup[0] for tup in v]
    passages = set(searcher.collection[pid] for pid in pids)
    if row["context"] in passages:
        top20_dict[row["query_lang"]][row["context_lang"]] += 1
        
result_dict = {}
for lang in sum_dict.keys():
    result_dict[lang] = {}
    for context_lang in sum_dict[lang].keys():
        result_dict[lang][context_lang] = top20_dict[lang][context_lang] / sum_dict[lang][context_lang]
result_dict


{'en': {'en': 0.885430463576159,
  'de': 0.009981851179673321,
  'ko': 0.021302390024246623,
  'fr': 0.014115432873274781,
  'vi': 0.023746701846965697},
 'de': {'en': 0.03235572374645222,
  'de': 0.8062613430127041,
  'ko': 0.020956009698649115,
  'fr': 0.014742785445420327,
  'vi': 0.026385224274406333},
 'ko': {'en': 0.031031220435193945,
  'de': 0.012704174228675136,
  'ko': 0.9494284724627641,
  'fr': 0.012233375156838143,
  'vi': 0.023746701846965697},
 'fr': {'en': 0.03282876064333018,
  'de': 0.013611615245009074,
  'ko': 0.023207481815032908,
  'fr': 0.8531994981179423,
  'vi': 0.025254428948360347},
 'vi': {'en': 0.03254493850520341,
  'de': 0.012250453720508167,
  'ko': 0.021995150675441635,
  'fr': 0.013801756587202008,
  'vi': 0.9227289860535243}}

In [14]:
top10_dict = {k1:{k2:0 for k2 in ["en", "de", "ko", "fr", "vi"]} for k1 in ["en", "de", "ko", "fr", "vi"]}

res_dict = res

for k, v in res_dict.items():
    row = data[k]
    pids = [tup[0] for tup in v]
    passages = set(searcher.collection[pid] for pid in pids[:10])
    if row["context"] in passages:
        top10_dict[row["query_lang"]][row["context_lang"]] += 1
        
result_dict = {}
for lang in sum_dict.keys():
    result_dict[lang] = {}
    for context_lang in sum_dict[lang].keys():
        result_dict[lang][context_lang] = top10_dict[lang][context_lang] / sum_dict[lang][context_lang]
result_dict


{'en': {'en': 0.8724692526017029,
  'de': 0.008620689655172414,
  'ko': 0.011776931070315206,
  'fr': 0.010037641154328732,
  'vi': 0.010931021485111195},
 'de': {'en': 0.02270577105014191,
  'de': 0.8021778584392014,
  'ko': 0.01195012123311396,
  'fr': 0.01066499372647428,
  'vi': 0.015454202789295138},
 'ko': {'en': 0.023084200567644278,
  'de': 0.008166969147005444,
  'ko': 0.9416349151368202,
  'fr': 0.009410288582183186,
  'vi': 0.013569543912551827},
 'fr': {'en': 0.02346263008514664,
  'de': 0.008166969147005444,
  'ko': 0.012123311395912712,
  'fr': 0.8331242158092849,
  'vi': 0.014323407463249152},
 'vi': {'en': 0.02336802270577105,
  'de': 0.009074410163339383,
  'ko': 0.012296501558711466,
  'fr': 0.010037641154328732,
  'vi': 0.9170750094232943}}

In [18]:
import time

#result of test wrong
querying_time = {k:0 for k in ["en", "de", "ko", "fr", "vi"]}
n = len(data)
for i, row in enumerate(data):
    start_time = time.time()
    res = searcher.search(row["query"], k=20)[0]
    end_time = time.time()
    
    querying_time[row["query_lang"]] += (end_time-start_time)
    if not (i % 100):
        print("Finish searching {i}/{n}".format(i=i, n=n))

Finish searching 0/121945
Finish searching 100/121945
Finish searching 200/121945
Finish searching 300/121945
Finish searching 400/121945
Finish searching 500/121945
Finish searching 600/121945
Finish searching 700/121945
Finish searching 800/121945
Finish searching 900/121945
Finish searching 1000/121945
Finish searching 1100/121945
Finish searching 1200/121945
Finish searching 1300/121945
Finish searching 1400/121945
Finish searching 1500/121945
Finish searching 1600/121945
Finish searching 1700/121945
Finish searching 1800/121945
Finish searching 1900/121945
Finish searching 2000/121945
Finish searching 2100/121945
Finish searching 2200/121945
Finish searching 2300/121945
Finish searching 2400/121945
Finish searching 2500/121945
Finish searching 2600/121945
Finish searching 2700/121945
Finish searching 2800/121945
Finish searching 2900/121945
Finish searching 3000/121945
Finish searching 3100/121945
Finish searching 3200/121945
Finish searching 3300/121945
Finish searching 3400/1219

In [23]:
for k in querying_time.keys():
    print(k, querying_time[k]/24389)


en 0.009541498641548593
de 0.009672094199865564
ko 0.009656541476249109
fr 0.009700295562365627
vi 0.009694301019936276
