<a href="https://colab.research.google.com/github/lsteimel/cs4100-final-scraper/blob/master/reranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reranking Retrieval Results

In this notebook, you will continue using the [Pyserini](http://pyserini.io/) library's indexing and retrieval models.  This time, however, you will get an initial set of retrieval results and then write your own reranking code to try to move relevant documents higher in the list.

As before, we start by installing the python interface. Since it calls the underlying Lucene search engine, which is written in Java, we make sure we point to an appropriate Java installation. If like Colab you don't have Java 21, uncomment the following code and run it, or whatever makes sense for your platform.

In [1]:
## Uncomment the following code to install Java 21 on Colab
!apt-get install openjdk-21-jre-headless -qq > /dev/null
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
!update-alternatives --set java /usr/lib/jvm/java-21-openjdk-amd64/bin/java
!java -version

openjdk version "21.0.6" 2025-01-21
OpenJDK Runtime Environment (build 21.0.6+7-Ubuntu-122.04.1)
OpenJDK 64-Bit Server VM (build 21.0.6+7-Ubuntu-122.04.1, mixed mode, sharing)


In [2]:
!pip install pyserini
# You can change this to gpu if you have one.
# It's a pyserini dependency, but we won't need it until the next assignment.
!pip install faiss-cpu

Collecting pyserini
  Downloading pyserini-0.44.0.tar.gz (195.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.3/195.3 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pyjnius>=1.6.0 (from pyserini)
  Downloading pyjnius-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting onnxruntime>=1.8.1 (from pyserini)
  Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting tiktoken>=0.4.0 (from pyserini)
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting coloredlogs (from onnxruntime>=1.8.1->pyserini)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from to

We initialize the searcher with a pre-built index for the Robust04 collection, which Pyserini will automatically download if it hasn't already. Note that the index takes up 1.6GB of disk.

In [3]:
from pyserini.search.lucene import LuceneSearcher

searcher = LuceneSearcher.from_prebuilt_index('robust04')

Downloading index at https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene/lucene-inverted.disk45.20240803.36f7e3.tar.gz...


lucene-inverted.disk45.20240803.36f7e3.tar.gz: 1.66GB [00:22, 79.9MB/s]                            


Now we can search for a query and inspect the results:

In [4]:
hits = searcher.search('black bear attacks', 1000)

# Prints the first 10 hits
for i in range(0, 10):
    print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}')

 1 LA092790-0015   7.06680
 2 LA081689-0039   6.89020
 3 FBIS4-16530     6.61630
 4 LA102589-0076   6.46450
 5 FT932-15491     6.25090
 6 FBIS3-12276     6.24630
 7 LA091090-0085   6.17030
 8 FT922-13519     6.04270
 9 LA052790-0205   5.94060
10 LA103089-0041   5.90650


The `IndexReaderUtils` class provides various methods to read the index directly. For example, we can fetch a raw document from the index given its `docid`:

In [5]:
from pyserini.index import LuceneIndexReader
from IPython.core.display import display, HTML

reader = LuceneIndexReader.from_prebuilt_index('robust04')

doc = reader.doc('LA092790-0015').raw()
display(HTML('<div style="font-family: Times New Roman; padding-bottom:10px">' + doc + '</div>'))

Note that the result is exactly the same as displaying the hit contents above. Given the raw text, we can obtain its analyzed form (i.e., tokenized, stemmed, stopwords removed, etc.). Here we show the first ten tokens:

In [6]:
analyzed = reader.analyze(doc)
analyzed[0:10]

['date',
 'p',
 'septemb',
 '27',
 '1990',
 'thursdai',
 'ventura',
 'counti',
 'edit',
 'p']

## Retrieving Initial Ranked Lists

We can load some standard evaluation sets such as Robust04, which contains 250 queries, or "topics" as the TREC conferences call them.

In [7]:
from pyserini.search import get_topics
topics = get_topics('robust04')
print(f'{len(topics)} queries total')

250 queries total


The topics are in a dictionary, whose keys are integers uniquely identifying each query. Each topic contains the following fields:

* `title`: TREC's term for the brief query a user might actually type;
* `description`: a longer form of the query in the form of a complete sentence; and
* `narrative`: a description of what the user is looking for and what kinds of results would be relevant or non-relevant.

In [8]:
topics[301]

{'narrative': 'A relevant document must as a minimum identify the organization and the type of illegal activity (e.g., Columbian cartel exporting cocaine). Vague references to international drug trade without identification of the organization(s) involved would not be relevant.',
 'description': 'Identify organizations that participate in international criminal activity, the activity, and, if possible, collaborating organizations and the countries involved.',
 'title': 'International Organized Crime'}

For the purpose of your experiments, we'll divide them into a development and test set.

In [9]:
dev_topics = {k:topics[k] for k in list(topics.keys())[:125]}
test_topics = {k:topics[k] for k in list(topics.keys())[125:]}

Now, we'll fetch the relevance judgments for the Robust04 queries, which TREC calls "qrels".

In [10]:
from urllib.request import urlopen

qfile = 'https://github.com/castorini/anserini-tools/blob/63ceeab1dd94c1221f29b931d868e8fab67cc25c/topics-and-qrels/qrels.robust04.txt?raw=true'
qrels = []
for line in urlopen(qfile):
  qid, round, docid, score = line.strip().split()
  qrels.append([int(qid), 0, docid.decode('UTF-8'), int(score)])
#qrels = [line.strip().split() for line in urlopen(qfile)]

Each record in the qrel contains four fields:

1. the numeric identifier of the query;
2. the round of relevance feedback, which is here always 0;
3. the identifier of a documennt that has been judged; and
4. the relevance score of that document.

In Robust04, all relevance judgments are binary, i.e., 1 or 0. Note that not all non-relevant documents are recorded. The qrel file only contains those documents the annotators actually looked at; the vast majority of documents in the collection have not been judged. In IR evaluation, we assume that unannotated documents are non-relevant.

In [11]:
qrels[0:10]

[[301, 0, 'FBIS3-10082', 1],
 [301, 0, 'FBIS3-10169', 0],
 [301, 0, 'FBIS3-10243', 1],
 [301, 0, 'FBIS3-10319', 0],
 [301, 0, 'FBIS3-10397', 1],
 [301, 0, 'FBIS3-10491', 1],
 [301, 0, 'FBIS3-10555', 0],
 [301, 0, 'FBIS3-10622', 1],
 [301, 0, 'FBIS3-10634', 0],
 [301, 0, 'FBIS3-10635', 0]]

We collect the top 1000 hists for both the dev and test sets. You

In [12]:
# Compute top-1000 lists for queries in test_topics
def topic_hits(searcher, topics, k=1000):
  hits = {}
  for topic, info in topics.items():
    print(topic, info['title'])
    hits[topic] = [(hit.docid, hit.score) for hit in searcher.search(info['title'], k)]
  return hits

dev_hits = topic_hits(searcher, dev_topics)
test_hits = topic_hits(searcher, test_topics)

350 Health and Computer Terminals
351 Falkland petroleum exploration
352 British Chunnel impact
353 Antarctica exploration
354 journalist risks
355 ocean remote sensing
356 postmenopausal estrogen Britain
357 territorial waters dispute
358 blood-alcohol fatalities
359 mutual fund predictors
360 drug legalization benefits
361 clothing sweatshops
362 human smuggling
363 transportation tunnel disasters
364 rabies
365 El Nino
366 commercial cyanide uses
367 piracy
368 in vitro fertilization
369 anorexia nervosa bulimia
370 food/drug laws
371 health insurance holistic
372 Native American casino
373 encryption equipment export
374 Nobel prize winners
375 hydrogen energy
376 World Court
377 cigar smoking
378 euro opposition
379 mainstreaming
380 obesity medical treatment
381 alternative medicine
382 hydrogen fuel automobiles
383 mental illness drugs
384 space station moon
385 hybrid fuel cars
386 teaching disabled children
387 radioactive waste
388 organic soil enhancement
389 illegal technol

## Evaluating Initial Ranked Lists



When reranking, an important metric is the _recall_ of the initial set of results. This tells us the upper bound or &ldquo;headroom&rdquo; on the improvements that reranking can achieve. If the recall in the initial ranked lists is too low, we know we need to optimize the initial retrieval model.

For this assignment, you will work with fixed initial ranked lists from pyserini's BM25 model, but it's still useful to see how much room there is for improvement during reranking.

As before, you should process the `qrels` data to find the relevant results for each query.

In [13]:
import numpy as np

def organize_relevance_judgments(judgments):
    relevance_data = {}
    for judgment in judgments:
        query_id = judgment[0]
        doc_id = judgment[2]
        relevance_score = judgment[3]

        relevance_data.setdefault(query_id, {})
        relevance_data[query_id][doc_id] = relevance_score
    return relevance_data

def calculate_recall_metrics(hits_data, relevance_data):
    recall_metrics = {}
    for query_id, retrieved_documents in hits_data.items():
        if query_id in relevance_data:
            relevant_document_set = {doc_id for doc_id, score in relevance_data[query_id].items() if score > 0}
            retrieved_document_ids = {doc_id for doc_id, _ in retrieved_documents}
            retrieved_relevant = retrieved_document_ids.intersection(relevant_document_set)

            total_relevant = len(relevant_document_set)
            recall_value = len(retrieved_relevant) / total_relevant if total_relevant > 0 else 0.0
            recall_metrics[query_id] = recall_value
    return recall_metrics

def calculate_average_metric(metrics_dict):
    return sum(metrics_dict.values()) / len(metrics_dict) if metrics_dict else 0


relevance_data = organize_relevance_judgments(qrels)

development_recall_metrics = calculate_recall_metrics(dev_hits, relevance_data)
test_recall_metrics = calculate_recall_metrics(test_hits, relevance_data)

dev = calculate_average_metric(development_recall_metrics)
test = calculate_average_metric(test_recall_metrics)

print(f"Recall@1000, dev queries: {dev:.4f}")
print(f"Recall@1000, test queries: {test:.4f}")

Recall@1000, dev queries: 0.6985
Recall@1000, test queries: 0.6993


For a given set of top-1000 lists, Recall@1000 will not change after reranking. What will change are ranking-based metrics like MAP and NDCG. You should compute MAP@1000 for the initial `dev_hits` and `test_hits` data.

In [14]:
def compute_average_precision(result_documents, ground_truth_relevant, cutoff=1000):
    if len(ground_truth_relevant) == 0:
        return 0.0

    docs_retrieved = [doc_tuple[0] for doc_tuple in result_documents[:cutoff]]

    relevant_found = 0
    precision_sum = 0.0

    for position, doc_id in enumerate(docs_retrieved, 1):
        if doc_id in ground_truth_relevant:
            relevant_found += 1
            current_precision = relevant_found / position
            precision_sum += current_precision

    return precision_sum / len(ground_truth_relevant)


def evaluate_mean_average_precision(search_results, relevance_judgments, max_depth=1000):
    precision_values = []

    for query_id, ranked_docs in search_results.items():
        if query_id not in relevance_judgments:
            continue

        relevant_set = set()
        for document_id, relevance in relevance_judgments[query_id].items():
            if relevance > 0:
                relevant_set.add(document_id)

        query_ap = compute_average_precision(ranked_docs, relevant_set, max_depth)
        precision_values.append(query_ap)

    if not precision_values:
        return 0.0

    return np.mean(precision_values)


dev_map = evaluate_mean_average_precision(
    dev_hits, relevance_data, max_depth=1000)
test_map = evaluate_mean_average_precision(
    test_hits, relevance_data, max_depth=1000)

print(f"MAP@1000 for dev queries (initial): {dev_map:.4f}")
print(f"MAP@1000 for test queries (initial): {test_map:.4f}")

MAP@1000 for dev queries (initial): 0.2426
MAP@1000 for test queries (initial): 0.2637


## Reranking Search Results

In this final part of the assignment, you should implement a ranking function that, hopefully, improves on the baseline BM25 ranking. You may use the BM25 score for each document as input, as well as the query, of course, and any other properties of the documents you look up with the `reader` object.  After computing a new score for each candidate, re-sort the top-1000 results by your model's score.

You may use anything you've learned in this course---or in another course---to build your ranking function. For example, you could implement pseudo-relevance feedback or a relevance model, which would treat the top of each ranked list (e.g., the top 100) as if it were truly relevant and retrain model parameters. You could tune different BM25, query likelihood, or sequential dependence model parameters. You could try to learn different weights or embeddings for different fields in documents. You could use implementations of transformer language models such as BERT or SentenceBERT to score the compatibility of queries and documents. To be clear, you don't have to any of these approaches; you are free to try whatever ideas you like.

If your reranking model has tunable parameters, you should tune them on the `dev_hits` set. In the end, you will also evaluate MAP@1000 on the `test_hits` set.

**TODO**: Put any explanation of your reranking function here.

In [15]:
from collections import Counter
import math

K_PSEUDO_RELEVANT = 10
NUM_EXPANSION_TERMS = 20


def get_rocchio_expanded_query(topic_id, initial_hits, reader, topics, k_pseudo, num_expansion_terms):
    original_query_text = topics[topic_id]['title']
    analyzed_original_query = reader.analyze(original_query_text)

    k_pseudo = min(k_pseudo, len(initial_hits))
    if k_pseudo == 0:
        return original_query_text

    pseudo_relevant_docids = [doc_id for doc_id,
                              score in initial_hits[:k_pseudo]]

    term_freqs = Counter()
    for doc_id in pseudo_relevant_docids:
        try:
            doc_vector = reader.get_document_vector(doc_id)
            if doc_vector:
                term_freqs.update(doc_vector)
        except Exception as e:
            pass

    if not term_freqs:
        return original_query_text

    expansion_terms = []
    for term, freq in term_freqs.most_common():
        if term not in analyzed_original_query:
            expansion_terms.append(term)
            if len(expansion_terms) >= num_expansion_terms:
                break


    expanded_query_text = original_query_text + " " + " ".join(expansion_terms)
    return expanded_query_text


def rerank_hits(searcher, reader, topics, initial_hits_map, k_pseudo, num_expansion, k_rerank=1000):
    reranked_hits_map = {}
    original_doc_ids_map = {topic_id: {doc_id for doc_id, score in hits}
                            for topic_id, hits in initial_hits_map.items()}

    total_topics = len(initial_hits_map)
    processed_topics = 0

    for topic_id, initial_hits in initial_hits_map.items():
        processed_topics += 1

        if not initial_hits:
            reranked_hits_map[topic_id] = []
            continue

        expanded_query = get_rocchio_expanded_query(
            topic_id, initial_hits, reader, topics, k_pseudo, num_expansion)

        try:
            new_search_results = searcher.search(
                expanded_query, k=int(k_rerank * 1.5))
        except Exception as e:
            print(
                f"Error searching expanded query for topic {topic_id}: {e}. Falling back to initial ranking.")

            reranked_hits_map[topic_id] = initial_hits
            continue

        original_doc_ids = original_doc_ids_map[topic_id]
        final_reranked_dict = {}
        found_original_ids = set()

        for hit in new_search_results:
            if hit.docid in original_doc_ids:
                final_reranked_dict[hit.docid] = hit.score
                found_original_ids.add(hit.docid)

        min_score = -1e9
        missed_docs = original_doc_ids - found_original_ids
        for doc_id in missed_docs:
            final_reranked_dict[doc_id] = min_score

        final_reranked_list = list(final_reranked_dict.items())
        final_reranked_list.sort(key=lambda x: x[1], reverse=True)


        reranked_hits_map[topic_id] = final_reranked_list[:k_rerank]

    return reranked_hits_map

best_k_pseudo = K_PSEUDO_RELEVANT
best_num_expansion = NUM_EXPANSION_TERMS

reranked_dev_hits = rerank_hits(
    searcher, reader, dev_topics, dev_hits, best_k_pseudo, best_num_expansion)
reranked_test_hits = rerank_hits(
    searcher, reader, test_topics, test_hits, best_k_pseudo, best_num_expansion)

In [16]:

map_dev_reranked = evaluate_mean_average_precision  (
    reranked_dev_hits, relevance_data, max_depth=1000)
print(f"MAP@1000 for dev queries (reranked): {map_dev_reranked:.4f}")

map_test_reranked = evaluate_mean_average_precision(
    reranked_test_hits, relevance_data, max_depth=1000)
print(f"MAP@1000 for test queries (reranked): {map_test_reranked:.4f}")

print(f"MAP@1000 for dev queries (initial):  {dev_map:.4f}")
print(f"MAP@1000 for test queries (initial): {test_map:.4f}")

dev_improvement = map_dev_reranked - dev_map
test_improvement = map_test_reranked - test_map
print(f"Improvement on dev set:  {dev_improvement:+.4f}")
print(f"Improvement on test set: {test_improvement:+.4f}")

MAP@1000 for dev queries (reranked): 0.2287
MAP@1000 for test queries (reranked): 0.2151
MAP@1000 for dev queries (initial):  0.2426
MAP@1000 for test queries (initial): 0.2637
Improvement on dev set:  -0.0139
Improvement on test set: -0.0486
