<h2 style="color: yellow;">This document only computes the latency for the retrieval stage</h2>

<h3>Before running</h3>
Make sure you are running this notebook using the "sprint_env" conda environment to ensure a smooth operation.

The cell below checks that.

In [2]:
import subprocess

result = subprocess.check_output("conda info | grep 'active environment'", shell=True)
result_str = result.decode('utf-8').strip().split(" : ")[1]

expected_value = "sprint_env"
assert result_str == expected_value, f"Expected value: {expected_value}; Actual value: {result_str}"

<h3>Set-up and the required functions</h3>

In [3]:
import inspect
import os

from tqdm import tqdm
from transformers import AutoTokenizer
from typing import List

from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer
from pyserini.output_writer import OutputFormat, get_output_writer
from pyserini.pyclass import autoclass
from pyserini.query_iterator import get_query_iterator, TopicsFormat
from pyserini.search import JDisjunctionMaxQueryGenerator
from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher
from pyserini.search.lucene.reranker import (
    ClassifierType,
    PseudoRelevanceClassifierReranker,
)
from sprint_toolkit.inference import encoder_builders
from pathlib import Path

import io
from contextlib import redirect_stdout


  from .autonotebook import tqdm as notebook_tqdm


In [28]:
def search(query_iterator, topics, args, tokenizer, searcher, fields, query_generator):
    batch_topics = list()
    batch_topic_ids = list()
    for index, (topic_id, text) in enumerate(
        tqdm(query_iterator, total=len(topics.keys()))
    ):
        if args.tokenizer != None:
            toks = tokenizer.tokenize(text)
            text = " "
            text = text.join(toks)
        if args.batch_size <= 1 and args.threads <= 1:
            if args.impact:
                hits = searcher.search(text, args.hits, fields=fields)
            else:
                hits = searcher.search(
                    text, args.hits, query_generator=query_generator, fields=fields
                )
            results = [(topic_id, hits)]
        else:
            batch_topic_ids.append(str(topic_id))
            batch_topics.append(text)
            if (index + 1) % args.batch_size == 0 or index == len(topics.keys()) - 1:
                if args.impact:
                    results = searcher.batch_search(
                        batch_topics,
                        batch_topic_ids,
                        args.hits,
                        args.threads,
                        fields=fields,
                    )
                else:
                    results = searcher.batch_search(
                        batch_topics,
                        batch_topic_ids,
                        args.hits,
                        args.threads,
                        query_generator=query_generator,
                        fields=fields,
                    )
                results = [(id_, results[id_]) for id_ in batch_topic_ids]
                batch_topic_ids.clear()
                batch_topics.clear()
            else:
                continue
        results.clear()

In [29]:
# Packaging pyserini.search into both callable and entry point
# With reference to https://github.com/castorini/pyserini/blob/master/pyserini/search/__main__.py
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

def set_bm25_parameters(searcher, index, k1=None, b=None):
    if k1 is not None or b is not None:
        if k1 is None or b is None:
            print("Must set *both* k1 and b for BM25!")
            exit()
        print(f"Setting BM25 parameters: k1={k1}, b={b}")
        searcher.set_bm25(k1, b)
    else:
        # Automatically set bm25 parameters based on known index:
        if index == "msmarco-passage" or index == "msmarco-passage-slim":
            print("MS MARCO passage: setting k1=0.82, b=0.68")
            searcher.set_bm25(0.82, 0.68)
        elif index == "msmarco-passage-expanded":
            print("MS MARCO passage w/ doc2query-T5 expansion: setting k1=2.18, b=0.86")
            searcher.set_bm25(2.18, 0.86)
        elif index == "msmarco-doc" or index == "msmarco-doc-slim":
            print("MS MARCO doc: setting k1=4.46, b=0.82")
            searcher.set_bm25(4.46, 0.82)
        elif (
            index == "msmarco-doc-per-passage"
            or index == "msmarco-doc-per-passage-slim"
        ):
            print("MS MARCO doc, per passage: setting k1=2.16, b=0.61")
            searcher.set_bm25(2.16, 0.61)
        elif index == "msmarco-doc-expanded-per-doc":
            print(
                "MS MARCO doc w/ doc2query-T5 (per doc) expansion: setting k1=4.68, b=0.87"
            )
            searcher.set_bm25(4.68, 0.87)
        elif index == "msmarco-doc-expanded-per-passage":
            print(
                "MS MARCO doc w/ doc2query-T5 (per passage) expansion: setting k1=2.56, b=0.59"
            )
            searcher.set_bm25(2.56, 0.59)


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

def run(
    topics: str,
    index: str,
    output: str,
    topics_format: str = TopicsFormat.DEFAULT.value,
    output_format: str = OutputFormat.TREC.value,
    max_passage: bool = False,
    max_passage_hits: int = 100,
    max_passage_delimiter: str = "#",
    batch_size: int = 1,
    threads: int = 1,
    remove_duplicates: bool = False,
    hits: int = 1000,
    impact: bool = False,
    encoder_name: str = None,
    ckpt_name: str = None,
    tokenizer: str = None,
    min_idf: int = -1,
    bm25: bool = False,
    rm3: bool = False,
    qld: bool = False,
    language: str = "en",
    prcl: List[ClassifierType] = [],
    k1: float = None,
    b: float = None,
    vectorizer: str = None,
    fields=None,
    stopwords: str = None,
    r: int = 10,
    n: int = 100,
    alpha: float = 0.5,
    dismax: bool = False,
    tiebreaker: float = 0.0,
    model: str = None,
    dataset: str = None
):
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    args = AttrDict(dict(zip(args, map(lambda arg: values[arg], args))))

    query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format))
    prev_topics = topics
    topics = query_iterator.topics

    if not args.impact:
        if os.path.exists(args.index):
            # create searcher from index directory
            searcher = LuceneSearcher(args.index)
        else:
            # create searcher from prebuilt index name
            searcher = LuceneSearcher.from_prebuilt_index(args.index)
    elif args.impact:
        ######## build query encoder by encoder name and checkpoint name ##########
        encoder_builder = encoder_builders.get_builder(encoder_name, ckpt_name, "query")
        args.encoder = encoder_builder()  # By default this will use CPU
        ###########################################################################
        if os.path.exists(args.index):
            searcher = LuceneImpactSearcher(args.index, args.encoder, args.min_idf)
        else:
            searcher = LuceneImpactSearcher.from_prebuilt_index(
                args.index, args.encoder, args.min_idf
            )

    if args.language != "en":
        searcher.set_language(args.language)

    if not searcher:
        exit()

    search_rankers = []

    if args.qld:
        search_rankers.append("qld")
        searcher.set_qld()
    elif args.bm25:
        search_rankers.append("bm25")
        set_bm25_parameters(searcher, args.index, args.k1, args.b)

    if args.rm3:
        search_rankers.append("rm3")
        searcher.set_rm3()

    fields = dict()
    if args.fields:
        fields = dict([pair.split("=") for pair in args.fields])
        print(f"Searching over fields: {fields}")

    query_generator = None
    if args.dismax:
        query_generator = JDisjunctionMaxQueryGenerator(args.tiebreaker)
        print(f"Using dismax query generator with tiebreaker={args.tiebreaker}")

    if args.tokenizer != None:
        analyzer = JWhiteSpaceAnalyzer()
        searcher.set_analyzer(analyzer)
        print(f"Using whitespace analyzer because of pretokenized topics")
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
        print(f"Using {args.tokenizer} to preprocess topics")

    if args.stopwords:
        analyzer = JDefaultEnglishAnalyzer.fromArguments(
            "porter", False, args.stopwords
        )
        searcher.set_analyzer(analyzer)
        print(f"Using custom stopwords={args.stopwords}")

    # get re-ranker
    use_prcl = args.prcl and len(args.prcl) > 0 and args.alpha > 0
    if use_prcl is True:
        ranker = PseudoRelevanceClassifierReranker(
            searcher.index_dir,
            args.vectorizer,
            args.prcl,
            r=args.r,
            n=args.n,
            alpha=args.alpha,
        )

    # build output path
    
    #### 
    # This branch is not visited
    ####
    output_path = args.output
    if output_path is None:
        if use_prcl is True:
            clf_rankers = []
            for t in args.prcl:
                if t == ClassifierType.LR:
                    clf_rankers.append("lr")
                elif t == ClassifierType.SVM:
                    clf_rankers.append("svm")

            r_str = f"prcl.r_{args.r}"
            n_str = f"prcl.n_{args.n}"
            a_str = f"prcl.alpha_{args.alpha}"
            clf_str = "prcl_" + "+".join(clf_rankers)
            tokens1 = ["run", args.topics, "+".join(search_rankers)]
            tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str]
            output_path = ".".join(tokens1) + "-" + "-".join(tokens2) + ".txt"
        else:
            tokens = ["run", args.topics, "+".join(search_rankers), "txt"]
            output_path = ".".join(tokens)

    print(f"Running {args.topics} topics, saving to {output_path}...")

    ####
    # Timing the search
    ####
    f = io.StringIO()
    with redirect_stdout(f):
        %timeit -n 1 -r 5 search(query_iterator, topics, args, tokenizer, searcher, fields, query_generator)

    timeit_result = f.getvalue()
    number_of_queries = None

    try:
        elapsed_time = float(timeit_result.split(" s +- ")[0]) * 1000

        with open(prev_topics, "r") as file:
            number_of_queries = len([line for line in file])

        with open(Path(f"latency/{model}/{dataset}.txt").resolve().as_posix(), "a") as file:
            file.write('Latency for retrieval:\n')
            file.write(f"Timeit result: {timeit_result}")
            file.write(f"There are {number_of_queries} queries in the dataset.\n")
            file.write(f"This comes to {round(float(elapsed_time / number_of_queries), 2)} ms/query\n\n")
    except:
        elapsed_time = float(timeit_result.split(" ms +- ")[0])

        with open(prev_topics, "r") as file:
            number_of_queries = len([line for line in file])

        with open(Path(f"latency/{model}/{dataset}.txt").resolve().as_posix(), "a") as file:
            file.write('Latency for retrieval:\n')
            file.write(f"Timeit result: {timeit_result}")
            file.write(f"There are {number_of_queries} queries in the dataset.\n")
            file.write(f"This comes to {round(float(elapsed_time / number_of_queries), 2)} ms/query\n\n")
    
    print(f"{__name__}: Done")

<h2>Evaluation</h2>

In [38]:
topic_split = "test"
###
dataset = "nfcorpus"
###
model='unicoil'
# model='deepimpact'
ckpt_name='castorini/unicoil-noexp-msmarco-passage' 
# ckpt_name='/Users/cciacu/Desktop/school/rp/experiments/deepimpact-bert-base' 
###

topics=Path(f"queries/queries_{model}_{dataset}_{topic_split}.tsv").resolve().as_posix()
index=Path(f"sparse_indexes/sparse_index_{dataset}_{model}").resolve().as_posix()
output=Path(f"sprint_searches/search-{dataset}-{model}-{topic_split}.tsv").resolve().as_posix()

encoder_name=model
impact=True
hits=1000 + 1
batch_size=128
threads=12
output_format='trec'
min_idf=-1

In [39]:
run(topics=topics, encoder_name=encoder_name, ckpt_name=ckpt_name, index=index,
    output=output, impact=impact, hits=hits, batch_size=batch_size, threads=threads,
    output_format=output_format, min_idf=min_idf, model=model, dataset=dataset)

  0%|          | 0/323 [00:00<?, ?it/s]

Running /Users/cciacu/Desktop/school/rp/experiments/queries/queries_unicoil_nfcorpus_test.tsv topics, saving to /Users/cciacu/Desktop/school/rp/experiments/sprint_searches/search-nfcorpus-unicoil-test.tsv...


100%|██████████| 323/323 [00:13<00:00, 23.25it/s]
100%|██████████| 323/323 [00:14<00:00, 23.07it/s]
100%|██████████| 323/323 [00:13<00:00, 23.91it/s]
100%|██████████| 323/323 [00:12<00:00, 24.87it/s]
100%|██████████| 323/323 [00:14<00:00, 22.63it/s]

__main__: Done



