# Imports

In [1]:
from datasets import load_dataset
from random import randint
import gzip
import pandas as pd
import requests
import os
from tqdm import tqdm
import csv
import json
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from transformers import AutoConfig, AutoModel, AutoTokenizer, RobertaTokenizer

os.environ["JAVA_HOME"] = "/usr/lib64/openjdk-21"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]
from pyserini.search.lucene import LuceneSearcher


import dotenv
dotenv.load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
os.getenv("JAVA_HOME")

'/etc/java-config-2/current-system-vm'

# Analysis
In the following section I'll try to find long-context tasks to work on. To do so, I'll download them and look at their size. 

## Qasper
Qasper is a QA dataset made from scientific articles. This makes it naturally long-context dataset, with 3 700 word on average. With some effort, I could transform it into a retrieval dataset.

In [2]:
qasper = load_dataset("allenai/qasper", split="train")

Downloading data: 100%|██████████| 14.4M/14.4M [00:02<00:00, 4.94MB/s]
Downloading data: 100%|██████████| 4.75M/4.75M [00:00<00:00, 4.89MB/s]
Downloading data: 100%|██████████| 7.07M/7.07M [00:01<00:00, 6.70MB/s]
Generating train split: 100%|██████████| 888/888 [00:00<00:00, 1910.02 examples/s]
Generating validation split: 100%|██████████| 281/281 [00:00<00:00, 2481.13 examples/s]
Generating test split: 100%|██████████| 416/416 [00:00<00:00, 2385.69 examples/s]


In [39]:
qasper[0]["qas"]

{'question': ['What is the seed lexicon?',
  'What are the results?',
  'How are relations used to propagate polarity?',
  'How big is the Japanese data?',
  'What are labels available in dataset for supervision?',
  'How big are improvements of supervszed learning results trained on smalled labeled data enhanced with proposed approach copared to basic approach?',
  'How does their model learn using mostly raw data?',
  'How big is seed lexicon used for training?',
  'How large is raw corpus used for training?'],
 'question_id': ['753990d0b621d390ed58f20c4d9e4f065f0dc672',
  '9d578ddccc27dd849244d632dd0f6bf27348ad81',
  '02e4bf719b1a504e385c35c6186742e720bcb281',
  '44c4bd6decc86f1091b5fc0728873d9324cdde4e',
  '86abeff85f3db79cf87a8c993e5e5aa61226dc98',
  'c029deb7f99756d2669abad0a349d917428e9c12',
  '39f8db10d949c6b477fa4b51e7c184016505884f',
  'd0bc782961567dc1dd7e074b621a6d6be44bb5b4',
  'a592498ba2fac994cd6fad7372836f0adb37e22a'],
 'nlp_background': ['two',
  'two',
  'two',
  'two

In [101]:
# Randomly sample nb_examples examples
#nb_examples = len(qasper)
avg_text_len = []
for i in range(len(qasper)):
    text_len = 0
    example = qasper[i]["full_text"]["paragraphs"]

    for par in example:
        text_len += sum([text.count(" ") + 1 for text in par])

    avg_text_len.append(text_len)

qasper_lengths = pd.Series(avg_text_len)
qasper_lengths.describe()

count      888.000000
mean      3718.864865
std       2254.151393
min          0.000000
25%       2387.500000
50%       3527.000000
75%       4353.000000
max      25910.000000
dtype: float64

## NQ

### HuggingFace
In this one I'll first check out Google Research's original Natual Question, instead of BEIR. Ok so Google Research's version is completely uncleaned, making it difficult to deal with. BEIR's version is split into passages. However, if you reconstruct the original article using the titles, you end up with a decently long corpus, roughly 1800 words long on average. 

In this section I'll check out the data format available on huggingface directly. 


In [52]:
nq = load_dataset("BeIR/nq", "corpus", token=os.getenv("HF_TOKEN"), trust_remote_code=True)

Downloading data: 100%|██████████| 285M/285M [00:05<00:00, 52.0MB/s] 
Downloading data: 100%|██████████| 285M/285M [00:05<00:00, 52.5MB/s] 
Downloading data: 100%|██████████| 217M/217M [00:04<00:00, 49.8MB/s] 
Generating corpus split: 100%|██████████| 2681468/2681468 [00:17<00:00, 150277.50 examples/s]


In [None]:
nq["corpus"]

{'_id': 'doc0',
 'title': 'Minority interest',
 'text': "In accounting, minority interest (or non-controlling interest) is the portion of a subsidiary corporation's stock that is not owned by the parent corporation. The magnitude of the minority interest in the subsidiary company is generally less than 50% of outstanding shares, or the corporation would generally cease to be a subsidiary of the parent.[1]"}

In [None]:
def grouping_func(input):
    _id = input["_id"].iloc[0]
    text = " ".join(input["text"])

    return pd.Series([_id, text], index=["_id", "text"])

nq_df = pd.DataFrame.from_dict(nq["corpus"])
grouped_nq_df = nq_df.groupby("title").apply(grouping_func)

In [103]:
(grouped_nq_df["text"].str.count(" ") + 1).describe()

count    108593.000000
mean       1872.470380
std        2435.234843
min           1.000000
25%         390.000000
50%        1007.000000
75%        2358.000000
max       68964.000000
Name: text, dtype: float64

### BEIR
Here I'll check out the data available on BEIR's repo. There are two downloadable NQ datasets: nq.zip and nq-train.zip. The train file is larger, somehow. 

In [33]:
out_dir = "/Tmp/lvpoellhuber/datasets/nq"

In [34]:
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
data_path = util.download_and_unzip(url, out_dir)

In [35]:
data_path

'/Tmp/lvpoellhuber/datasets/nq/nq'

In [None]:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

100%|██████████| 2681468/2681468 [00:07<00:00, 352184.27it/s]


: 

### Create a short dataset
I'm creating a short version of the test dataset, made up of 500 queries. To maintain the same balance as the regular corpus, I need to have a total 319146 documents in my corpus, containing the 500 queries' positive documents.

In [14]:
from random import randint
import csv

Calculate the number of positive documents.

In [7]:
qrel_len = 0
for qid in qrels:
    qrel_len += len(qrels[qid])
qrel_len

4201

Grab all the qrels, queries and the positive documents. 

In [None]:
short_corpus = {}
short_queries = {}
short_qrels = {}
i = 0
for qid in qrels:
    if i == 500: break
    query = queries[qid]
    docids = qrels[qid]

    documents = []
    for docid in docids:
        short_corpus[docid] = corpus[docid]
    
    short_qrels[qid] = docids
    short_queries[qid] = query

    i+=1

Compute the short corpus' sized, balanced like the original corpus. 

In [9]:
short_corpus_size = round(500*len(corpus)/qrel_len)
leftover_corpus_size = short_corpus_size - len(short_corpus)
leftover_corpus_size

318523

Grab a list of all the positive docids. 

In [10]:
positive_docids = short_corpus.keys()
positive_docids

dict_keys(['doc0', 'doc1', 'doc6', 'doc10', 'doc17', 'doc18', 'doc42', 'doc50', 'doc59', 'doc63', 'doc67', 'doc86', 'doc91', 'doc118', 'doc136', 'doc153', 'doc172', 'doc293', 'doc302', 'doc305', 'doc449', 'doc450', 'doc514', 'doc565', 'doc579', 'doc618', 'doc635', 'doc649', 'doc653', 'doc658', 'doc698', 'doc703', 'doc724', 'doc763', 'doc787', 'doc789', 'doc807', 'doc820', 'doc824', 'doc897', 'doc908', 'doc916', 'doc921', 'doc967', 'doc972', 'doc1010', 'doc1016', 'doc1026', 'doc1042', 'doc1070', 'doc1071', 'doc1100', 'doc1118', 'doc1154', 'doc1164', 'doc1187', 'doc1193', 'doc1215', 'doc1229', 'doc1239', 'doc1260', 'doc1404', 'doc1405', 'doc1407', 'doc1420', 'doc1432', 'doc1448', 'doc1468', 'doc1474', 'doc1486', 'doc1490', 'doc1541', 'doc1580', 'doc1599', 'doc1617', 'doc1631', 'doc1679', 'doc1684', 'doc1729', 'doc1744', 'doc1754', 'doc1771', 'doc1774', 'doc1782', 'doc1824', 'doc1927', 'doc2000', 'doc2030', 'doc2107', 'doc2127', 'doc2134', 'doc2151', 'doc2254', 'doc2262', 'doc2274', 'doc2

Randomly sample the corpus until we reach the desired size, ignoring positive docids. 

In [11]:
while (len(short_corpus) < short_corpus_size):
    random_int = randint(0, len(corpus))
    random_docid = f"doc{random_int}"

    if random_docid not in positive_docids:
        short_corpus[random_docid] = corpus[random_docid]


Save locally, to nq/nq-short. 

In [29]:
with open("/Tmp/lvpoellhuber/datasets/nq/nq-short/corpus.jsonl", "w", encoding="utf-8") as f:
    for doc_id, doc_content in short_corpus.items():
        json.dump({"_id": doc_id, **doc_content}, f)
        f.write("\n") 

In [30]:
with open("/Tmp/lvpoellhuber/datasets/nq/nq-short/queries.jsonl", "w", encoding="utf-8") as f:
    for qid, query in short_queries.items():
        json.dump({"_id": qid, "text": query, "metadata": {}}, f)
        f.write("\n") 

In [24]:

# Save as TSV file
with open("/Tmp/lvpoellhuber/datasets/nq/nq-short/qrels/test.tsv", "w", encoding="utf-8", newline="") as f:
    writer = csv.writer(f, delimiter="\t")  # Use tab as delimiter
    
    # Write header
    writer.writerow(["query-id", "corpus-id", "score"])
    
    # Write rows
    for qid, docids in short_qrels.items():
        for docid in docids:
            writer.writerow([qid, docid, 1])

Test it!

In [31]:
short_datapath = "/Tmp/lvpoellhuber/datasets/nq/nq-short"
corpus, queries, qrels = GenericDataLoader(data_folder=short_datapath).load(split="test")

100%|██████████| 319146/319146 [00:00<00:00, 358186.33it/s]


In [32]:
len(corpus)

319146

## Arguana
I'm not too sure what kind of dataset Arguana is, but it is damn long,with 7800 word on average!

In [104]:
arguana = load_dataset("BeIR/arguana", "corpus", token=os.getenv("HF_TOKEN"), trust_remote_code=True)

Downloading data: 100%|██████████| 5.09M/5.09M [00:00<00:00, 10.9MB/s]
Generating corpus split: 100%|██████████| 8674/8674 [00:00<00:00, 75092.92 examples/s]


In [105]:
arguana["corpus"]

Dataset({
    features: ['_id', 'title', 'text'],
    num_rows: 8674
})

In [106]:
def grouping_func(input):
    _id = input["_id"].iloc[0]
    text = " ".join(input["text"])

    return pd.Series([_id, text], index=["_id", "text"])

arguana_df = pd.DataFrame.from_dict(arguana["corpus"])
grouped_arguana_df = arguana_df.groupby("title").apply(grouping_func)

  grouped_arguana_df = arguana_df.groupby("title").apply(grouping_func)


In [109]:
(grouped_arguana_df["text"].str.count(" ") + 1).describe()

count    1.840000e+02
mean     7.854114e+03
std      7.444053e+04
min      4.480000e+02
25%      1.859250e+03
50%      2.268000e+03
75%      2.694000e+03
max      1.012057e+06
Name: text, dtype: float64

## Touche 2020

In [110]:
web = load_dataset("BeIR/webis-touche2020", "corpus", token=os.getenv("HF_TOKEN"), trust_remote_code=True)

Downloading data: 100%|██████████| 268M/268M [00:14<00:00, 18.3MB/s] 
Downloading data: 100%|██████████| 95.0M/95.0M [00:03<00:00, 24.8MB/s]
Generating corpus split: 100%|██████████| 382545/382545 [00:07<00:00, 49889.55 examples/s] 


In [111]:
web["corpus"]

Dataset({
    features: ['_id', 'title', 'text'],
    num_rows: 382545
})

In [112]:
def grouping_func(input):
    _id = input["_id"].iloc[0]
    text = " ".join(input["text"])

    return pd.Series([_id, text], index=["_id", "text"])

web_df = pd.DataFrame.from_dict(web["corpus"])
grouped_web_df = web_df.groupby("title").apply(grouping_func)

  grouped_web_df = web_df.groupby("title").apply(grouping_func)


In [113]:
(grouped_web_df["text"].str.count(" ") + 1).describe()

count     72121.000000
mean       1519.150095
std        4258.727238
min           1.000000
25%         159.000000
50%         533.000000
75%        1774.000000
max      729511.000000
Name: text, dtype: float64

## HotpotQA

While it's a very common QA dataset, it's not a very long one. 

In [115]:
hpqa = load_dataset("hotpotqa/hotpot_qa", "fullwiki", token=os.getenv("HF_TOKEN"), trust_remote_code=True)

Downloading data: 100%|██████████| 566M/566M [00:39<00:00, 14.2MB/s] 
Downloading data: 100%|██████████| 47.5M/47.5M [00:02<00:00, 18.9MB/s]
Downloading data: 100%|██████████| 46.2M/46.2M [00:01<00:00, 24.3MB/s]
Generating train split: 100%|██████████| 90447/90447 [01:05<00:00, 1373.21 examples/s]
Generating validation split: 100%|██████████| 7405/7405 [00:05<00:00, 1315.93 examples/s]
Generating test split: 100%|██████████| 7405/7405 [00:05<00:00, 1441.31 examples/s]


In [117]:
hpqa["train"]

Dataset({
    features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
    num_rows: 90447
})

In [136]:
len(hpqa["train"][0]["context"]["sentences"])
hpqa["train"][0]["context"]

{'title': ['Radio City (Indian radio station)',
  'History of Albanian football',
  'Echosmith',
  "Women's colleges in the Southern United States",
  'First Arthur County Courthouse and Jail',
  "Arthur's Magazine",
  '2014–15 Ukrainian Hockey Championship',
  'First for Women',
  'Freeway Complex Fire',
  'William Rast'],
 'sentences': [["Radio City is India's first private FM radio station and was started on 3 July 2001.",
   ' It broadcasts on 91.1 (earlier 91.0 in most cities) megahertz from Mumbai (where it was started in 2004), Bengaluru (started first in 2001), Lucknow and New Delhi (since 2003).',
   ' It plays Hindi, English and regional songs.',
   ' It was launched in Hyderabad in March 2006, in Chennai on 7 July 2006 and in Visakhapatnam October 2007.',
   ' Radio City recently forayed into New Media in May 2008 with the launch of a music portal - PlanetRadiocity.com that offers music related news, videos, songs, and other music-related features.',
   ' The Radio station c

## MS Marco Document Retrieval Track

With some effort, I can download and convert this to a document retrieval task. 

## Conclusion
In conclusion, many of the long-context tasks I found come from a modification of BEIR. These include NQ, Arguana and Touche 2020. Qasper is another great choice, the best being MS Marco. 


# Pre Processing
Most datasets will need some degree of preprocessing: whether they are QA or reranking, I'll need to reformulate the problem as a document ranking task. 

## MS Marco
This task will probably be the most demanding in terms of preprocessing: it's not even available on Datasets.

In [24]:
def download_url(url: str, save_path: str, chunk_size: int = 1024):
    """Download url with progress bar using tqdm
    https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads

    Args:
        url (str): downloadable url
        save_path (str): local path to save the downloaded file
        chunk_size (int, optional): chunking of files. Defaults to 1024.
    """
    r = requests.get(url, stream=True)
    total = int(r.headers.get("Content-Length", 0))
    with (
        open(save_path, "wb") as fd,
        tqdm(
            desc=save_path,
            total=total,
            unit="iB",
            unit_scale=True,
            unit_divisor=chunk_size,
        ) as bar,
    ):
        for data in r.iter_content(chunk_size=chunk_size):
            size = fd.write(data)
            bar.update(size)


In [None]:
files = ["msmarco-docs.trec.gz", "msmarco-docs-lookup.tsv.gz", "msmarco-doctrain-queries.tsv.gz", "msmarco-doctrain-top100.gz", "msmarco-doctrain-qrels.tsv.gz", "msmarco-docdev-queries.tsv.gz", "msmarco-docdev-top100.gz", "msmarco-docdev-qrels.tsv.gz", "docleaderboard-queries.tsv.gz", "docleaderboard-top100.gz"]
url = " /"
save_dir = "/Tmp/lvpoellhuber/datasets/msmarco-doc"

for download_file in files:
    file_url = url + download_file
    save_path = os.path.join(save_dir, download_file)
    
    if os.path.exists(save_path):
        print("Download already exists. ")
    else:
        download_url(url = file_url, save_path = save_path)

Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 
Download already exists. 


In [27]:
storage_dir = "/Tmp/lvpoellhuber/datasets/msmarco-doc/"

In [None]:
# The query string for each topicid is querystring[topicid]
querystring = {}
with gzip.open(storage_dir + "msmarco-doctrain-queries.tsv.gz", 'rt', encoding='utf8') as f:
    tsvreader = csv.reader(f, delimiter="\t")
    for [topicid, querystring_of_topicid] in tsvreader:
        querystring[topicid] = querystring_of_topicid

In [5]:
querystring

{'1185869': ')what was the immediate impact of the success of the manhattan project?',
 '1185868': '_________ justice is designed to repair the harm to victim, the community and the offender caused by the offender criminal act. question 19 options:',
 '1183785': 'elegxo meaning',
 '645590': 'what does physical medicine do',
 '186154': 'feeding rice cereal how many times per day',
 '457407': 'most dependable affordable cars',
 '441383': 'lithophile definition',
 '683408': 'what is a flail chest',
 '484187': 'put yourself on child support in texas',
 '666321': 'what happens in a wrist sprain',
 '564233': 'what are rhetorical topics',
 '733739': 'what is considered early fall',
 '1164798': 'what causes elevated nitrate levels in aquariums',
 '443797': 'lyme disease symptoms mood',
 '662502': 'what forms the epineurium',
 '1184679': 'an alpha helix is an example of which protein structure?',
 '14562': 'aggregate demand curve',
 '602162': 'what county is ackley iowa in',
 '708236': 'what is

In [28]:
# In the corpus tsv, each docid occurs at offset docoffset[docid]
docoffset = {}
with gzip.open(storage_dir + "msmarco-docs-lookup.tsv.gz", 'rt', encoding='utf8') as f:
    tsvreader = csv.reader(f, delimiter="\t")
    for [docid, _, offset] in tsvreader:
        docoffset[docid] = int(offset)

In [29]:
docoffset

{'D1555982': 0,
 'D301595': 1852,
 'D1359209': 7973,
 'D2147834': 23656,
 'D1568809': 31104,
 'D3233725': 32233,
 'D1150618': 35109,
 'D1885729': 37298,
 'D1311240': 38964,
 'D3048094': 57457,
 'D2342771': 59170,
 'D1840066': 65036,
 'D3085586': 68974,
 'D62203': 71304,
 'D2883971': 73803,
 'D1911483': 79057,
 'D1281784': 87676,
 'D2347744': 90681,
 'D560769': 92990,
 'D1050302': 97356,
 'D207561': 98507,
 'D15500': 103732,
 'D488904': 105214,
 'D330566': 111733,
 'D1270076': 117733,
 'D205553': 119686,
 'D2976645': 125084,
 'D2520478': 127718,
 'D2641659': 129109,
 'D3058536': 129523,
 'D2175291': 130327,
 'D2702544': 132001,
 'D3257085': 135273,
 'D495191': 137242,
 'D2380450': 139142,
 'D2435215': 141278,
 'D1247413': 145672,
 'D2442854': 149244,
 'D2173367': 160104,
 'D2111530': 168026,
 'D2378859': 172668,
 'D2435425': 176248,
 'D885257': 180175,
 'D2900586': 181779,
 'D2981241': 184778,
 'D7792': 193188,
 'D3261752': 204070,
 'D1163231': 210534,
 'D1683937': 213235,
 'D1256481': 

In [52]:
# For each topicid, the list of positive docids is qrel[topicid]
qrel = {}
with gzip.open(storage_dir + "msmarco-doctrain-qrels.tsv.gz", 'rt', encoding='utf8') as f:
    tsvreader = csv.reader(f, delimiter="\t")
    for item in tsvreader:
        topicid, _, docid, rel = item[0].split(" ")
        assert rel == "1"
        
        qrel[topicid] = docid


In [53]:
qrel

{'3': 'D312959',
 '5': 'D140227',
 '12': 'D213890',
 '15': 'D1033338',
 '16': 'D508131',
 '18': 'D2286511',
 '24': 'D69114',
 '26': 'D1350520',
 '31': 'D304123',
 '42': 'D1439360',
 '48': 'D322379',
 '51': 'D920249',
 '54': 'D361377',
 '55': 'D1366317',
 '60': 'D246777',
 '63': 'D1450821',
 '67': 'D494640',
 '68': 'D1896896',
 '69': 'D2155744',
 '70': 'D2845225',
 '76': 'D368756',
 '79': 'D2030028',
 '80': 'D134843',
 '91': 'D871894',
 '105': 'D272',
 '107': 'D507162',
 '108': 'D3224805',
 '114': 'D917179',
 '118': 'D114892',
 '125': 'D1862900',
 '141': 'D2240743',
 '142': 'D2736990',
 '144': 'D341873',
 '145': 'D178709',
 '150': 'D2984660',
 '152': 'D972484',
 '155': 'D1342286',
 '160': 'D2388083',
 '161': 'D2868710',
 '163': 'D2977021',
 '168': 'D2466979',
 '173': 'D2170117',
 '174': 'D3049251',
 '177': 'D2789372',
 '188': 'D986616',
 '190': 'D1125524',
 '202': 'D460288',
 '204': 'D697848',
 '205': 'D2176419',
 '212': 'D3005908',
 '216': 'D2548726',
 '226': 'D906850',
 '228': 'D26986

In [62]:
# This function seeks the docid in the TSV file. It's to make it easier to access since it's so huge. 
def getcontent(docid, f):
    """getcontent(docid, f) will get content for a given docid (a string) from filehandle f.
    The content has four tab-separated strings: docid, url, title, body.
    """

    f.seek(docoffset[docid])
    line = f.readline()
    assert line.startswith(docid + "\t"), \
        f"Looking for {docid}, found {line}"
    return line.rstrip()


In [68]:
with gzip.open(storage_dir + "msmarco-docs.tsv.gz", 'rt', encoding='utf8') as f:
    content = getcontent("D1366317", f)

content.count(" ")

431

# NQ

In [9]:
out_dir = "/Tmp/lvpoellhuber/datasets/nq"

In [10]:
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq-train.zip"
data_path = util.download_and_unzip(url, out_dir)

In [11]:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")

100%|██████████| 18060996/18060996 [00:52<00:00, 346164.52it/s]


In [15]:
pos_doc_len = {}
for qid in qrels:
    docs = qrels[qid]
    if str(len(docs)) not in pos_doc_len.keys():
        pos_doc_len[str(len(docs))] = 1
    else:
        pos_doc_len[str(len(docs))] += 1
        
pos_doc_len

{'1': 132803}

In [27]:
list(qrels["train0"].keys())[0]

'doc77'

In [9]:
documents = []
for qid, docs in qrels.items():
    for doc in docs:
        documents.append(doc)

unique_docs = pd.Series(documents).unique()
doc_set = set(list(unique_docs))

corpus_set = set(corpus.keys())

overlap = corpus_set.intersection(doc_set)

len(overlap) / len(corpus) * 100

0.7353027485305905

In [11]:
output_file = "/Tmp/lvpoellhuber/datasets/nq/corpus.jsonl"

with open(output_file, "w") as f:
    for doc_id, doc in corpus.items():
        json.dump({"id": doc_id, "title": doc["title"], "contents": doc["text"]}, f)
        f.write("\n")

Then run the following command: 

python -m pyserini.index -collection JsonCollection -input /Tmp/lvpoellhuber/datasets/nq -index /Tmp/lvpoellhuber/datasets/nq/bm25index -generator DefaultLuceneDocumentGenerator -threads 4 -storePositions -storeDocvectors -storeRaw


In [5]:
# Load the index
searcher = LuceneSearcher("/Tmp/lvpoellhuber/datasets/nq/bm25index")

# Search with BM25
query = 'when did richmond last play in a preliminary final'
docids = ["doc77"]
hits = searcher.search(query, k=6)  # Get top 6 results. 6 ensures we have at least 5 documents different from the right answer. 

# Print results
for i in range(len(hits)):
    print(f"Rank {i+1}: {hits[i].docid} | Score: {hits[i].score}")


Rank 1: doc553722 | Score: 13.992400169372559
Rank 2: doc10389121 | Score: 13.2746000289917
Rank 3: doc1243114 | Score: 12.989800453186035
Rank 4: doc16370059 | Score: 12.659700393676758
Rank 5: doc8207325 | Score: 12.659699440002441
Rank 6: doc16348311 | Score: 11.469799995422363


In [8]:
neg_docs = []
for i in range(len(hits)):
    if (hits[i].docid not in docids) & (len(neg_docs)<5):
        neg_docs.append(hits[i].docid)
neg_docs

['doc553722', 'doc10389121', 'doc1243114', 'doc16370059', 'doc8207325']

In [4]:
from model_biencoder import BiEncoder
import torch

# Parameter Comparisons

Comparing E5's automatic implementation and mine's. 

In [None]:

custom_model = BiEncoder(model_path = ("intfloat/e5-base-v2", "intfloat/e5-base-v2"), sep=" [SEP] ")
custom_q_model = custom_model.q_model.to("cpu")
base_model = AutoModel.from_pretrained("intfloat/e5-base-v2").to("cpu")

# Compare weights
custom_weights = custom_q_model.state_dict().keys()

for name, param in base_model.named_parameters():
    #print(name)
    if name in custom_weights:
        custom_param = custom_q_model.state_dict()[name]
        if torch.equal(param, custom_param):
            #print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")

    else:
        print(f"Layer {name} not found.")

Comparing DPR's implementation with mine. 

In [1]:
from transformers import DPRQuestionEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [7]:

custom_model = BiEncoder(model_path = ("facebook/dpr-question_encoder-single-nq-base", "facebook/dpr-ctx_encoder-single-nq-base"), sep=" [SEP] ")
custom_q_model = custom_model.q_model.to("cpu")
base_model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to("cpu")

# Compare weights
custom_weights = custom_q_model.state_dict().keys()

for name, param in base_model.named_parameters():
    #print(name)
    if name in custom_weights:
        custom_param = custom_q_model.state_dict()[name]
        if torch.equal(param, custom_param):
            #print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")

    else:
        print(f"Layer {name} not found.")

Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DPRQuestionEncoder were not initialized from the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base and are newly initialized: ['bert_model.embeddings.LayerNorm.bias', 'bert_model.embeddings.LayerNorm.weight', 'bert_model.embeddings.position_embeddings.weight', 'ber