In [12]:
!pip install datasets transformers torch==2.7 tqdm numpy pylate bitsandbytes accelerate huggingface_hub wandb torchvision

Collecting transformers
  Using cached transformers-4.53.2-py3-none-any.whl.metadata (40 kB)
Collecting torch==2.7
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting triton==3.3.0 (from torch==2.7)
  Downloading triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)
Collecting transformers
  Using cached transformers-4.48.2-py3-none-any.whl.metadata (44 kB)
INFO: pip is looking at multiple versions of torchvision to determine which version is compatible with other requirements. This could take a while.
Collecting torchvision
  Downloading torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl (865.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m865.2/865.2 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (156.5 MB)
[2K   

## Preference Dataset Creation with Google T5 Flan

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import torch
import json
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import random
import datetime

CURRENT_TIME_STAMP = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# === Query Generator (T5 Flan) ===
model_path = "google/flan-t5-small"
query_generator = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
query_tokenizer = AutoTokenizer.from_pretrained(model_path)
query_generator.eval()

# === ColBERT for retrieval scoring ===
colbert_tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0")
colbert_model = AutoModel.from_pretrained("colbert-ir/colbertv2.0").to(device)
colbert_model.eval()

Using device: cuda


model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/405 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [3]:
import os
# === Dataset ===
# Check if datasets already exist

# Generate timestamped and parameter-specific filenames
timestamp = CURRENT_TIME_STAMP.replace(" ", "_").replace(":", "-")
dataset_size = 5000
val_size = 1000

train_filename = "hotpot_train_5000samples_2025-06-11_18-31-48.jsonl"
val_filename = "hotpot_val_1000samples_2025-06-11_18-31-48.jsonl"

# Check for existing files with similar pattern
existing_train_files = [f for f in os.listdir('.') if f.startswith('hotpot_train_') and f.endswith('.jsonl')]
existing_val_files = [f for f in os.listdir('.') if f.startswith('hotpot_val_') and f.endswith('.jsonl')]

if existing_train_files and existing_val_files:
    # Use the most recent existing files
    train_filename = sorted(existing_train_files)[-1]
    val_filename = sorted(existing_val_files)[-1]

    print(f"Using existing dataset files:")
    print(f"Training: {train_filename}")
    print(f"Validation: {val_filename}")

    # Load from existing files
    def load_dataset_from_jsonl(filename):
        data = defaultdict(list)
        with open(filename, 'r') as f:
            for line in f:
                item = json.loads(line.strip())
                for key, value in item.items():
                    data[key].append(value)
        return dict(data)

    train_dataset = load_dataset_from_jsonl(train_filename)
    val_dataset = load_dataset_from_jsonl(val_filename)

    print(f"Loaded {len(train_dataset['question'])} training samples from existing file")
    print(f"Loaded {len(val_dataset['question'])} validation samples from existing file")
else:
    # Load from HuggingFace and create new files
    dataset = load_dataset("hotpot_qa", "fullwiki", trust_remote_code=True)
    DATASET_SPLIT = 0.9  # 90% for training, 10% for validation
    train_dataset = dataset['train'][:5000]  # Use 5K for faster processing
    val_dataset = dataset['train'][5000:6000]  # Use 1K for validation

    print(f"Loaded {len(train_dataset['question'])} samples for preference dataset creation")

    # Dump dataset into JSONL files for future use
    def dump_dataset_to_jsonl(dataset, filename):
        with open(filename, 'w') as f:
            for item in dataset:
                f.write(json.dumps(item) + '\n')

    dump_dataset_to_jsonl(train_dataset, train_filename)
    dump_dataset_to_jsonl(val_dataset, val_filename)

    print(f"Training dataset saved to: {train_filename}")
    print(f"Validation dataset saved to: {val_filename}")

Using existing dataset files:
Training: hotpot_train_5000samples_2025-06-11_18-31-48.jsonl
Validation: hotpot_val_1000samples_2025-06-11_18-31-48.jsonl
Loaded 5000 training samples from existing file
Loaded 1000 validation samples from existing file


# Core Utility Functions

In [59]:
def compute_colbert_embeddings_batched(texts, batch_size=32):
    """Compute ColBERT embeddings for texts in batches"""
    all_embeddings = []

    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]

        encoded = colbert_tokenizer(
            batch_texts,
            max_length=512,
            padding=True,  # Changed from "max_length" to True for efficiency
            truncation=True,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            output = colbert_model(**encoded).last_hidden_state

        masks = encoded["attention_mask"].bool()
        batch_embeddings = [output[i][masks[i]].cpu().numpy() for i in range(len(batch_texts))]
        all_embeddings.extend(batch_embeddings)

        # Clear GPU memory
        del output, encoded
        torch.cuda.empty_cache()

    return all_embeddings

# Main Preference Dataset Creation Loop

## Scoring and Evaluation Functions

In [70]:
def maxsim_score(query_emb, doc_emb):
    """Optimized MaxSim score computation"""
    if isinstance(query_emb, np.ndarray):
        query_tensor = torch.tensor(query_emb, dtype=torch.float32, device=device)
    else:
        query_tensor = query_emb.to(device)

    if isinstance(doc_emb, np.ndarray):
        doc_tensor = torch.tensor(doc_emb, dtype=torch.float32, device=device)
    else:
        doc_tensor = doc_emb.to(device)

    # Use torch.mm for better performance
    similarity_matrix = torch.mm(query_tensor, doc_tensor.T)
    return float(similarity_matrix.max(dim=1).values.sum())

def compute_ap_recall(supporting_pairs, retrieved_ids, sentence_metadata):
    """Compute Average Precision and Recall"""
    retrieved_pairs = {
        (sentence_metadata[i]["title"], sentence_metadata[i]["sent_idx"])
        for i in retrieved_ids
    }

    hits = [
        1 if (sentence_metadata[i]["title"], sentence_metadata[i]["sent_idx"]) in supporting_pairs else 0
        for i in retrieved_ids
    ]

    ap = sum(hits[i] / (i + 1) for i in range(len(hits)) if hits[i]) / max(sum(hits), 1)
    recall = sum(hits) / len(supporting_pairs) if supporting_pairs else 0

    return ap, recall

## Query Generation with T5 Flan

In [69]:
def generate_query(question, context="", use_fewshot=False):
    """Generate a single query using T5 Flan"""
    if context:
        prompt = f"Context: {context}\n\nGenerate a search query for: {question}"
    else:
        prompt = f"Generate a search query for: {question}"

    if use_fewshot:
        # Load and use random examples from fewshot_examples.json
        fewshot_file_path = '/content/fewshot_examples.json'
        try:
            with open(fewshot_file_path, 'r') as f:
                fewshot_examples = json.load(f)

            # Randomly select 2-3 examples
            num_examples = min(3, len(fewshot_examples))
            selected_examples = random.sample(fewshot_examples, num_examples)

            fewshot_prompt = "\n".join([f"Question: {ex['question']}\nQuery: {ex['query']}" for ex in selected_examples])
            prompt = f"{fewshot_prompt}\n\nQuestion: {question}\nQuery:"
        except FileNotFoundError:
            print("Warning: fewshot_examples.json not found, proceeding without few-shot examples")

    inputs = query_tokenizer(
        prompt,
        return_tensors="pt",
        max_length=512,
        truncation=True
    ).to(device)

    with torch.no_grad():
        outputs = query_generator.generate(
            **inputs,
            max_new_tokens=20,
            do_sample=True,
            temperature=0.85,
            top_p=0.9,
            num_return_sequences=1
        )

    query = query_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    return query


def generate_queries_batch(question, current_context, num_queries=5, use_fewshot=True):
    """Generate multiple queries in a single batch"""
    if use_fewshot:
        try:
            with open('/content/fewshot_examples.json', 'r') as f:
                fewshot_examples = json.load(f)

            selected_examples = random.sample(fewshot_examples, min(3, len(fewshot_examples)))
            fewshot_prompt = "\n".join([f"Question: {ex['question']}\nQuery: {ex['query']}" for ex in selected_examples])

            if current_context:
                prompt = f"{fewshot_prompt}\n\nContext: {current_context}\nQuestion: {question}\nQuery:"
            else:
                prompt = f"{fewshot_prompt}\n\nQuestion: {question}\nQuery:"
        except FileNotFoundError:
            if current_context:
                prompt = f"Context: {current_context}\n\nGenerate a search query for: {question}"
            else:
                prompt = f"Generate a search query for: {question}"
    else:
        if current_context:
            prompt = f"Context: {current_context}\n\nGenerate a search query for: {question}"
        else:
            prompt = f"Generate a search query for: {question}"

    inputs = query_tokenizer(
        prompt,
        return_tensors="pt",
        max_length=512,
        truncation=True
    ).to(device)

    with torch.no_grad():
        outputs = query_generator.generate(
            **inputs,
            max_new_tokens=20,
            do_sample=True,
            temperature=0.85,
            top_p=0.9,
            num_return_sequences=num_queries  # Generate all queries at once
        )

    queries = []
    for output in outputs:
        query = query_tokenizer.decode(output, skip_special_tokens=True).strip()
        # Extract only the generated part
        if "Query:" in query:
            query = query.split("Query:")[-1].strip()
        if query and query not in queries:
            queries.append(query)

    return queries[:num_queries]

In [71]:
# === Configuration Parameters ===
NUM_HOPS = 2          # Number of retrieval hops
NUM_QUERIES = 5       # Generate 5 queries per hop for ranking
TOP_K = 5            # Top-K documents to retrieve
BATCH_SIZE = 32      # Process samples in batches for speed

print(f"Configuration:")
print(f"- Number of hops: {NUM_HOPS}")
print(f"- Queries per hop: {NUM_QUERIES}")
print(f"- Top-K retrieval: {TOP_K}")
print(f"- Batch size: {BATCH_SIZE}")

Configuration:
- Number of hops: 2
- Queries per hop: 5
- Top-K retrieval: 5
- Batch size: 32


## Data Processing and Context Preparation

In [72]:
def prepare_sample_context(sample):
    """Prepare and flatten context for a single sample"""
    context_titles = sample['context']['title']
    context_sentences_grouped = sample['context']['sentences']
    flattened_sentences = []
    sentence_metadata = []

    for title, sentences in zip(context_titles, context_sentences_grouped):
        for i, sent in enumerate(sentences):
            flattened_sentences.append(sent)
            sentence_metadata.append({"title": title, "sent_idx": i})

    return flattened_sentences, sentence_metadata

print("Data processing functions defined")

Data processing functions defined


In [73]:
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
from torch.utils.data import Dataset
import random

def finetune_query_generator_with_fewshot():
    """Fine-tune T5 Flan model using few-shot examples in prompts"""

    global query_generator # Moved to the beginning of the function

    # Load few-shot examples
    fewshot_file_path = '/content/fewshot_examples.json'
    try:
        with open(fewshot_file_path, 'r') as f:
            fewshot_examples = json.load(f)
        print(f"Loaded {len(fewshot_examples)} few-shot examples for fine-tuning")
    except FileNotFoundError:
        print("Error: fewshot_examples.json not found. Please create this file first.")
        return None

    def create_fewshot_prompt(target_question, target_query, examples, num_shots=3):
        """Create a few-shot prompt with examples"""
        # Randomly select few-shot examples (excluding the target)
        available_examples = [ex for ex in examples if ex['question'] != target_question]
        selected_examples = random.sample(available_examples, min(num_shots, len(available_examples)))

        # Build the prompt
        prompt = "Generate search queries based on questions. Here are some examples:\n\n"

        # Add few-shot examples
        for i, example in enumerate(selected_examples, 1):
            prompt += f"Example {i}:\n"
            prompt += f"Question: {example['question']}\n"
            prompt += f"Query: {example['query']}\n\n"

        # Add the target question
        prompt += f"Now generate a query for:\nQuestion: {target_question}\nQuery:"

        return prompt

    class QueryDataset(Dataset):
        def __init__(self, examples, tokenizer, max_length=512):
            self.examples = examples
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __len__(self):
            return len(self.examples)

        def __getitem__(self, idx):
            example = self.examples[idx]

            # Create few-shot prompt
            input_text = create_fewshot_prompt(
                example['question'],
                example['query'],
                self.examples,
                num_shots=3
            )
            target_text = example['query']

            inputs = self.tokenizer(
                input_text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            targets = self.tokenizer(
                target_text,
                max_length=64,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            return {
                'input_ids': inputs.input_ids.squeeze(),
                'attention_mask': inputs.attention_mask.squeeze(),
                'labels': targets.input_ids.squeeze()
            }

    # Create dataset
    dataset = QueryDataset(fewshot_examples, query_tokenizer)

    # Split into train/val (80/20)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset_ft, val_dataset_ft = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=query_tokenizer,
        model=query_generator,
        padding=True
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"./finetuned_t5_flan_{timestamp}",
        num_train_epochs=20,
        per_device_train_batch_size=2,  # Reduced due to longer prompts
        per_device_eval_batch_size=2,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir=f"./logs_{timestamp}",
        logging_steps=10,
        evaluation_strategy="steps",
        eval_steps=50,
        save_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        gradient_accumulation_steps=2,  # To compensate for smaller batch size
    )

    # Initialize trainer
    trainer = Trainer(
        model=query_generator,
        args=training_args,
        train_dataset=train_dataset_ft,
        eval_dataset=val_dataset_ft,
        data_collator=data_collator,
        tokenizer=query_tokenizer,
    )

    print("Starting fine-tuning with few-shot prompts...")
    trainer.train()

    # Save the fine-tuned model
    model_save_path = f"./finetuned_t5_flan_final_{timestamp}"
    trainer.save_model(model_save_path)
    query_tokenizer.save_pretrained(model_save_path)

    print(f"Fine-tuned model saved to: {model_save_path}")

    # Update global model with fine-tuned version
    query_generator = trainer.model

    return model_save_path

# Execute fine-tuning if few-shot examples exist
if os.path.exists('/content/fewshot_examples.json'):
    finetuned_model_path = finetune_query_generator_with_fewshot()
else:
    print("Few-shot examples file not found. Skipping fine-tuning.")
    print("To enable fine-tuning, create '/content/fewshot_examples.json' with format:")
    print('[{"question": "example question", "query": "example query"}, ...]')

Loaded 100 few-shot examples for fine-tuning
Starting fine-tuning with few-shot prompts...


  trainer = Trainer(


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [74]:
# Generate example queries to showcase the T5 Flan query generator
print("="*80)
print("QUERY GENERATOR EXAMPLES")
print("="*80)

# Get a few sample questions from the dataset
sample_questions = train_dataset['question'][:5]
sample_contexts = []

# Prepare some sample contexts from the dataset
for i in range(5):
    sample = {k: train_dataset[k][i] for k in train_dataset.keys()}
    flattened_sentences, _ = prepare_sample_context(sample)
    # Take first few sentences as context
    context = " ".join(flattened_sentences[:3])
    sample_contexts.append(context + "..." if len(context) > 200 else context)

print("\n1. QUERIES WITHOUT CONTEXT:")
print("-" * 50)
for i, question in enumerate(sample_questions):
    query = generate_query(question, use_fewshot= False )
    print(f"\nQuestion {i+1}: {question}...")
    print(f"Generated Query: {query}")


print("\n\n2. QUERIES WITH CONTEXT and without fewshot:")
print("-" * 50)
for i, (question, context) in enumerate(zip(sample_questions, sample_contexts)):
    query = generate_query(question, context, use_fewshot = False)
    print(f"\nQuestion {i+1}: {question}...")
    print(f"Context: {context}...")
    print(f"Generated Query: {query}")

print("\n\n3. QUERIES WITH CONTEXT and with fewshot:")
print("-" * 50)
for i, (question, context) in enumerate(zip(sample_questions, sample_contexts)):
    query = generate_query(question, context, use_fewshot = True)
    print(f"\nQuestion {i+1}: {question}...")
    print(f"Context: {context}...")
    print(f"Generated Query: {query}")

print("\n\n4. QUERIES WITHout CONTEXT (FEW SHOT EXAMPLES):")
print("-" * 50)
for i, (question, context) in enumerate(zip(sample_questions, sample_contexts)):
    query = generate_query(question, context = None, use_fewshot = True)
    print(f"\nQuestion {i+1}: {question}...")
    print(f"Context: {context}...")
    print(f"Generated Query: {query}")

print("\n\n5. MULTIPLE QUERIES FOR SAME QUESTION:")
print("-" * 50)
example_question = sample_questions[0]
print(f"Question: {example_question}")
print("Generated queries:")
for j in range(NUM_QUERIES):
    query = generate_query(example_question, sample_contexts[0], use_fewshot= True)
    print(f"  {j+1}. {query}")

print("\n" + "="*80)

QUERY GENERATOR EXAMPLES

1. QUERIES WITHOUT CONTEXT:
--------------------------------------------------

Question 1: Which magazine was started first Arthur's Magazine or First for Women?...
Generated Query: Arthur's Magazine

Question 2: The Oberoi family is part of a hotel company that has a head office in what city?...
Generated Query: Oberoi family head office in city

Question 3: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?...
Generated Query: Allie Goertz song written by The Simpsons character Milhouse (Madage

Question 4:  What nationality was James Henry Miller's wife?...
Generated Query: James Henry Miller's wife nationality

Question 5: Cadmium Chloride is slightly soluble in this chemical, it is also called what?...
Generated Query: Cadmium Chloride in soluble in chemistry


2. QUERIES WITH CONTEXT and without fewshot:
--------------------------------------------------

Question 1: Which maga

## Single Hop Processing Function

In [75]:
def process_single_hop_optimized(question, current_context, flattened_sentences,
                                context_embeddings, supporting_pairs, sentence_metadata):
    """Optimized single hop processing"""

    # Generate all queries at once
    queries = generate_queries_batch(question, current_context, NUM_QUERIES, use_fewshot=True)

    if not queries:
        return None

    # Batch compute query embeddings
    query_embeddings = compute_colbert_embeddings_batched(queries, batch_size=len(queries))

    # Score all queries against all documents in vectorized manner
    scored_queries = []
    for i, (query, query_emb) in enumerate(zip(queries, query_embeddings)):
        # Vectorized scoring
        scores = np.array([maxsim_score(query_emb, doc_emb) for doc_emb in context_embeddings])
        top_indices = np.argsort(scores)[-TOP_K:][::-1]

        ap, recall = compute_ap_recall(supporting_pairs, top_indices, sentence_metadata)

        scored_queries.append({
            "query": query,
            "ap": ap,
            "recall": recall,
            "top_indices": top_indices.tolist(),
            "retrieved_context": [flattened_sentences[i] for i in top_indices]
        })

    scored_queries.sort(key=lambda x: x["ap"], reverse=True)

    # Create preference pairs
    preference_pairs = []
    for i in range(len(scored_queries)):
        for j in range(i + 1, len(scored_queries)):
            if scored_queries[i]["ap"] > scored_queries[j]["ap"]:
                preference_pairs.append((i, j))

    return {
        "queries": [x["query"] for x in scored_queries],
        "aps": [x["ap"] for x in scored_queries],
        "recalls": [x["recall"] for x in scored_queries],
        "preference_pairs": preference_pairs,
        "best_retrieved_context": "\n".join(scored_queries[0]["retrieved_context"]) if scored_queries else ""
    }

## Main Processing Loop

In [76]:
# Add this at the beginning of your main processing loop
document_embedding_cache = {}

def get_cached_embeddings(flattened_sentences):
    """Get embeddings with caching to avoid recomputation"""
    # Create a hash key for the sentences
    sentences_key = hash(tuple(flattened_sentences))

    if sentences_key not in document_embedding_cache:
        document_embedding_cache[sentences_key] = compute_colbert_embeddings_batched(flattened_sentences)

    return document_embedding_cache[sentences_key]

In [77]:
import wandb

# Initialize wandb
wandb.init(
    project="t5-flan-preference-dataset",
    name=f"preference_creation_{timestamp}",
    config={
        "model_path": model_path,
        "dataset_size": dataset_size,
        "num_hops": NUM_HOPS,
        "num_queries": NUM_QUERIES,
        "top_k": TOP_K,
        "batch_size": BATCH_SIZE
    }
)

preference_dataset = {}
total_processed = 0
total_skipped = 0

print("Starting preference dataset creation...")

for batch_start in tqdm(range(0, len(train_dataset['question']), BATCH_SIZE), desc="Processing batches"):
    batch_end = min(batch_start + BATCH_SIZE, len(train_dataset['question']))

    for idx in range(batch_start, batch_end):
        sample = {k: train_dataset[k][idx] for k in train_dataset.keys()}
        question = sample['question']
        supporting_facts = sample['supporting_facts']

        # Skip if no supporting facts
        if not supporting_facts['title']:
            total_skipped += 1
            continue

        # Prepare context
        flattened_sentences, sentence_metadata = prepare_sample_context(sample)
        context_embeddings = get_cached_embeddings(flattened_sentences)
        supporting_pairs = set(zip(supporting_facts['title'], supporting_facts['sent_id']))

        # Initialize dataset entry
        preference_dataset[question] = {
            "question": question,
            "hops": {}
        }

        current_context = ""

        # Process each hop
        for hop in range(NUM_HOPS):

            hop_data = process_single_hop_optimized(
                question, current_context, flattened_sentences,
                context_embeddings, supporting_pairs, sentence_metadata
            )

            if hop_data:
                preference_dataset[question]["hops"][f"hop_{hop}"] = hop_data
                # Update context with best retrieval for next hop
                if hop_data["queries"]:
                    best_context = hop_data.get("best_retrieved_context", "")
                    current_context = best_context

                # Log hop metrics to wandb
                wandb.log({
                    f"hop_{hop}_avg_ap": np.mean(hop_data["aps"]) if hop_data["aps"] else 0,
                    f"hop_{hop}_max_ap": max(hop_data["aps"]) if hop_data["aps"] else 0,
                    f"hop_{hop}_avg_recall": np.mean(hop_data["recalls"]) if hop_data["recalls"] else 0,
                    f"hop_{hop}_num_queries": len(hop_data["queries"]),
                    f"hop_{hop}_num_preferences": len(hop_data["preference_pairs"])
                })

        total_processed += 1

        # Log progress every 100 samples
        if total_processed % 100 == 0:
            wandb.log({
                "total_processed": total_processed,
                "total_skipped": total_skipped,
                "completion_rate": total_processed / (total_processed + total_skipped)
            })

# Log final metrics
wandb.log({
    "final_total_processed": total_processed,
    "final_total_skipped": total_skipped,
    "final_dataset_size": len(preference_dataset),
    "processing_success_rate": total_processed / (total_processed + total_skipped) if (total_processed + total_skipped) > 0 else 0
})

print(f"Processing completed!")
print(f"- Total processed: {total_processed}")
print(f"- Total skipped: {total_skipped}")
print(f"- Final dataset size: {len(preference_dataset)} questions")

wandb.finish()

0,1
hop_0_avg_ap,▆▆▇▄▁▂▅██▅▅▃▇▃██▃▃▂█▁▁▁▆▆▁▇▄▆▅▆▅▆▆▃▅▆▆▂▂
hop_0_avg_recall,██▄▃▁▂▅▃▂▅▇▅▃▅▅▂▂▂▅▅▁▁█▅▅▁▅▅▇▃▅▇▃▃▃▃█▆▅▃
hop_0_max_ap,▆▆█▅▁▃▆██▃██▃█▅▁███▃▁▁▁▆██▁██▆▆█▆█▃█▆▆▃▃
hop_0_num_preferences,▁▁▅▂▁▃█▁▁▃▁▇▁▃█▁▁▃▆▅▄▁▁▁▁▁▃▆▂▁▆▅▅█▁▃▂▃▃▂
hop_0_num_queries,▆▃▆▁▆▃██▁▆▆█▁▃██▆▆█▆▆█▃▁▃▆▁▃▆▆▃▃▆▆█▃▃▁▃▃
hop_1_avg_ap,▇█▆▃▁▆▃▄▁▁▄▂▅▅▃▅█▄▂▁▁▁▁▅▆▁▅▂▄▅▅▆▇▆▅▁▇▅▃▂
hop_1_avg_recall,▆▅▅▆▁▁▅▅▄▁█▅▄▃▅▂▆▅▅▂▅▁▁▁▁▅▅▄▅▃▆█▅▃▃▁▆█▅▃
hop_1_max_ap,███▃▁▁█▃▅▁▅▆▃██▃▆█▆▆▄▁▁▁██▁█▅▆█▆███▁█▅▅▃
hop_1_num_preferences,▃▁▄▂▁▁▃▁▂▁▁▂█▂▃▃▁▂▆▁▁▁▁▁▆▃▁▄▅▃▁▃▂▅█▁▂▁▃▂
hop_1_num_queries,▅▁▆▅▅▅▁▃▁▁▃██▃▅▅▁▃█▁▅▃▁▆▁▁▆▆▅▁▅▅▃▆█▁▃▃▆▃

0,1
hop_0_avg_ap,0.125
hop_0_avg_recall,0.25
hop_0_max_ap,0.25
hop_0_num_preferences,1.0
hop_0_num_queries,2.0
hop_1_avg_ap,0.125
hop_1_avg_recall,0.25
hop_1_max_ap,0.25
hop_1_num_preferences,1.0
hop_1_num_queries,2.0


Starting preference dataset creation...


Processing batches: 100%|██████████| 157/157 [51:54<00:00, 19.83s/it]

Processing completed!
- Total processed: 5000
- Total skipped: 0
- Final dataset size: 5000 questions





0,1
completion_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
final_dataset_size,▁
final_total_processed,▁
final_total_skipped,▁
hop_0_avg_ap,▄▅▂▅▇▅▃▂▇▃▂▂▅▆▆▅▅▅█▁▇▅▅▂▃▄▃▃▂▅▂▃▂█▅▇▇▇▅▄
hop_0_avg_recall,▆▅▅▆▆▅▃▅▃▄▆▁▂▃▅▃▆▄▆▆▇▆▃▂▆█▄▅▄▃▃▆▄▄▅▃▃▂▅▃
hop_0_max_ap,▃▁▁█▅██▁█▄██▅█▅█▅███▁▆████▅█▅█▅▁█▆▅▅████
hop_0_num_preferences,▆▇▄▇▇▇▄▆▇▄▇▇▇▆▆▆██▇▇▆▇▄▇▃▁▁▁▇▅▇▅▁▇▇▅▇▁▄▃
hop_0_num_queries,██████▆███▆▃▆▆█▃████▃█▃███▁████▆███▃█▆██
hop_1_avg_ap,▂▇▇▆▇▆▃▆▅▄▅▂▆▁▄▁▄▂▆▂▇▁▆▆▅█▆▄▇▆▁▂█▇▄▆█▄▂▂

0,1
completion_rate,1.0
final_dataset_size,5000.0
final_total_processed,5000.0
final_total_skipped,0.0
hop_0_avg_ap,0.66667
hop_0_avg_recall,0.7
hop_0_max_ap,1.0
hop_0_num_preferences,8.0
hop_0_num_queries,5.0
hop_1_avg_ap,0.75


## Save Preference Dataset

In [78]:
import datetime

# Save the preference dataset
# Generate timestamped filename
timestamp = CURRENT_TIME_STAMP.replace(" ", "_").replace(":", "-")
output_file = f"preference_dataset_t5_flan_{timestamp}.json"
with open(output_file, "w") as f:
    json.dump(preference_dataset, f, indent=2)

print(f"Preference dataset saved to {output_file}")

# Display statistics
total_preference_pairs = 0
total_hops = 0

for question, data in preference_dataset.items():
    for hop_key, hop_data in data["hops"].items():
        total_hops += 1
        total_preference_pairs += len(hop_data["preference_pairs"])

print(f"\nDataset Statistics:")
print(f"- Total questions: {len(preference_dataset)}")
print(f"- Total hops: {total_hops}")
print(f"- Total preference pairs: {total_preference_pairs}")
print(f"- Average preference pairs per hop: {total_preference_pairs/total_hops:.2f}")

Preference dataset saved to preference_dataset_t5_flan_2025-07-21_10-08-17.json

Dataset Statistics:
- Total questions: 5000
- Total hops: 10000
- Total preference pairs: 53717
- Average preference pairs per hop: 5.37


In [79]:
def format_preference_data_for_training(preference_dataset_path):
    """Convert preference dataset to training format"""
    with open(preference_dataset_path, 'r') as f:
        data = json.load(f)

    training_data = []

    for question, entry in data.items():
        for hop_key, hop_data in entry["hops"].items():
            queries = hop_data["queries"]
            aps = hop_data["aps"]
            preference_pairs = hop_data["preference_pairs"]

            for preferred_idx, dispreferred_idx in preference_pairs:
                training_data.append({
                    "question": question,
                    "preferred": queries[preferred_idx],
                    "dispreferred": queries[dispreferred_idx],
                    "preferred_ap": aps[preferred_idx],
                    "dispreferred_ap": aps[dispreferred_idx],
                    "hop": hop_key
                })

    return training_data

print("Training data formatting function defined")

Training data formatting function defined


## Format Data for Training

## Generate Training Data

In [81]:
# Format the preference dataset for training
training_data = format_preference_data_for_training(output_file)

print(f"Created {len(training_data)} preference pairs for training")

# Save training data
training_filename = f"preference_training_data_formatted_{timestamp}.json"
with open(training_filename, "w") as f:
    json.dump(training_data, f, indent=2)

print(f"Training data saved to {training_filename}")

Created 53717 preference pairs for training
Training data saved to preference_training_data_formatted_2025-07-21_10-08-17.json


## Display Sample Results

In [83]:
# Display sample preference pairs
if training_data:
    print("\n" + "="*80)
    print("SAMPLE PREFERENCE PAIRS")
    print("="*80)

    for i, sample in enumerate(training_data[:5]):  # Show first 3 samples
        print(f"\nSample {i+1}:")
        print(f"Question: {sample['question'][:100]}...")
        print(f"Hop: {sample['hop']}")
        print(f"Preferred Query (AP={sample['preferred_ap']:.3f}): {sample['preferred']}")
        print(f"Dispreferred Query (AP={sample['dispreferred_ap']:.3f}): {sample['dispreferred']}")
        print("-" * 60)

    print(f"\nTotal training samples: {len(training_data)}")
else:
    print("No training data generated.")


SAMPLE PREFERENCE PAIRS

Sample 1:
Question: Which magazine was started first Arthur's Magazine or First for Women?...
Hop: hop_1
Preferred Query (AP=0.750): Arthur's Magazine magazine first for Women
Dispreferred Query (AP=0.667): first for women magazine in 1989
------------------------------------------------------------

Sample 2:
Question: Which magazine was started first Arthur's Magazine or First for Women?...
Hop: hop_1
Preferred Query (AP=0.750): Arthur's Magazine magazine first for Women
Dispreferred Query (AP=0.667): Arthur's Magazine magazine
------------------------------------------------------------

Sample 3:
Question: Which magazine was started first Arthur's Magazine or First for Women?...
Hop: hop_1
Preferred Query (AP=0.750): Arthur's Magazine magazine first for Women
Dispreferred Query (AP=0.667): Arthur's Magazine first for Women magazine in 1989
------------------------------------------------------------

Sample 4:
Question: Which magazine was started first Art

## PyTorch Dataset Class (Optional)

In [None]:
from torch.utils.data import Dataset

class PreferenceDataset(Dataset):
    """PyTorch Dataset for preference learning"""

    def __init__(self, json_path):
        with open(json_path, 'r') as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create dataset instance
train_dataset_pytorch = PreferenceDataset("preference_training_data.json")
print(f"PyTorch dataset created with {len(train_dataset_pytorch)} samples")
print("Ready for training with preference learning algorithms (DPO, IPO, etc.)")

## Summary

This notebook has successfully created a preference dataset using Google T5 Flan for query generation. The dataset includes:

- **Multi-hop retrieval**: 2 hops of query generation and retrieval
- **Multiple queries per hop**: 5 queries generated and ranked by AP score
- **Preference pairs**: Automatically generated based on retrieval performance
- **Training ready format**: JSON files ready for preference learning

**Output files:**
- `preference_dataset_t5_flan.json`: Raw preference dataset
- `preference_training_data.json`: Formatted training data

**Next steps:**
- Use the training data with preference learning algorithms (DPO, IPO, etc.)
- Fine-tune the T5 model on the preference pairs
- Evaluate the improved model on multi-hop QA tasks

In [4]:
!pip uninstall -y transformers datasets torch torchvision

[0m