# Introduction
## Retrieval-Augmented Generation (RAG) under Knowledge Edits

### üìò Overview
This notebook implements a **Retrieval-Augmented Generation (RAG)** system that reasons correctly under **knowledge edits** ‚Äî when a single factual statement in a knowledge base (say [WIKIDATA](https://www.wikidata.org/wiki/Wikidata:Main_Page)) is modified.

For example, if the original fact was:  
> ‚ÄúLeonardo DiCaprio was born in the United States.‚Äù  
and it is **edited** to:  
> ‚ÄúLeonardo DiCaprio was born in Syria.‚Äù  

then a query like  
> ‚ÄúWhat is the currency of the country where Leonardo DiCaprio was born?‚Äù  
 will produce **‚ÄúSyrian Pound‚Äù**, not **‚ÄúUS Dollar.‚Äù**

The system learns to retrieve the **modified fact** and uses it to answer downstream reasoning questions that depend on it.

## üß± Section 1 ‚Äî Environment Setup, Hugging Face Authentication, and Testing Enviroment

In this section, we install all the required Python libraries and configure the Hugging Face access to load pretrained models for retrieval and generation.

### üß© Step 1 ‚Äì Install Required Packages

We‚Äôll install a stable environment:

- **`vllm == 0.11.0`** ‚Äì fast local inference for prompt-based evaluation  
- **`transformers`, `tokenizers`, `accelerate`, `huggingface_hub`** ‚Äì for working with pretrained LLMs  
- **`rank_bm25`** ‚Äì lexical retrieval baseline  
- **`gdown`, `json_repair`** ‚Äì for dataset download + JSON cleanup  

> üí° The first two commands uninstall and purge cached packages to prevent version conflicts.  
> If the kernel suggests a restart after installation, **accept it** before continuing.

Run the next code cell and wait for it to complete successfully.

In [None]:
!pip uninstall -y transformers tokenizers accelerate huggingface_hub
!pip cache purge
!pip install "vllm==0.11.0" "transformers>=4.51.0" "tokenizers>=0.21.0" "accelerate>=1.0.0" "huggingface_hub>=0.26.0" "rank_bm25" "gdown" "json_repair"

Found existing installation: transformers 4.57.1
Uninstalling transformers-4.57.1:
  Successfully uninstalled transformers-4.57.1
Found existing installation: tokenizers 0.22.1
Uninstalling tokenizers-0.22.1:
  Successfully uninstalled tokenizers-0.22.1
Found existing installation: accelerate 1.11.0
Uninstalling accelerate-1.11.0:
  Successfully uninstalled accelerate-1.11.0
Found existing installation: huggingface-hub 0.36.0
Uninstalling huggingface-hub-0.36.0:
  Successfully uninstalled huggingface-hub-0.36.0
[0mFiles removed: 0
Collecting vllm==0.11.0
  Downloading vllm-0.11.0-cp38-abi3-manylinux1_x86_64.whl.metadata (17 kB)
Collecting transformers>=4.51.0
  Downloading transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.0/44.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers>=0.21.0
  Downloading tokenizers-0.


### üß© Step 2 ‚Äì Authenticate with Hugging Face

Some models (e.g., `sentence-transformers` or `Qwen` series) require a valid Hugging Face access token.

1. Go to **[https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)**  
2. Create or copy a token with at least **‚Äúread‚Äù** permissions  
3. Run the login cell below and paste the token when prompted  

> ‚ö†Ô∏è The command `huggingface-cli login` is deprecated.  
> we can still use it, or switch to the modern equivalents:
> ```bash
> !huggingface-cli login --token <YOUR_TOKEN>
> ```
> or
> ```bash
> !hf auth login
> ```

In [31]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `hf auth whoami` to get more information or `hf auth logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
The token `ad

### üß© Step 3 ‚Äì Import Packages and Verify Installation

Now that all dependencies are installed, let‚Äôs import the required packages and confirm that the environment is correctly configured.

This step ensures:
- All essential libraries (`torch`, `transformers`, `vllm`, etc.) were installed successfully  
- Their versions match the expected setup for this assignment  

In [2]:
import sys
import os
import json
from json_repair import repair_json
import torch
import accelerate
import huggingface_hub
import tokenizers
import numpy as np
import gdown
from rank_bm25 import BM25Okapi
import sentence_transformers
from sentence_transformers import SentenceTransformer, CrossEncoder
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import pathlib
import gdown
import random
from tqdm import tqdm
import gc
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"  # required before importing vLLM
import vllm
from vllm import LLM, SamplingParams
random.seed(42)

INFO 11-11 02:42:31 [__init__.py:216] Automatically detected platform cuda.


In [3]:
# --- Expected versions ---
expected = {
    "python": "3.12",
    "torch": "2.8.0+cu126",
    "transformers": "4.57.1",
    "accelerate": "1.11.0",
    "huggingface_hub": "0.36.0",
    "tokenizers": "0.22.1",
    "sentence_transformers": "5.1.2",
    "numpy": "2.0.2",
    "vllm": "0.11.0",
    "gdown": "5.2.0",
}

assert sys.version.startswith(expected["python"]), f"Python version mismatch: {sys.version}"
assert torch.__version__ == expected["torch"], f"Torch version mismatch: {torch.__version__}"
assert transformers.__version__ == expected["transformers"], f"Transformers version mismatch: {transformers.__version__}"
assert accelerate.__version__ == expected["accelerate"], f"Accelerate version mismatch: {accelerate.__version__}"
assert huggingface_hub.__version__ == expected["huggingface_hub"], f"HuggingFace Hub version mismatch: {huggingface_hub.__version__}"
assert tokenizers.__version__ == expected["tokenizers"], f"Tokenizers version mismatch: {tokenizers.__version__}"
assert sentence_transformers.__version__ == expected["sentence_transformers"], f"SentenceTransformers version mismatch: {sentence_transformers.__version__}"
assert np.__version__ == expected["numpy"], f"Numpy version mismatch: {np.__version__}"
assert vllm.__version__ == expected["vllm"], f"vLLM version mismatch: {vllm.__version__}"
assert gdown.__version__ == expected["gdown"], f"gdown version mismatch: {gdown.__version__}"

print("‚úÖ All package versions match expected values.")

‚úÖ All package versions match expected values.


## üìÇ Section 2 ‚Äî Dataset Overview and Loading

In this section, we‚Äôll load and inspect the dataset.
Each example represents a **modified fact** paired with **queries** that require reasoning based on this change.

### üß† Dataset Description

The dataset has two main components:

1. **üìò Edited Facts**  
   - Each document encodes one modified statement (e.g., *‚ÄúLeonardo DiCaprio was born in Syria.‚Äù*).  
   - These serve as the **knowledge base** for retrieval.

2. **‚ùì Queries**  
   - Each query asks a reasoning question affected by the edited fact.  
   - Each includes:
     - `question`: the query text  
     - `choices`: six options (A‚ÄìF)  
     - `correct_choice`: correct option key  
     - `correct_document_ids`: IDs of the relevant edited fact(s)

> üí° **Goal:** Retrieve the correct edited fact and answer according to the **modified world**, not the original one.

### üß© Step 1 ‚Äì Download, Save, and Inspect the Dataset

1. **Download** the dataset files using the `gdrive_get()` function implemented below.  
2. **Save** them locally in the working directory at `/content/datasets/`.  
3. **Load** the dataset using `json.load()` or an equivalent utility.  
4. **Inspect a few examples**, each containing:
   - `documents` ‚Äî verify each includes a single edited fact  
   - `queries` ‚Äî ensure they reference valid `correct_document_ids`  
   - The `documents` remain the same across all `queries` within a given data split (`val`, `test`).  
5. The dataset has two splits: **validation (`val`)** and **test (`test`)**.  
   - For the **test** split, `correct_choice` is always set to `"A"` and `correct_document_ids` to `0`.

In [4]:
def gdrive_get(url, out):
    fid = re.search(r'(?:/d/|id=)([-\w]{10,})', url).group(1)
    p = pathlib.Path(out); p.parent.mkdir(parents=True, exist_ok=True)
    gdown.download(id=fid, output=str(p), quiet=False)
    if p.read_bytes()[:32].lstrip().startswith((b'<!DOCTYPE html', b'<html')):
        gdown.download(url=f"https://drive.google.com/uc?id={fid}&export=download",
                       output=str(p), quiet=False)

In [5]:
files = {
    "datasets/val_dataset.json":  "https://drive.google.com/file/d/1fbRNGPpNebv8lDJGtjg0kvEXoYs5R4XC/view?usp=sharing",
    "datasets/test_dataset.json": "https://drive.google.com/file/d/1BnB3cAakw5oB1z7yzt8HKXF9c5V3rTWw/view?usp=sharing",
}

datasets = {}

print("\nüì¶ Downloading & Loading Datasets...\n" + "="*60)
for out_path, url in files.items():
    out = pathlib.Path(out_path)
    split_name = "val" if "val" in out.stem else "test"

    gdrive_get(url, out)

    with open(out, "r") as f:
        dataset = json.load(f)
    datasets[split_name] = dataset

    num_queries = len(dataset)
    num_docs = len(dataset[0]["all_documents"]) if num_queries > 0 else 0
    print(f"üìÇ Split: {split_name.upper():<5}")
    print(f"   ‚Ä¢ Path: {out.resolve()}")
    print(f"   ‚Ä¢ Queries:  {num_queries:,}")
    print(f"   ‚Ä¢ Documents: {num_docs:,}")
    print("-" * 60)

print("‚úÖ All datasets ready.\n")


üì¶ Downloading & Loading Datasets...


Downloading...
From: https://drive.google.com/uc?id=1fbRNGPpNebv8lDJGtjg0kvEXoYs5R4XC
To: /content/datasets/val_dataset.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 51.3M/51.3M [00:00<00:00, 235MB/s]


üìÇ Split: VAL  
   ‚Ä¢ Path: /content/datasets/val_dataset.json
   ‚Ä¢ Queries:  1,492
   ‚Ä¢ Documents: 371
------------------------------------------------------------


Downloading...
From (original): https://drive.google.com/uc?id=1BnB3cAakw5oB1z7yzt8HKXF9c5V3rTWw
From (redirected): https://drive.google.com/uc?id=1BnB3cAakw5oB1z7yzt8HKXF9c5V3rTWw&confirm=t&uuid=425d254f-b048-43b0-840b-59168a6af1bc
To: /content/datasets/test_dataset.json
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 108M/108M [00:00<00:00, 240MB/s] 


üìÇ Split: TEST 
   ‚Ä¢ Path: /content/datasets/test_dataset.json
   ‚Ä¢ Queries:  1,955
   ‚Ä¢ Documents: 535
------------------------------------------------------------
‚úÖ All datasets ready.



In [6]:
def print_sample(split, item_id):
    query_id = datasets[split][item_id]['query_id']
    query = datasets[split][item_id]['query']
    all_document_ids = datasets[split][item_id]['all_document_ids']
    all_documents = datasets[split][item_id]['all_documents']
    gold_document_ids = datasets[split][item_id]['correct_document_ids']
    gold_document = all_documents[gold_document_ids[0]]

    choices = datasets[split][item_id]['choices']
    correct_choice = datasets[split][item_id]['correct_choice']
    correct_choice_value = choices[correct_choice]

    print("=" * 80)
    print(f"üìò Split: {split} | Item ID: {item_id} | Query ID: {query_id}")
    print("-" * 80)
    print(f"üß† Query:\n{query}\n")
    print("üìö All Documents:")
    for doc_id, doc_text in zip(all_document_ids, all_documents):
        print(f"  [{doc_id}] {doc_text}")
    print("\nüèÜ Gold Document:")
    print(f"  ID: {gold_document_ids[0]}")
    print(f"  Text: {gold_document}\n")
    print("üó≥Ô∏è Choices:")
    for key, val in choices.items():
        print(f"  {key}: {val}")
    print(f"\n‚úÖ Correct Choice: {correct_choice} ‚Üí {correct_choice_value}")
    print("=" * 80)

In [None]:
print_sample('val', 1400)

üìò Split: val | Item ID: 1400 | Query ID: 1402
--------------------------------------------------------------------------------
üß† Query:
The name of the anthem of the country Michigan‚ÄìOhio State football rivalry is associated with is

üìö All Documents:
  [0] The name of the country of citizenship of Leonardo DiCaprio is Syria.
  [1] The name of the country which Academy Award for Best Picture is associated with is Wassoulou Empire.
  [2] The name of the spouse of Ron DeSantis is Carol Chu.
  [3] The names of the siblings of Janice Dickinson are Antoine-Jean-Matthieu S√©guier.
  [4] Big Mouth is followed by 1977‚Äì78 French Division 2.
  [5] The name of the anthem of Philippines is Hatikvah.
  [6] The name of the country of citizenship of Jerrod Carmichael is Terengganu.
  [7] The name of the composer of Vikram is Johnny Reine.
  [8] The place of burial of Princess Alice of Battenberg is Pante√≥n de Marinos Ilustres.
  [9] Soviet Union follows 2011 Greece Junior Badminton Champ

In [None]:
print_sample('test', 1400)

üìò Split: test | Item ID: 1400 | Query ID: 1400
--------------------------------------------------------------------------------
üß† Query:
The name of the head of state of the country 5 Forge Row is associated with is

üìö All Documents:
  [0] The name of the country which Goursez Vreizh is associated with is Franche-Comt√©.
  [1] The name of the country which Pralayakkad South is associated with is Sui dynasty.
  [2] The gender of Jose L Castillo is cisgender female.
  [3] The occupation of Emily I Jones is philatelist.
  [4] The name of the country which Suttor is associated with is Dutch Republic.
  [5] The name of the country which canton of Orci√®res is associated with is Chuvash Republic.
  [6] The occupation of G.L. Defer is Greek prefect.
  [7] The name of the country which Shockwave is associated with is Republic of Abkhazia.
  [8] The occupation of Nicholas D Rintala is police dog.
  [9] The name of the mother of Stephana Warnock is Sheila Mary Nolan.
  [10] The gender o

## üîé Section 3 ‚Äî Implement the Retriever

In this section, we implement a **retriever** that retrieves the most relevant documents for every query.  
We can choose to use either:

- **BM25** ‚Äî a lexical retriever based on token overlap, or  
- **Bi-encoder** ‚Äî a neural retriever that uses **sentence embeddings**.

> üí° Read more about different types of model-based retrievers here: [blog](https://blog.dailydoseofds.com/p/visual-guide-to-bi-encoders-cross)

Each retriever computes **similarity scores** between every pair of query and document.  
For each query, your goal is to identify the **top-k documents** with the highest similarity scores.


In [7]:
from sentence_transformers.util import similarity
def get_similarity_scores(retrieval_dataset, retriever_type="model", model_name="hkunlp/instructor-large"):
    queries = [item["query"] for item in retrieval_dataset] # queries.
    documents = retrieval_dataset[0]["all_documents"] # documents.
    for idx in range(len(retrieval_dataset)): # documents for different queries in a split are exactly same.
      assert retrieval_dataset[idx]["all_documents"] == documents

    if retriever_type == 'model':
      print(f"Using retriever type: {retriever_type}, model name: {model_name}")
    else:
      print(f"Using retriever type: {retriever_type}")

    if retriever_type == "model":
        # Use the SentenceTransformer specified by model_name to compute similarity scores (tensor of shape [len(queries), len(documents)]).
        model = SentenceTransformer(model_name)
        query_embeddings = model.encode(queries, convert_to_tensor=True, show_progress_bar=True)
        document_embeddings = model.encode(documents, convert_to_tensor=True, show_progress_bar=True)

        query_embeddings=torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
        document_embeddings=torch.nn.functional.normalize(document_embeddings, p=2, dim=1)

        similarity=torch.matmul(query_embeddings, document_embeddings.T)

    elif retriever_type == "bm25":
        # Use the BM25Okapi to compute similarity scores (tensor of shape [len(queries), len(documents)]).
        tokenized_docs = [doc.lower().split() for doc in documents]
        bm25 = BM25Okapi(tokenized_docs)

        all_scores = []
        for q in queries:
          tokenized_q = q.lower().split()
          scores=bm25.get_scores(tokenized_q)
          all_scores.append(scores)
        similarity=torch.tensor(np.array(all_scores), dtype=torch.float32)

    else:
        raise ValueError(f"Unknown retriever type: {retriever_type}")

    # --- Output consistency checks ---
    assert isinstance(similarity, torch.Tensor)
    assert similarity.shape[0] == len(queries) and similarity.shape[1] == len(documents)
    return similarity

In [8]:
def get_topk(similarity, documents, k):
    """
    Return top-k indices and documents per query for a single k.
    """
    # Use tensor manipulation to fetch indices of top-k most similar documents corresponding to each query.

    topk_indices = torch.topk(similarity, k=k, dim=1).indices.cpu().tolist()  # [[i1..ik], ...]
    topk_docs = [[documents[j] for j in row] for row in topk_indices]


    # --- Output consistency checks ---
    assert len(topk_indices) == similarity.size(0), (
        f"Mismatch: got {len(topk_indices)} query results, expected {similarity.size(0)}"
    )
    assert all(len(row) == k for row in topk_indices), "Each query must return exactly k indices"
    assert all(len(row) == k for row in topk_docs), "Each query must return exactly k docs"
    assert len(topk_docs) == len(topk_indices), "Mismatch between indices and docs output lengths"
    assert all(isinstance(idx, int) for row in topk_indices for idx in row), \
        "All elements in topk_indices must be integers"
    assert all(isinstance(doc, str) for row in topk_docs for doc in row), \
        "All elements in topk_docs must be strings"
    return topk_indices, topk_docs

In [9]:
def hitrate_at_k_from_indices(topk_indices, gold_ids):
    """
    Compute hit-rate@k given top-k indices (per query) and gold doc IDs (per query).
    """
    n = len(gold_ids)
    hits = sum(bool(set(gold_ids[i]).intersection(topk_indices[i])) for i in range(n))
    return hits / n if n else 0.0

Change below here when you want to run on train or test split

In [10]:
CURRENT_SPLIT = "test" # set current split

In [11]:
# setting the dataset variables
retrieval_dataset = datasets[CURRENT_SPLIT]
gold_ids  = [item["correct_document_ids"] for item in datasets[CURRENT_SPLIT]]
queries   = [item["query"] for item in datasets[CURRENT_SPLIT]]
documents = datasets[CURRENT_SPLIT][0]["all_documents"]
answer_choices = []
for item in datasets[CURRENT_SPLIT]:
  item_choices = item['choices']
  item_choices['correct_choice'] = item['correct_choice']
  answer_choices.append(item_choices)

### üß© Step 3 ‚Äì Run Retriever, Visualize Examples, and Save Results

In this step, we will:

1. **Run the Retriever**  
   - Execute the retrieval pipeline using your selected retriever type (e.g., `bm25`, `instructor`, or `qwen3` model).  
   - Compute similarity scores between all **queries** and **documents**.  
   - Evaluate **Hit@K** metrics across different values of K (e.g., 1, 2, 4, 8, 16, 32, 64).

2. **Visualize Retrieval Examples**  
   - Randomly inspect a few **Top-1 retrieved documents** for sample queries.  
   - Check whether the retrieved results match the **ground-truth document IDs**.

3. **Save Retrieval Results**  
   - Store the computed top-K indices, retrieved documents, and hit rates in a dictionary (e.g., `TOP_K_RETRIEVAL_RESULTS`).  
   - Later, write these results to disk for submission or further evaluation.

In [12]:
# Try retriever "bm25" and "model" with embedding models such as "hkunlp/instructor-large" or "Qwen/Qwen3-Embedding-0.6B
RETRIEVER_TYPE, RETRIEVER_MODEL_NAME = "model", "Qwen/Qwen3-Embedding-0.6B"

TOP_K_RETRIEVAL_RESULTS = dict() # saving the retrieval results here.
similarity_scores = get_similarity_scores(retrieval_dataset=retrieval_dataset, retriever_type=RETRIEVER_TYPE, model_name=RETRIEVER_MODEL_NAME).cpu()

ks = (1, 2, 4, 8, 16, 32, 64)
for k in ks:
    topk_idx, topk_docs = get_topk(similarity_scores, documents, k=k)
    if k not in TOP_K_RETRIEVAL_RESULTS:
      TOP_K_RETRIEVAL_RESULTS[k] = dict()
    hitrate = hitrate_at_k_from_indices(topk_idx, gold_ids)
    TOP_K_RETRIEVAL_RESULTS[k] = {'indices': topk_idx, 'documents': topk_docs, 'hitrate': hitrate}
    if k == 1 and CURRENT_SPLIT == "val":
      assert hitrate > 0.80, "For validation split the hitrate@1 should be greater than 0.8"
    print(f"HitRate@{k}: {hitrate:.3f}")

Using retriever type: model, model name: Qwen/Qwen3-Embedding-0.6B


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Batches:   0%|          | 0/62 [00:00<?, ?it/s]

Batches:   0%|          | 0/17 [00:00<?, ?it/s]

HitRate@1: 0.001
HitRate@2: 0.002
HitRate@4: 0.003
HitRate@8: 0.006
HitRate@16: 0.011
HitRate@32: 0.032
HitRate@64: 0.078


In [13]:
random.seed(42)
print("\n=== Sample Top-1 Retrievals (Random Order) ===")
for i in random.sample(range(len(queries)), k=min(20, len(queries))):
    top1_idx = TOP_K_RETRIEVAL_RESULTS[1]['indices'][i][0]
    gold = set(gold_ids[i])
    is_hit = top1_idx in gold
    symbol = "‚úÖ" if is_hit else "‚ùå"

    print(f"\n{symbol} Query {i}: {queries[i]}")
    print(f"  ‚û§ Top-1 Retrieved Doc[{top1_idx}]: {topk_docs[i][0][:150]}{'...' if len(topk_docs[i][0])>150 else ''}")
    print(f"  üéØ Ground Truth IDs: {sorted(gold)}")


=== Sample Top-1 Retrievals (Random Order) ===

‚ùå Query 1309: The official language of the country Leyden is associated with is
  ‚û§ Top-1 Retrieved Doc[361]: The name of the country which Leyden is associated with is Kushan Empire.
  üéØ Ground Truth IDs: [0]

‚ùå Query 228: The occupation of the author of Imaging Fibroblast Activation Protein Alpha improves diagnosis of metastatic Prostate Cancer with Positron Emission Tomography is
  ‚û§ Top-1 Retrieved Doc[224]: The name of the author of Using stem cell-derived gametes for same-sex reproduction: an alternative scenario is Binghang Liu.
  üéØ Ground Truth IDs: [0]

‚ùå Query 51: The name of the position held by the father of Thomas Davis Lamb is
  ‚û§ Top-1 Retrieved Doc[13]: The name of the position held by Thomas Phillipps Lamb is deputy high court judge.
  üéØ Ground Truth IDs: [0]

‚ùå Query 1518: The name of the alma mater of the spouse of Deborah (?) is
  ‚û§ Top-1 Retrieved Doc[407]: The name of the spouse of Deborah (

In [15]:
# Save the retrieval results.

output_file = f"anlp_hw2_outputs/retrieval_results_{RETRIEVER_TYPE}_{RETRIEVER_MODEL_NAME.replace('/', '.')}_{CURRENT_SPLIT}.json"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

with open(output_file, "w") as f:
    json.dump(TOP_K_RETRIEVAL_RESULTS, f, indent=2)

print(f"‚úÖ Retrieval results saved successfully to: {output_file}")

‚úÖ Retrieval results saved successfully to: anlp_hw2_outputs/retrieval_results_model_Qwen.Qwen3-Embedding-0.6B_test.json


### üß† Section 4 ‚Äî Generate Answers with an LLM

In this section, we‚Äôll use a **generator LLM** to answer queries **conditioned on the retrieved documents**. Our objective is to improve the model‚Äôs reasoning and accuracy by systematically experimenting with different configurations.

**tasks:**
1. **Try different generator models**  
   - Experiment with different small models.  
   - Compare their performance on the validation split to identify which model generalizes best.

2. **Tune the input prompts**  
   - Refine the system and user prompts to improve reasoning structure and output consistency.  
   - Emphasize step-by-step reasoning and enforce strict JSON-only answers.

3. **Vary the number of top-K retrieved documents**  
   - Experiment with different context sizes (e.g., K = 0, 1, 2, 4, 8, 16, 64).  
   - Observe how increasing or decreasing K affects both accuracy and generation cost.

4. **Save your results to disk**  
   - Store outputs for atleast two configuration (model, K, prompt variant) (one with K>=1 and one with K=0) on both **validation** and **test** splits.  

### üß© Step 1 ‚Äì Load the retrieval results from the best found retriever

In this step, we will **reload the retrieval outputs** obtained from the best-performing retriever model (identified in the previous section).  
These results contain the **top-K retrieved documents** for each query and will serve as input for the **generator LLM** in the next stages.

In [16]:
RETRIEVER_TYPE, RETRIEVER_MODEL_NAME = "model", "Qwen/Qwen3-Embedding-0.6B"
retrieval_results_file = f"anlp_hw2_outputs/retrieval_results_{RETRIEVER_TYPE}_{RETRIEVER_MODEL_NAME.replace('/', '.')}_{CURRENT_SPLIT}.json"
TOP_K_RETRIEVAL_RESULTS = json.load(open(retrieval_results_file))

### üß© Step 2 ‚Äì Setup the generator LLM using vLLM

In this step, we will **initialize the generator model** that will produce answers based on the retrieved documents.  
The generator LLM takes as input the query, context (retrieved top-K documents), and the prompt instructions ‚Äî and outputs the most likely correct answer.

We will:
1. **Select the generator model**  
   Choose from models like `Qwen2.5-3B-Instruct`, `Llama-3.2-1B-Instruct`, or `SmolLM3-3B`.  
   Larger models generally exhibit stronger reasoning abilities but require more GPU memory and may take longer to load.

2. **Initialize the tokenizer**  
   The tokenizer formats the input messages using the appropriate chat template (system + user roles) before sending them to the model.

3. **Load the model with vLLM**  
   Use the `LLM()` class from [**vLLM**](https://vllm.ai) ‚Äî a high-performance inference engine optimized for **fast, memory-efficient, and scalable generation**.  
   vLLM uses techniques such as **PagedAttention**, dynamic batching, and GPU memory sharing, allowing you to run large models efficiently with minimal latency.

   Key parameters to tune:
   - `dtype`: `"float16"` or `"bfloat16"` (depending on GPU support)  
   - `tensor_parallel_size`: Increase if running on multiple GPUs  
   - `gpu_memory_utilization`: Adjust between `0.85‚Äì0.95` depending on available memory  
   - `max_model_len`: Set according to the model‚Äôs context window  

In [17]:
# Try different small generator models such as "Qwen/Qwen2.5-3B-Instruct" or "meta-llama/Llama-3.2-1B-Instruct" or "HuggingFaceTB/SmolLM3-3B"
# <fill block>
GENERATOR_NAME = "HuggingFaceTB/SmolLM3-3B"
# </fill block>

# ---- Initialize tokenizer (for chat templating) ----
generator_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_NAME)

# ---- Initialize vLLM engine ----
generator_model = LLM(
    model=GENERATOR_NAME,
    dtype="float16",                # or "bfloat16" if supported
    tensor_parallel_size=1,         # increase if multi-GPU
    gpu_memory_utilization=0.85,     # try 0.85‚Äì0.95 depending on GPU headroom
    max_model_len=2048,
    enable_prefix_caching=True      # saves the compute if prompts share a common prefix (e.g. a system prompt)
)
print(f"‚úÖ Model {GENERATOR_NAME} loaded successfully.")

INFO 11-11 02:43:59 [utils.py:233] non-default args: {'dtype': 'float16', 'max_model_len': 2048, 'enable_prefix_caching': True, 'gpu_memory_utilization': 0.85, 'disable_log_stats': True, 'model': 'HuggingFaceTB/SmolLM3-3B'}
INFO 11-11 02:44:01 [model.py:547] Resolved architecture: SmolLM3ForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 11-11 02:44:01 [model.py:1510] Using max model len 2048
INFO 11-11 02:44:04 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 11-11 02:44:51 [llm.py:306] Supported_tasks: ['generate']
‚úÖ Model HuggingFaceTB/SmolLM3-3B loaded successfully.


### üß© Step 3 ‚Äì Tune the Prompts to Generator

In this step, we will design **prompt templates** that will guide your **generator model** (e.g., `Llama-3`, `Qwen2`, `Gemma`, etc.) to produce accurate and well-structured answers based on the retrieved documents.

In [18]:
def prepare_generator_input(contextual_facts, question, answer_options):
    ctx = " ".join(contextual_facts)
    answer_choices_string = "\n".join([f"{choice}: {answer_options['choice_' + choice]}" for choice in ['A', 'B', 'C', 'D', 'E', 'F']])
    # NOTE: Write the system and user prompt.
    # System prompt should guide the LLM about how to use context, how to reason and how to geneate output in the parseable format {"correct_choice": "label"}, where "label" can only be "A", "B", "C", "D", "E", "F".
    # User prompt will be dependent on contextual facts, question and answer_option.
    # More tips on tuning prompts:
    # =========================
    # SYSTEM PROMPT CHECKLIST
    # =========================
    # Purpose: Teach the model HOW to think and respond.

    # 1) Define model role clearly:
    #       - State the model is an expert in Wikidata facts and logical reasoning.

    # 2) Explain function of the provided context:
    #       - Context temporarily modifies/overrides relationships in the Wikidata graph.

    # 3) Set precedence of new facts:
    #       - Modified facts in context take precedence over any prior/latent knowledge.

    # 4) Specify reasoning style:
    #       - Require 3‚Äì6 concise, non-redundant bullet points.
    #       - Bullets must reference the modified relationships and their ripple effects.

    # 5) Enforce output structure:
    #       - Exactly two sections, in order:
    #         1) "Reasoning:" (bullets)
    #         2) "Final:" (single JSON line)

    # 6) Enforce strict JSON format in Final:
    #       - Exactly: {"correct_choice": "label"}
    #       - "label" ‚àà {"A", "B", "C", "D", "E", "F"}

    # 7) Restrict JSON location:
    #       - JSON must appear ONLY in the Final section (never inside Reasoning).

    # 8) Encourage step-by-step logic:
    #       - Instruct to explain how modified links cause one choice to be correct.


    # =======================
    # USER PROMPT CHECKLIST
    # =======================
    # Purpose: Provide the specific instance (facts, question, choices).

    # 1) Include contextual facts:
    #       - Start with "Context:" containing the edited/modified Wikidata relationships.

    # 2) Present the question clearly:
    #       - Follow with "Question:" that refers to entities impacted by the context.

    # 3) List answer choices:
    #       - End with "Answer Choices:" listing options A‚ÄìF, clearly labeled.

    # 4) Keep structure consistent:
    #       - Order must be: Context ‚Üí Question ‚Üí Answer Choices.
    #       - Use clean line breaks; avoid extra filler text.

    # 5) Avoid meta-instructions:
    #       - Do NOT include reasoning guidance or examples‚Äîonly task content.

    # 6) Keep inputs parseable:
    #       - Consistent labels and formatting to allow automated extraction.


    sys = ("""
    You are an expert reasoning assistant trained to answer logical questions based on Wikidata-style knowledge.
    You are given MODIFIED FACTS that temporarily replace original Wikidata relationships. Always treat these new facts as the only source of truth.

    Follow this reasoning and output structure exactly:
    ------------------------------------------------------
    Reasoning:
    - Provide 3‚Äì6 concise bullet points where each point must explain how the modified facts affect the entities mentioned in the question and answer choices
    - Explain how these relationships eliminate or support each answer choice step by step.
    - Use step-by-step reasoning; avoid redundancy.

    Final:
    - Output exactly one JSON object on a single line using this schema:
    {"correct_choice": "<label>"}
    where <label> ‚àà {"A", "B", "C", "D", "E", "F"}.

    Rules:
    1. The JSON line must appear only after 'Final:' and never inside 'Reasoning'.
    2. Do not include explanations or text after the JSON line.
    3. Modified context overrides the model‚Äôs prior knowledge ‚Äî always reason from context first.
    4. Responses not following this exact format are considered invalid.

    """
    )#


    user = f"""
    Context: {ctx}

    Question: {question}

    Answer Choies: {answer_choices_string}

    Select the correct choice based solely on the Context above. Provide your reasoning and final JSON answer as instructed.

    """

    return [
        {"role": "system", "content": sys},
        {"role": "user", "content": user}
    ]

In [19]:
def chat_response(list_messages, generator_model, generator_tokenizer, max_new_tokens, temperature=0.0):
    # Prepare the chat input text
    prompts = []
    for messages in list_messages:
      prompt = generator_tokenizer.apply_chat_template(
          messages,
          tokenize=False,
          add_generation_prompt=True
      )
      prompts.append(prompt)

    # --- Define sampling parameters ---
    sp = SamplingParams(
        temperature=temperature,
        max_tokens=max_new_tokens,
    )

    # --- Run inference with vLLM ---
    responses = generator_model.generate(prompts, sp)
    responses = [response.outputs[0].text for response in responses]

    return responses

### üß© Step 4 ‚Äì Prepare Generator Inputs with Top-K Retrieved Documents

In this step, we‚Äôll **combine the retriever results** (Top-K documents) with the query and answer options to form the **input messages** for the generator LLM.

Tasks:

1. **Select K (number of retrieved documents)**
   - Set the variable `K` to control how many top-retrieved documents are passed to the generator.
   - Try values like `K = 1`, `K = 2`, `K = 4`, or `K = 8` to observe how retrieval depth affects generation accuracy.
   - When `K = 0`, the model receives **no context** (baseline, zero-context reasoning).

2. **Sample Evaluation Queries**
   - A random subset of queries will be selected to visualize input examples and sanity-check that the right documents are being used.

3. **Generate Formatted Messages**
   - For each query, we will create a message pair:
     - **System + User prompts** (from `prepare_generator_input`)
     - **Ground-truth answer label**
   - These structured inputs will later be passed to your generator in the next step for inference.

4. **Inspect a Few Samples**
   - After preparing the list `messages_and_labels`, print a few samples to confirm:
     - Retrieved context snippets
     - Corresponding question and answer choices
     - JSON-formatted LLM messages and correct answer labels

In [25]:
# NOTE: Choose how many top-k documents (experiment with atleast K=0 and K=1) you want to put the in the context of generator LLM.
K = 0

if K == 0:
  topk_idx, topk_docs = [[] for _ in range(len(TOP_K_RETRIEVAL_RESULTS[str(1)]['indices']))], [[] for _ in range(len(TOP_K_RETRIEVAL_RESULTS[str(1)]['documents']))]
else:
  topk_idx, topk_docs = TOP_K_RETRIEVAL_RESULTS[str(K)]['indices'], TOP_K_RETRIEVAL_RESULTS[str(K)]['documents']

messages_and_labels = [
    {
        'messages': prepare_generator_input(topk_docs[i], queries[i], answer_choices[i]),
        'correct_choice': answer_choices[i]['correct_choice'].split('_')[-1],
    }
    for i in range(len(queries))
]

In [26]:
# visualise some geneator inptus
random.seed(42)
indices = random.sample(range(len(queries)), k=5)

for i in indices:
    print(f"\n=== Query {i} ===")
    print(f"‚ùì {queries[i]}")
    print(f"üìö Context ({len(topk_docs[i])} docs):")
    for j, doc in enumerate(topk_docs[i]):
        print(f"  [{j+1}] {doc[:150]}{'...' if len(doc)>150 else ''}")
    print(f"\nüß† LLM Messages:\n{json.dumps(messages_and_labels[indices.index(i)]['messages'], indent=4)}\n")
    print(f"Correct Choice:\n{messages_and_labels[indices.index(i)]['correct_choice']}")


=== Query 1309 ===
‚ùì The official language of the country Leyden is associated with is
üìö Context (0 docs):

üß† LLM Messages:
[
    {
        "role": "system",
        "content": "\n    You are an expert reasoning assistant trained to answer logical questions based on Wikidata-style knowledge.\n    You are given MODIFIED FACTS that temporarily replace original Wikidata relationships. Always treat these new facts as the only source of truth.\n\n    Follow this reasoning and output structure exactly:\n    ------------------------------------------------------\n    Reasoning:\n    - Provide 3\u20136 concise bullet points where each point must explain how the modified facts affect the entities mentioned in the question and answer choices\n    - Explain how these relationships eliminate or support each answer choice step by step.\n    - Use step-by-step reasoning; avoid redundancy.\n\n    Final:\n    - Output exactly one JSON object on a single line using this schema:\n    {\"correct

### üß© Step 5 ‚Äì Compute Accuracy and Save the Generated Responses

In this final step, we‚Äôll **evaluate your generator model** by comparing its predicted answers against the ground truth labels and compute overall accuracy.


In [27]:
def accuracy(responses, labels):
    pred_labels = []
    parsing_error = 0.0
    for response, label in zip(responses, labels):
        pred_obj = repair_json(response, return_objects=True)
        if isinstance(pred_obj, list):
            pred_obj = pred_obj[-1]
        if isinstance(pred_obj, dict) and "correct_choice" in pred_obj:
            pred_label = pred_obj["correct_choice"]
        else:
            pred_label = "A"  # default fallback
            parsing_error += 1
        pred_labels.append(pred_label)

    correct = sum(p == l for p, l in zip(pred_labels, labels))
    return {"accuracy": correct / len(labels) if labels else 0.0, "parse_error_fraction": parsing_error / len(labels) if labels else 0.0}

In [28]:
generator_messages = [item['messages'] for item in messages_and_labels]
labels = [item['correct_choice'] for item in messages_and_labels]

responses = chat_response(
    generator_messages, generator_model, generator_tokenizer, max_new_tokens=1502
)

# compute final accuracy
results = accuracy(responses, labels)
print(f"Final Accuracy: {results['accuracy']:.3%} || Parsing Error: {results['parse_error_fraction']:.4%}")
#assert results['accuracy'] > 0.5, "Final accuracy should be greater than 0.5 after tuning value of K and prompt."
#assert results['parse_error_fraction'] < 0.05, "Parsing error fraction should be less than 5%."

Adding requests:   0%|          | 0/1955 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1955 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s‚Ä¶

Final Accuracy: 31.304% || Parsing Error: 0.8184%


AssertionError: Final accuracy should be greater than 0.5 after tuning value of K and prompt.

In [29]:
# Save the generator results.
output_file = f"anlp_hw2_outputs/generator_results_{GENERATOR_NAME.replace('/', '.')}_K-{K}_{CURRENT_SPLIT}.json"
os.makedirs(os.path.dirname(output_file), exist_ok=True)

with open(output_file, "w") as f:
    json.dump({"responses": responses, "generator_messages": generator_messages}, f, indent=2)

print(f"‚úÖ Generator results saved successfully to: {output_file}")

‚úÖ Generator results saved successfully to: anlp_hw2_outputs/generator_results_HuggingFaceTB.SmolLM3-3B_K-0_test.json
