# Query Generation for HotpotQA with ColBERT

This notebook demonstrates a query generation system for the HotpotQA dataset using ColBERT embeddings. The workflow includes:

- Loading the HotpotQA fullwiki dataset
- Setting up ColBERT model for text embedding and retrieval
- Implementing few-shot prompting for query generation
- Evaluating query quality for multi-hop question answering

The system generates search queries that can effectively retrieve relevant passages for complex multi-hop questions in the HotpotQA benchmark.

#### Unzip the pretrained Query Generation Model

In [None]:
# prompt: extract this zip file : "/content/epoch_1_model.zip" to the a folder

import zipfile

# Define the path to the zip file and the destination folder
zip_path = "/content/epoch_1_model.zip"
extract_path = "/content/extracted_model"

# Create the destination folder if it doesn't exist
import os
os.makedirs(extract_path, exist_ok=True)

# Extract the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print(f"Extracted {zip_path} to {extract_path}")

#### Import Required Libraries

In [2]:

# Importing the necessary libraries
from datasets import load_dataset
from transformers import  AutoTokenizer, AutoModel
import torch
import os
import json
from tqdm import tqdm
from huggingface_hub import login
import numpy as np
from collections import defaultdict
import random

# Setting the seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x140849f10>

#### Initialize ColBERT Model and the dataset

In [None]:
# setting the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Iniitializing the retrieval model
colbert_tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0")
colbert_model = AutoModel.from_pretrained("colbert-ir/colbertv2.0").to(device)
colbert_model.eval()

#### Clone the github repo

In [None]:
# Cloning the GitHub repository
import getpass

# Login to GitHub
# You will be prompted to enter your GitHub username and password.
username = input("Enter your GitHub username: ")
password = getpass.getpass("Enter your GitHub password: ")

!git clone https://{username}:{password}@github.com/erenyavuz02/Trajectory-Aware-RL-for-Efficient-Multi-Hop-Retrieval.git

import sys
sys.path.append('/content/Trajectory-Aware-RL-for-Efficient-Multi-Hop-Retrieval')


In [None]:
# Check if dataset files exist, otherwise load from HuggingFace
def load_dataset_from_jsonl(filename):
    data = defaultdict(list)
    with open(filename, 'r') as f:
        for line in f:
            item = json.loads(line.strip())
            for key, value in item.items():
                data[key].append(value)
    return dict(data)

# Check if pre-saved dataset files exist
train_file = "train_dataset.jsonl"
eval_file = "eval_dataset.jsonl"

if os.path.exists(train_file) and os.path.exists(eval_file):
    print("Loading datasets from existing files...")
    train_dataset = load_dataset_from_jsonl(train_file)
    eval_dataset = load_dataset_from_jsonl(eval_file)
    print(f"Loaded {len(train_dataset['question'])} training samples")
    print(f"Loaded {len(eval_dataset['question'])} evaluation samples")
else:
    print("Loading dataset from HuggingFace...")
    dataset = load_dataset("hotpot_qa", "fullwiki", trust_remote_code=True)
    DATASET_SPLIT = 0.9  # 90% for training, 10% for validation
    train_dataset = dataset['train'][:5000]  # Use 5K for faster processing
    eval_dataset = dataset['train'][5000:6000]  # Use 1K for validation
    print(f"Loaded {len(train_dataset)} training samples")
    print(f"Loaded {len(eval_dataset)} evaluation samples")

#### Load few-shot prompts

Print some examples of few-shot prompts

In [10]:
# get the fewshot examples from the json file
with open("fewshot_examples.json", "r") as f:
    FEWSHOT_EXAMPLES = json.load(f)

# print some examples
for example in FEWSHOT_EXAMPLES[:5]:  # Print the first 5 examples
    print(f"Question: {example['question']}")
    print(f"Query: {example['query']}")
    print("-" * 40)


Question: Which Nirvana album featured The Vaselines, Dave Grohl, and Chad Channing?
Query: Nirvana album featuring The Vaselines Dave Grohl Chad Channing
----------------------------------------
Question: Who contributed to more Disney films, Claire Keane or her father Glen Keane?
Query: Claire Keane vs Glen Keane Disney films
----------------------------------------
Question: What do the films "Giuliani Time" and "Life After People" have in common?
Query: Giuliani Time and Life After People commonalities
----------------------------------------
Question:  "I Saw Her Again" was co-written by what Canadian singer born in 1940?
Query: "I Saw Her Again" co-written by Canadian singer born 1940
----------------------------------------
Question: What former Detroit Pistons player hosted a talkshow on MTV in 1996?
Query: Detroit Pistons player hosted MTV talk show 1996
----------------------------------------


#### Helper Functions 

In [8]:

# Few-shot examples for generating search queries
def build_fewshot_prompt(question, context="", add_fewshot=False):
    
    task_str = f"Generate a search query for the following question:\n{question}"
    
    context_str = f"Context:\n{context}\n\n" if context else ""
    
    if add_fewshot:
        num_fewshots = random.randint(1, 3)
        fewshots = random.sample(FEWSHOT_EXAMPLES, num_fewshots)

        fewshot_str = "Examples:\n"
        for ex in fewshots:
            fewshot_str += f"Question:{ex['question']}\nQuery:{ex['query']}\n\n"
        return f"{fewshot_str}{context_str}{task_str}"
    else:
        return f"{context_str}{task_str}"
    
# === Embedding utility ===
def compute_colbert_embeddings(texts):
    encoded = colbert_tokenizer(
        texts,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    ).to(device)
    with torch.no_grad():
        output = colbert_model(**encoded).last_hidden_state
    masks = encoded["attention_mask"].bool()
    return [output[i][masks[i]].cpu().numpy() for i in range(len(texts))]

# === Scoring utility ===
def maxsim_score(query_emb, doc_embs):
    return float((torch.matmul(query_emb, doc_embs.T)).max(dim=1).values.sum())

def compute_ap_recall_precision(supporting_pairs, retrieved_ids, sentence_metadata):
    if not retrieved_ids or not supporting_pairs:
        return 0.0, 0.0, 0.0
        
    retrieved_pairs = {
        (sentence_metadata[i]["title"], sentence_metadata[i]["sent_idx"]) for i in retrieved_ids
    }
    hits = [1 if (sentence_metadata[i]["title"], sentence_metadata[i]["sent_idx"]) in supporting_pairs else 0 for i in retrieved_ids]
    
    # Calculate AP (Average Precision)
    ap = sum(hits[i] / (i + 1) for i in range(len(hits)) if hits[i]) / max(sum(hits), 1)
    
    # Calculate regular precision
    precision = sum(hits) / len(retrieved_ids) if retrieved_ids else 0
    
    # Calculate recall
    recall = sum(hits) / len(supporting_pairs) if supporting_pairs else 0
    
    return ap, precision, recall

def calculate_f1(precision, recall):
    if precision + recall == 0:
        return 0.0
    return (2 * precision * recall) / (precision + recall)


# Evaluater

In [None]:
def evaluate_hotpotqa(
    eval_dataset,
    query_generator,
    query_tokenizer,
    num_hops=2,
    top_k_retrieval=5,
    max_new_tokens=20  # Allow more tokens for potentially longer queries
):
    print(f"Starting evaluation with {len(eval_dataset)} samples...")  # Print evaluation start message

    # ---------- PREPARING ----------
    metrics_per_hop = [{  # Initialize per-hop metrics storage
        "total_ap": 0.0,
        "total_precision": 0.0,
        "total_recall": 0.0,
        "num_samples": 0
    } for _ in range(num_hops)]

    all_results = []  # To store detailed results for inspection
    pbar = tqdm(range(len(eval_dataset)), desc="Evaluating")  # Initialize progress bar
    
    for idx in pbar:  # Iterate through each sample in the dataset
        sample = eval_dataset[idx]  # Get current sample
        question = sample['question']  # Extract question text
        supporting_facts = sample['supporting_facts']  # Extract ground truth supporting facts
        
        context_titles = sample['context']['title']  # Extract context titles
        context_sentences_grouped = sample['context']['sentences']  # Extract grouped context sentences
        flattened_sentences = []  # Initialize list for all sentences
        sentence_metadata = []  # Initialize metadata for each sentence
        for title, sentences in zip(context_titles, context_sentences_grouped):  # Iterate through title-sentence pairs
            for i, sent in enumerate(sentences):  # Iterate through sentences in each group
                flattened_sentences.append(sent)  # Add sentence to flat list
                sentence_metadata.append({"title": title, "sent_idx": i})  # Add metadata

        context_embeddings = compute_colbert_embeddings(flattened_sentences)  # Compute embeddings for all context sentences

        vector_store_embeddings_for_scoring = [torch.tensor(emb, dtype=torch.float32).to(device) for emb in context_embeddings]  # Convert embeddings to tensors

        current_context = ""  # Initialize context for first hop (empty)
        ground_truth_supporting_pairs = set(zip(supporting_facts['title'], supporting_facts['sent_id']))  # Create set of ground truth pairs

        question_results = {  # Initialize results storage for current question
            "question": question,
            "ground_truth_supporting_pairs": list(ground_truth_supporting_pairs),
            "hops": []
        }

        if not ground_truth_supporting_pairs:  # Skip questions with no supporting facts
            continue

        for hop in range(num_hops):  # Iterate through each hop

            # ---------- QUERY GENERATION ----------
            prompt = build_fewshot_prompt(question, context=current_context)  # Build few-shot prompt

            inputs = query_tokenizer(  # Tokenize the prompt
                prompt,
                return_tensors="pt",
                padding=True,  # Apply padding if batching
                truncation=True
            ).to(query_generator.device)

            outputs = query_generator.generate(  # Generate query using T5 model
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=max_new_tokens,
                do_sample=False,  # Deterministic generation
                num_return_sequences=1,
                pad_token_id=query_tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=False
            )

            generated_sequence = outputs.sequences[0]  # Extract generated sequence
            generated_query = query_tokenizer.decode(generated_sequence, skip_special_tokens=True).strip()  # Decode to text

            if not generated_query:  # Check if query is empty
                print(f"Warning: Empty query generated for question: '{question}' hop: {hop}.")
                continue  # Skip if empty query

            # ---------- RETRIEVAL ----------
            query_emb_list = compute_colbert_embeddings([generated_query])  # Compute query embeddings
            if not query_emb_list:  # Check if embedding generation failed
                print(f"Warning: No embedding generated for query: '{generated_query}' for question: '{question}' hop: {hop}.")
                continue  # Skip if embedding fails

            query_emb_np = query_emb_list[0]  # Get query embedding as numpy array
            query_emb = torch.tensor(query_emb_np, dtype=torch.float32).to(device)  # Convert to tensor

            scores = []  # Initialize scores list
            for doc_emb in vector_store_embeddings_for_scoring:  # Iterate through document embeddings
                scores.append(maxsim_score(query_emb, doc_emb))  # Compute similarity score

            if not scores:  # Check if no scores computed
                continue

            top_indices = np.argsort(scores)[-top_k_retrieval:][::-1]  # Get top-k retrieved document indices

            # ---------- PRECISION CALCULATION ----------
            ap, precision, recall = compute_ap_recall_precision(  # Calculate evaluation metrics
                ground_truth_supporting_pairs, 
                top_indices, 
                sentence_metadata
            )

            f1 = calculate_f1(precision, recall)  # Calculate F1 score

            metrics_per_hop[hop]["total_ap"] += ap  # Accumulate AP score
            metrics_per_hop[hop]["total_precision"] += precision  # Accumulate precision
            metrics_per_hop[hop]["total_recall"] += recall  # Accumulate recall
            metrics_per_hop[hop]["num_samples"] += 1  # Increment sample count

            retrieved_context = [flattened_sentences[i] for i in top_indices]  # Get retrieved sentences

            question_results["hops"].append({  # Store hop results
                "hop": hop,
                "generated_query": generated_query,
                "raw_generated_query": generated_query,
                "ap": ap,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "top_k_retrieved_docs": retrieved_context,
                "top_k_retrieved_ids": top_indices.tolist()
            })

            current_context = "\n".join(retrieved_context)  # Update context for next hop

        all_results.append(question_results)  # Store results for current question
        
        if idx > 0:  # Update progress bar with current metrics (avoid division by zero)
            total_samples = sum(hop["num_samples"] for hop in metrics_per_hop)  # Calculate total samples
            if total_samples > 0:  # Check if samples exist
                current_ap = sum(hop["total_ap"] for hop in metrics_per_hop) / total_samples  # Calculate current AP
                current_precision = sum(hop["total_precision"] for hop in metrics_per_hop) / total_samples  # Calculate current precision
                current_recall = sum(hop["total_recall"] for hop in metrics_per_hop) / total_samples  # Calculate current recall
                current_f1 = calculate_f1(current_precision, current_recall)  # Calculate current F1
                
                pbar.set_postfix({  # Update progress bar display
                    'AP': f'{current_ap:.3f}',
                    'P': f'{current_precision:.3f}',
                    'R': f'{current_recall:.3f}',
                    'F1': f'{current_f1:.3f}'
                })

    # ---------- PREPARING FINAL RESULTS ----------
    print("\n=== Per-Hop Evaluation Summary ===")  # Print summary header
    hop_summaries = []  # Initialize hop summaries list
    for hop in range(num_hops):  # Iterate through each hop
        num_samples = metrics_per_hop[hop]["num_samples"]  # Get sample count for hop
        if num_samples > 0:  # Check if hop has samples
            avg_ap = metrics_per_hop[hop]["total_ap"] / num_samples  # Calculate average AP
            avg_precision = metrics_per_hop[hop]["total_precision"] / num_samples  # Calculate average precision
            avg_recall = metrics_per_hop[hop]["total_recall"] / num_samples  # Calculate average recall
            avg_f1 = calculate_f1(avg_precision, avg_recall)  # Calculate average F1
            
            print(f"\nHop {hop + 1} Metrics:")  # Print hop header
            print(f"Number of Samples: {num_samples}")  # Print sample count
            print(f"Average AP: {avg_ap:.4f}")  # Print average AP
            print(f"Average Precision: {avg_precision:.4f}")  # Print average precision
            print(f"Average Recall: {avg_recall:.4f}")  # Print average recall
            print(f"Average F1: {avg_f1:.4f}")  # Print average F1
            
            hop_summaries.append({  # Store hop summary
                "hop": hop + 1,
                "num_samples": num_samples,
                "average_ap": avg_ap,
                "average_precision": avg_precision,
                "average_recall": avg_recall,
                "average_f1": avg_f1
            })

    total_samples = sum(hop["num_samples"] for hop in metrics_per_hop)  # Calculate total samples across all hops
    overall_ap = sum(hop["total_ap"] for hop in metrics_per_hop) / total_samples if total_samples > 0 else 0.0  # Calculate overall AP
    overall_precision = sum(hop["total_precision"] for hop in metrics_per_hop) / total_samples if total_samples > 0 else 0.0  # Calculate overall precision
    overall_recall = sum(hop["total_recall"] for hop in metrics_per_hop) / total_samples if total_samples > 0 else 0.0  # Calculate overall recall
    overall_f1 = calculate_f1(overall_precision, overall_recall)  # Calculate overall F1

    print("\n=== Overall Metrics (Averaged Across Hops) ===")  # Print overall summary header
    print(f"Total Samples Evaluated: {total_samples}")  # Print total sample count
    print(f"Overall AP: {overall_ap:.4f}")  # Print overall AP
    print(f"Overall Precision: {overall_precision:.4f}")  # Print overall precision
    print(f"Overall Recall: {overall_recall:.4f}")  # Print overall recall
    print(f"Overall F1: {overall_f1:.4f}")  # Print overall F1

    return {  # Return evaluation results
        "overall_metrics": {
            "average_ap": overall_ap,
            "average_precision": overall_precision,
            "average_recall": overall_recall,
            "average_f1": overall_f1,
            "total_samples": total_samples
        },
        "per_hop_metrics": hop_summaries,
        "detailed_results": all_results
    }


# Load the query generation model

In [None]:
from transformers import AutoModelForSeq2SeqLM

# Model to evaluate
model_path= "google/flan-t5-small"
model_to_eval = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
model_to_eval_tokenizer = AutoTokenizer.from_pretrained(model_path)



# Evaluate the base model on the HotpotQA dataset

In [None]:


# --- Run Evaluation ---
evaluation_metrics = evaluate_hotpotqa(
    eval_dataset=eval_dataset,
    query_generator=model_to_eval,
    query_tokenizer=model_to_eval_tokenizer,
    num_hops=2,           # Keep consistent with your training/preference dataset generation
    top_k_retrieval=5,    # Keep consistent with your preference dataset generation
    max_new_tokens=20     # Adjust as needed for query length
)

# --- Save Results (Optional) ---
output_filename = "hotpotqa_evaluation_results.json"
with open(output_filename, "w") as f:
    json.dump(evaluation_metrics, f, indent=4)
print(f"\nDetailed evaluation results saved to {output_filename}")

# Load the trained query generation model

In [None]:
# Load your trained query generation model
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Replace with the path to your trained model
trained_model_path = "path/to/your/trained/model"  # Update this path

# Load the trained model and tokenizer
trained_query_generator = AutoModelForSeq2SeqLM.from_pretrained(trained_model_path).to(device)
trained_query_tokenizer = AutoTokenizer.from_pretrained(trained_model_path)

# Set to evaluation mode
trained_query_generator.eval()

print(f"Loaded trained model from: {trained_model_path}")
print(f"Model device: {trained_query_generator.device}")

# Evalute the trained model on the HotpotQA dataset

In [None]:

# --- Run Evaluation ---
evaluation_metrics_trained = evaluate_hotpotqa(
    eval_dataset=eval_dataset,
    query_generator=trained_query_generator,
    query_tokenizer=trained_query_tokenizer,
    num_hops=2,           # Keep consistent with your training/preference dataset generation
    top_k_retrieval=5,    # Keep consistent with your preference dataset generation
    max_new_tokens=20     # Adjust as needed for query length
)

# --- Save Results (Optional) ---
output_filename = "hotpotqa_evaluation_results_trained_model.json"
with open(output_filename, "w") as f:
    json.dump(evaluation_metrics_trained, f, indent=4)
print(f"\nDetailed evaluation results saved to {output_filename}")

In [None]:
# Compare queries generated by both models
print("=== Query Comparison: Base Model vs Trained Model ===\n")

# Get some example questions from the evaluation dataset
sample_questions = [
    eval_dataset[0]['question'],
    eval_dataset[1]['question'],
    eval_dataset[2]['question'],
    eval_dataset[3]['question'],
    eval_dataset[4]['question']
]

for i, question in enumerate(sample_questions):
    print(f"Question {i+1}: {question}\n")
    
    # Generate query with base model
    prompt = build_fewshot_prompt(question, add_fewshot=True)
    
    # Base model query generation
    inputs = model_to_eval_tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(device)
    
    base_outputs = model_to_eval.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=20,
        do_sample=False,
        num_return_sequences=1,
        pad_token_id=model_to_eval_tokenizer.eos_token_id
    )
    
    base_query = model_to_eval_tokenizer.decode(base_outputs[0], skip_special_tokens=True).strip()
    
    # Trained model query generation
    trained_inputs = trained_query_tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to(device)
    
    trained_outputs = trained_query_generator.generate(
        input_ids=trained_inputs["input_ids"],
        attention_mask=trained_inputs["attention_mask"],
        max_new_tokens=20,
        do_sample=False,
        num_return_sequences=1,
        pad_token_id=trained_query_tokenizer.eos_token_id
    )
    
    trained_query = trained_query_tokenizer.decode(trained_outputs[0], skip_special_tokens=True).strip()
    
    print(f"Base Model Query: {base_query}")
    print(f"Trained Model Query: {trained_query}")
    print("-" * 80)
    print()