In [None]:
# LitSearch RAG with Context-Enhanced Scientific Paper Chunking

import pandas as pd
import json
import numpy as np
import os
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import List, Dict, Any, Union, Optional

# LangChain imports
from langchain.schema import Document
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import BM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.callbacks import get_openai_callback

# Hugging Face datasets
from datasets import load_dataset

# BERTScore
from bert_score import score as bert_score

# ROUGE metrics
from rouge_score import rouge_scorer

# Set OpenAI API key
import os
os.environ["OPENAI_API_KEY"] = "Replace with your API key"  # Replace with your API key

In [2]:
# Initialize ROUGE scorer
rouge_scorer_instance = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

## 1. Loading the Datasets

# Function to load user dataset (handles both JSON and CSV)
def load_user_dataset(file_path: str, num_samples: Optional[int] = None):
    """
    Load user dataset from JSON or CSV file
    
    Args:
        file_path: Path to dataset file (JSON or CSV)
        num_samples: Number of samples to use (None = all)
        
    Returns:
        List of dictionaries containing the dataset
    """
    if file_path.endswith('.json'):
        with open(file_path, 'r') as f:
            data = json.load(f)
            dataset = data.get('data', data)  # Handle both {'data': [...]} and direct list format
    elif file_path.endswith('.csv'):
        df = pd.read_csv(file_path)
        dataset = df.to_dict(orient='records')
    else:
        raise ValueError("Unsupported file format. Use JSON or CSV.")
    
    # Sample if requested
    if num_samples and num_samples < len(dataset):
        import random
        random.seed(42)  # For reproducibility
        dataset = random.sample(dataset, num_samples)
    
    print(f"Loaded {len(dataset)} samples from {file_path}")
    return dataset

In [4]:
# Function to load LitSearch corpus from Hugging Face
def load_litsearch_corpus():
    """
    Load the LitSearch corpus from Hugging Face
    
    Returns:
        Dictionary mapping paper IDs to paper content
    """
    print("Loading LitSearch corpus dataset...")
    try:
        # Load the corpus
        corpus_data = load_dataset("princeton-nlp/LitSearch", "corpus_clean", split="full")
        print(f"Loaded {len(corpus_data)} papers from LitSearch corpus")
        
        # Print a sample to understand the structure
        if len(corpus_data) > 0:
            print("Corpus data fields:", list(corpus_data[0].keys()))
        
        # Create a dictionary mapping paper IDs to paper content
        corpus_dict = {}
        paper_id_field = None
        
        # Determine the field containing paper IDs
        if len(corpus_data) > 0:
            sample = corpus_data[0]
            if "paper_id" in sample:
                paper_id_field = "paper_id"
            elif "doc_id" in sample:
                paper_id_field = "doc_id"
            elif "corpusid" in sample:
                paper_id_field = "corpusid"
            else:
                # Find a field that looks like an ID
                for key in sample.keys():
                    if "id" in key.lower():
                        paper_id_field = key
                        break
        
        if not paper_id_field:
            raise ValueError("Could not determine paper ID field in corpus dataset")
        
        print(f"Using '{paper_id_field}' as paper ID field")
        
        # Create the mapping
        for item in tqdm(corpus_data):
            paper_id = item.get(paper_id_field)
            if paper_id:
                corpus_dict[paper_id] = {
                    'title': item.get('title', ''),
                    'abstract': item.get('abstract', ''),
                    'full_text': item.get('full_paper', ''),  # Get full paper text if available
                    'authors': item.get('authors', ''),
                    'year': item.get('year', '')
                }
        
        print(f"Created corpus dictionary with {len(corpus_dict)} papers")
        return corpus_dict, paper_id_field
    
    except Exception as e:
        print(f"Error loading LitSearch corpus: {e}")
        print("Falling back to using only title and abstract from user dataset")
        return {}, None

# Load dataset paths
user_dataset_path = "/Users/himansh/Desktop/ANLP/litsearch/litsearch_rag_dataset_fullpaper_500.json"  # Replace with your path
user_dataset = load_user_dataset(user_dataset_path, num_samples=50)  # Adjust as needed

# Load LitSearch corpus
corpus_dict, paper_id_field = load_litsearch_corpus()

Loaded 50 samples from /Users/himansh/Desktop/ANLP/litsearch/litsearch_rag_dataset_fullpaper_500.json
Loading LitSearch corpus dataset...
Loaded 64183 papers from LitSearch corpus
Corpus data fields: ['corpusid', 'title', 'abstract', 'citations', 'full_paper']
Using 'corpusid' as paper ID field


100%|██████████| 64183/64183 [00:10<00:00, 6016.20it/s]

Created corpus dictionary with 64183 papers





In [5]:
## 2. Enhanced Scientific Paper Chunker

class ScientificPaperChunker:
    """Enhanced chunker for scientific papers with context preservation"""
    
    def __init__(self, chunk_size=1000, chunk_overlap=250):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
    
    def extract_topics(self, text, n=5):
        """Extract key topics from text (simplified version)"""
        # This is a simplified approach - in production, you would use NER or topic modeling
        common_scientific_terms = [
            "algorithm", "analysis", "data", "model", "method", "results", 
            "neural", "learning", "network", "performance", "accuracy", "prediction",
            "framework", "system", "implementation", "architecture"
        ]
        
        # Count term frequency
        term_counts = {}
        for term in common_scientific_terms:
            term_counts[term] = len(re.findall(r'\b' + term + r'\b', text.lower()))
        
        # Get top terms
        top_terms = sorted(term_counts.items(), key=lambda x: x[1], reverse=True)
        return [term for term, count in top_terms[:n] if count > 0]
    
    def chunk_paper(self, paper_id, paper_content):
        """
        Chunk a scientific paper with context preservation
        
        Args:
            paper_id: ID of the paper
            paper_content: Dictionary with paper content (title, abstract, full_text, etc.)
            
        Returns:
            List of Document objects
        """
        title = paper_content.get('title', '')
        abstract = paper_content.get('abstract', '')
        full_text = paper_content.get('full_text', '')
        authors = paper_content.get('authors', '')
        year = paper_content.get('year', '')
        
        # Skip if no content
        if not title and not abstract and not full_text:
            return []
        
        # Extract topics from abstract and title
        topics = self.extract_topics(title + " " + abstract)
        topics_str = ", ".join(topics) if topics else "scientific research"
        
        # Create header template for chunks
        header_template = f"PAPER: \"{title}\"\n"
        if authors:
            header_template += f"AUTHORS: {authors}\n"
        if year:
            header_template += f"YEAR: {year}\n"
        header_template += f"TOPICS: {topics_str}\n"
        
        chunks = []
        
        # Always include abstract as a standalone chunk with rich context
        if abstract:
            abstract_doc = Document(
                page_content=f"{header_template}SECTION: Abstract\n\n{abstract}",
                metadata={
                    'paper_id': paper_id,
                    'title': title,
                    'section': 'abstract',
                    'chunk_index': 0,
                    'topics': topics,
                    'contains_abstract': True
                }
            )
            chunks.append(abstract_doc)
        
        # If we have full text, chunk it
        if full_text:
            # Split text into sentences
            import re
            # Improved sentence splitting regex for scientific text
            sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', full_text)
            
            # Chunk sentences
            current_chunk = []
            current_chunk_size = 0
            chunk_index = 1  # Start after abstract
            
            for sentence in sentences:
                sentence = sentence.strip()
                if not sentence:
                    continue
                
                sentence_size = len(sentence)
                
                # If this sentence would make the chunk too big, finalize current chunk
                if current_chunk and current_chunk_size + sentence_size > self.chunk_size:
                    # Create chunk
                    chunk_text = " ".join(current_chunk)
                    
                    # Generate metadata
                    chunk_metadata = {
                        'paper_id': paper_id,
                        'title': title,
                        'chunk_index': chunk_index,
                        'topics': topics,
                        'contains_abstract': False
                    }
                    
                    # Add current section if we can detect it
                    current_section = self.detect_section(chunk_text)
                    if current_section:
                        chunk_metadata['section'] = current_section
                    
                    # Create the enhanced chunk with header
                    chunk_header = f"{header_template}CHUNK: {chunk_index}\n"
                    if current_section:
                        chunk_header += f"SECTION: {current_section}\n"
                    chunk_header += f"\n"
                    
                    # Create document
                    chunk_doc = Document(
                        page_content=chunk_header + chunk_text,
                        metadata=chunk_metadata
                    )
                    chunks.append(chunk_doc)
                    
                    # Start new chunk with overlap
                    overlap_size = 0
                    overlap_chunk = []
                    
                    # Create overlap with previous sentences
                    for prev_sentence in reversed(current_chunk):
                        if overlap_size + len(prev_sentence) <= self.chunk_overlap:
                            overlap_chunk.insert(0, prev_sentence)
                            overlap_size += len(prev_sentence)
                        else:
                            break
                    
                    current_chunk = overlap_chunk
                    current_chunk_size = overlap_size
                    chunk_index += 1
                
                # Add current sentence to chunk
                current_chunk.append(sentence)
                current_chunk_size += sentence_size
            
            # Add the last chunk if not empty
            if current_chunk:
                chunk_text = " ".join(current_chunk)
                
                # Generate metadata
                chunk_metadata = {
                    'paper_id': paper_id,
                    'title': title,
                    'chunk_index': chunk_index,
                    'topics': topics,
                    'contains_abstract': False
                }
                
                # Add current section if we can detect it
                current_section = self.detect_section(chunk_text)
                if current_section:
                    chunk_metadata['section'] = current_section
                
                # Create the enhanced chunk with header
                chunk_header = f"{header_template}CHUNK: {chunk_index}\n"
                if current_section:
                    chunk_header += f"SECTION: {current_section}\n"
                chunk_header += f"\n"
                
                # Create document
                chunk_doc = Document(
                    page_content=chunk_header + chunk_text,
                    metadata=chunk_metadata
                )
                chunks.append(chunk_doc)
        
        return chunks
    
    def detect_section(self, text):
        """Attempt to detect which section this chunk belongs to"""
        # Simple rule-based approach - could be enhanced with ML
        text_lower = text.lower()
        
        # Common section keywords
        if "introduction" in text_lower[:200]:
            return "Introduction"
        elif "method" in text_lower[:200] or "methodology" in text_lower[:200]:
            return "Methods"
        elif "result" in text_lower[:200]:
            return "Results"
        elif "discussion" in text_lower[:200]:
            return "Discussion"
        elif "conclusion" in text_lower[:200]:
            return "Conclusion"
        elif "reference" in text_lower[:200] or "bibliography" in text_lower[:200]:
            return "References"
        
        return None

In [6]:
## 3. Preparing Documents for LangChain

def prepare_documents(user_dataset, corpus_dict):
    """Process papers and generate context-enhanced chunks"""
    # Create a set of all paper IDs in the user dataset
    paper_ids_set = set()
    for item in user_dataset:
        paper_id = item.get('paper_id')
        if paper_id:
            paper_ids_set.add(paper_id)
    
    print(f"Found {len(paper_ids_set)} unique papers in user dataset")
    
    # Initialize chunker
    chunker = ScientificPaperChunker(chunk_size=1000, chunk_overlap=250)
    
    # Track statistics
    papers_with_full_text = 0
    papers_with_abstract_only = 0
    total_chunks = 0
    
    # Store all chunks
    all_chunks = []
    
    # Process each paper
    for paper_id in tqdm(paper_ids_set):
        # Check if paper exists in corpus with full text
        if paper_id in corpus_dict and corpus_dict[paper_id]['full_text']:
            # Use full paper from corpus
            paper_content = corpus_dict[paper_id]
            papers_with_full_text += 1
        else:
            # Fallback to title and abstract from user dataset
            paper_data = next((item for item in user_dataset if item.get('paper_id') == paper_id), None)
            if not paper_data:
                continue
                
            paper_content = {
                'title': paper_data.get('paper_title', ''),
                'abstract': paper_data.get('paper_abstract', ''),
                'full_text': ''  # No full text available
            }
            papers_with_abstract_only += 1
        
        # Chunk the paper
        paper_chunks = chunker.chunk_paper(paper_id, paper_content)
        all_chunks.extend(paper_chunks)
        total_chunks += len(paper_chunks)
    
    print(f"Created {total_chunks} chunks from {len(paper_ids_set)} papers")
    print(f"  - {papers_with_full_text} papers with full text")
    print(f"  - {papers_with_abstract_only} papers with abstract only")
    print(f"  - Average {total_chunks / len(paper_ids_set):.1f} chunks per paper")
    
    return all_chunks

# Create the ground truth lookup for evaluation
def create_ground_truth_lookup(dataset):
    ground_truth_lookup = {}
    for item in dataset:
        question = item.get('conceptual_question')
        if question:
            ground_truth_lookup[question] = {
                'paper_id': item.get('paper_id'),
                'answer': item.get('ground_truth_answer')
            }
    return ground_truth_lookup

# Process documents and create chunks
chunks = prepare_documents(user_dataset, corpus_dict)
ground_truth_lookup = create_ground_truth_lookup(user_dataset)

Found 49 unique papers in user dataset


100%|██████████| 49/49 [00:00<00:00, 208.42it/s]

Created 4201 chunks from 49 papers
  - 49 papers with full text
  - 0 papers with abstract only
  - Average 85.7 chunks per paper





In [None]:
## 4. Create Retrievers

# 1. BM25 Retriever
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 5  # Retrieve more chunks since we're using chunking

# 2. Dense Retriever with FAISS
# Use a scientific-specific embedding model if possible
embedding_model = "allenai/scibert_scivocab_uncased"  # Alternative: "allenai/specter2"
embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
vectorstore = FAISS.from_documents(chunks, embeddings)
dense_retriever = vectorstore.as_retriever(search_kwargs={"k": 20})

# 3. Ensemble Retriever (combines BM25 and Dense)
ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, dense_retriever],
    weights=[0.3, 0.7]  # Give higher weight to dense retrieval for scientific text
)



In [None]:
## 5. Create RAG Chains with LangChain

# Define a more specific prompt for scientific questions
prompt_template = """You are a helpful scientific assistant with expertise in research papers. Use the following pieces of context to answer the question at the end. 

The context contains information from scientific papers including titles, authors, and content. Use this information to provide a comprehensive and accurate answer.

If you don't know the answer based on the given context, just say that you don't have enough information, don't try to make up an answer.

Context:
{context}

Question: {question}
Answer:"""

PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

# Create LLM
llm = ChatOpenAI(temperature=0, 
                 model_name="gpt-4o", 
                 base_url="https://cmu.litellm.ai")

# Create RetrievalQA chains for each retriever
def create_qa_chain(retriever, llm=llm, prompt=PROMPT):
    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
        chain_type_kwargs={"prompt": prompt}
    )

# Create the chains
bm25_chain = create_qa_chain(bm25_retriever)
dense_chain = create_qa_chain(dense_retriever)
ensemble_chain = create_qa_chain(ensemble_retriever)

# Zero-shot chain (no retrieval)
zero_shot_template = """You are a helpful scientific assistant with expertise in machine learning, AI, and computer science research. Answer the following question based on your knowledge.

Question: {question}
Answer:"""

ZERO_SHOT_PROMPT = PromptTemplate(
    template=zero_shot_template,
    input_variables=["question"]
)


In [None]:
## 6. Run Experiments with Enhanced Evaluation

def evaluate_retrieval(retrieved_docs, ground_truth_paper_id):
    """
    Evaluate retrieval performance
    
    Args:
        retrieved_docs: List of retrieved documents
        ground_truth_paper_id: ID of the ground truth paper
        
    Returns:
        Dictionary with retrieval metrics
    """
    # Extract paper IDs from chunks
    retrieved_paper_ids = set()
    for doc in retrieved_docs:
        paper_id = doc.metadata.get('paper_id')
        if paper_id:
            retrieved_paper_ids.add(paper_id)
    
    # Convert to list for indexing
    retrieved_paper_ids_list = list(retrieved_paper_ids)
    
    # Check if ground truth is in retrieved docs
    found = ground_truth_paper_id in retrieved_paper_ids_list
    
    # Calculate MRR (Mean Reciprocal Rank)
    if found:
        rank = retrieved_paper_ids_list.index(ground_truth_paper_id)
        mrr = 1.0 / (rank + 1)
    else:
        mrr = 0.0
    
    # Calculate precision, recall (in this case they're the same since we have 1 relevant doc)
    precision = 1.0 if found else 0.0
    
    return {
        "found": found,
        "mrr": mrr,
        "precision": precision,
        "retrieved_paper_ids": retrieved_paper_ids_list
    }

def evaluate_answer(generated, ground_truth):
    """
    Evaluate answer quality using ROUGE and BERTScore
    
    Args:
        generated: Generated answer
        ground_truth: Ground truth answer
        
    Returns:
        Dictionary with evaluation metrics
    """
    # ROUGE scores
    rouge_scores = rouge_scorer_instance.score(ground_truth, generated)
    
    metrics = {
        "rouge1": rouge_scores["rouge1"].fmeasure,
        "rouge2": rouge_scores["rouge2"].fmeasure,
        "rougeL": rouge_scores["rougeL"].fmeasure
    }
    
    # BERTScore with SciBERT if available
    try:
        # Compute BERTScore
        P, R, F1 = bert_score(
            [generated], 
            [ground_truth], 
            model_type="allenai/scibert_scivocab_uncased",
            lang="en",
            verbose=False
        )
        
        # Add to metrics
        metrics.update({
            "bertscore_precision": P.item(),
            "bertscore_recall": R.item(),
            "bertscore_f1": F1.item()
        })
    except Exception as e:
        print(f"Warning: BERTScore calculation failed: {e}")
        print("Continuing without BERTScore. Install bert_score package for complete evaluation.")
    
    return metrics

def run_experiment(chain, question, ground_truth_info):
    """Run experiment with a specific chain"""
    # Track token usage
    with get_openai_callback() as cb:
        result = chain({"query": question})
    
    # Extract answer and retrieved documents
    answer = result.get("result", "")
    source_docs = result.get("source_documents", [])
    
    # Evaluate retrieval if we have source documents
    retrieval_metrics = {}
    if source_docs:
        retrieval_metrics = evaluate_retrieval(source_docs, ground_truth_info['paper_id'])
    
    # Evaluate answer
    answer_metrics = evaluate_answer(answer, ground_truth_info['answer'])
    
    return {
        "answer": answer,
        "retrieval_metrics": retrieval_metrics,
        "answer_metrics": answer_metrics,
        "token_usage": {
            "prompt_tokens": cb.prompt_tokens,
            "completion_tokens": cb.completion_tokens,
            "total_tokens": cb.total_tokens,
            "cost": cb.total_cost
        }
    }

def run_zero_shot(llm, question, ground_truth_info):
    """Run zero-shot experiment (no retrieval)"""
    # Track token usage
    with get_openai_callback() as cb:
        # Format prompt for LLM
        formatted_prompt = ZERO_SHOT_PROMPT.format(question=question)
        # Use predict instead of generate
        answer = llm.predict(formatted_prompt)
    
    # Evaluate answer
    answer_metrics = evaluate_answer(answer, ground_truth_info['answer'])
    
    return {
        "answer": answer,
        "answer_metrics": answer_metrics,
        "token_usage": {
            "prompt_tokens": cb.prompt_tokens,
            "completion_tokens": cb.completion_tokens,
            "total_tokens": cb.total_tokens,
            "cost": cb.total_cost
        }
    }




In [None]:
## 7. Run All Baselines

def run_all_baselines(dataset, num_samples=None):
    """Run all baseline methods"""
    if num_samples and num_samples < len(dataset):
        samples = dataset[:num_samples]
    else:
        samples = dataset
    
    results = {
        "bm25": [],
        "dense": [],
        "ensemble": [],
        "zero_shot": [],
        "summary": {}
    }
    
    print(f"Running baselines on {len(samples)} samples...")
    for i, sample in enumerate(tqdm(samples)):
        question = sample.get('conceptual_question')
        # Skip if question is missing
        if not question:
            continue
            
        # Get ground truth info
        ground_truth_info = ground_truth_lookup.get(question, {})
        if not ground_truth_info:
            print(f"Warning: No ground truth found for question: {question[:50]}...")
            continue
        
        print(f"\nProcessing question {i+1}/{len(samples)}: {question[:100]}...")
        
        # Run BM25 experiment
        print("Running BM25 + LLM...")
        bm25_result = run_experiment(bm25_chain, question, ground_truth_info)
        results["bm25"].append({
            "question": question,
            "ground_truth": ground_truth_info['answer'],
            "answer": bm25_result["answer"],
            "retrieval_metrics": bm25_result["retrieval_metrics"],
            "answer_metrics": bm25_result["answer_metrics"],
            "token_usage": bm25_result["token_usage"]
        })
        
        # Run Dense experiment
        print("Running Dense + LLM...")
        dense_result = run_experiment(dense_chain, question, ground_truth_info)
        results["dense"].append({
            "question": question,
            "ground_truth": ground_truth_info['answer'],
            "answer": dense_result["answer"],
            "retrieval_metrics": dense_result["retrieval_metrics"],
            "answer_metrics": dense_result["answer_metrics"],
            "token_usage": dense_result["token_usage"]
        })
        
        # Run Ensemble experiment
        print("Running Ensemble + LLM...")
        ensemble_result = run_experiment(ensemble_chain, question, ground_truth_info)
        results["ensemble"].append({
            "question": question,
            "ground_truth": ground_truth_info['answer'],
            "answer": ensemble_result["answer"],
            "retrieval_metrics": ensemble_result["retrieval_metrics"],
            "answer_metrics": ensemble_result["answer_metrics"],
            "token_usage": ensemble_result["token_usage"]
        })
        
        # Run Zero-shot experiment
        print("Running Zero-shot...")
        zero_shot_result = run_zero_shot(llm, question, ground_truth_info)
        results["zero_shot"].append({
            "question": question,
            "ground_truth": ground_truth_info['answer'],
            "answer": zero_shot_result["answer"],
            "answer_metrics": zero_shot_result["answer_metrics"],
            "token_usage": zero_shot_result["token_usage"]
        })
    
    # Calculate summary metrics
    calculate_summary_metrics(results)
    
    return results

def calculate_summary_metrics(results):
    """Calculate summary metrics for all methods"""
    methods = ["bm25", "dense", "ensemble", "zero_shot"]
    summary = {}
    
    for method in methods:
        method_results = results[method]
        
        # Skip if no results
        if not method_results:
            continue
            
        method_summary = {
            "answer_metrics": {
                "rouge1": 0.0,
                "rouge2": 0.0,
                "rougeL": 0.0
            }
        }
        
        # Add BERTScore if available in the first result
        if "bertscore_f1" in method_results[0]["answer_metrics"]:
            method_summary["answer_metrics"].update({
                "bertscore_precision": 0.0,
                "bertscore_recall": 0.0,
                "bertscore_f1": 0.0
            })
        
        # Add retrieval metrics for retrieval-based methods
        if method != "zero_shot":
            method_summary["retrieval_metrics"] = {
                "found_rate": 0.0,
                "mrr": 0.0,
                "precision": 0.0
            }
        
        # Calculate answer metrics
        for result in method_results:
            for metric in method_summary["answer_metrics"]:
                if metric in result["answer_metrics"]:
                    method_summary["answer_metrics"][metric] += result["answer_metrics"][metric]
        
        # Calculate retrieval metrics
        if method != "zero_shot":
            for result in method_results:
                method_summary["retrieval_metrics"]["found_rate"] += 1 if result["retrieval_metrics"].get("found", False) else 0
                method_summary["retrieval_metrics"]["mrr"] += result["retrieval_metrics"].get("mrr", 0.0)
                method_summary["retrieval_metrics"]["precision"] += result["retrieval_metrics"].get("precision", 0.0)
        
        # Calculate averages
        n = len(method_results)
        for metric in method_summary["answer_metrics"]:
            method_summary["answer_metrics"][metric] /= n
        
        if method != "zero_shot":
            for metric in method_summary["retrieval_metrics"]:
                method_summary["retrieval_metrics"][metric] /= n
        
        summary[method] = method_summary
    
    results["summary"] = summary
    return summary

# Run all baselines with a small sample first (adjust as needed)
num_samples = 5  # Start small for testing
results = run_all_baselines(user_dataset, num_samples=num_samples)

In [None]:
## 8. Analyze Results

def print_summary(results):
    """Print summary of results"""
    summary = results["summary"]
    
    print("\nResults Summary:")
    print("===============")
    
    methods = ["bm25", "dense", "ensemble", "zero_shot"]
    for method in methods:
        if method not in summary:
            continue
            
        method_summary = summary[method]
        print(f"\n{method.upper()}:")
        
        # Print retrieval metrics
        if "retrieval_metrics" in method_summary:
            print(f"  Retrieval Success Rate: {method_summary['retrieval_metrics']['found_rate']:.4f}")
            print(f"  MRR: {method_summary['retrieval_metrics']['mrr']:.4f}")
            print(f"  Precision: {method_summary['retrieval_metrics']['precision']:.4f}")
        
        # Print answer metrics
        print(f"  ROUGE-1: {method_summary['answer_metrics']['rouge1']:.4f}")
        print(f"  ROUGE-2: {method_summary['answer_metrics']['rouge2']:.4f}")
        print(f"  ROUGE-L: {method_summary['answer_metrics']['rougeL']:.4f}")
        
        # Print BERTScore if available
        if "bertscore_f1" in method_summary["answer_metrics"]:
            print(f"  BERTScore F1: {method_summary['answer_metrics']['bertscore_f1']:.4f}")

# Print summary
print_summary(results)

In [None]:
## 9. Visualize Results

def plot_results(results):
    """Plot comparison of different methods"""
    summary = results["summary"]
    methods = ["bm25", "dense", "ensemble", "zero_shot"]
    
    # Filter methods that have results
    methods = [method for method in methods if method in summary]
    
    # Data for plotting
    metrics = {
        "ROUGE-L": [summary[method]["answer_metrics"]["rougeL"] for method in methods],
        "Found Rate": [summary[method]["retrieval_metrics"]["found_rate"] if "retrieval_metrics" in summary[method] else 0 for method in methods],
        "MRR": [summary[method]["retrieval_metrics"]["mrr"] if "retrieval_metrics" in summary[method] else 0 for method in methods]
    }
    
    # Add BERTScore if available
    if "bertscore_f1" in summary[methods[0]]["answer_metrics"]:
        metrics["BERTScore F1"] = [summary[method]["answer_metrics"]["bertscore_f1"] for method in methods]
    
    # Create figure with subplots
    fig, axes = plt.subplots(1, len(metrics), figsize=(5*len(metrics), 6))
    if len(metrics) == 1:
        axes = [axes]
    
    # Plot each metric
    for i, (metric_name, metric_values) in enumerate(metrics.items()):
        axes[i].bar(methods, metric_values)
        axes[i].set_title(metric_name)
        axes[i].set_ylim(0, 1)
        
        # Add value labels
        for j, v in enumerate(metric_values):
            axes[i].text(j, v + 0.02, f"{v:.3f}", ha='center')
    
    plt.tight_layout()
    plt.show()

# Plot results
plot_results(results)

## 10. Save Results

def save_results(results, output_file="enhanced_scientific_rag_results.json"):
    """Save results to JSON file"""
    # Convert numpy values to Python types for JSON serialization
    def convert_for_json(obj):
        if isinstance(obj, np.float32) or isinstance(obj, np.float64):
            return float(obj)
        elif isinstance(obj, np.int32) or isinstance(obj, np.int64):
            return int(obj)
        elif isinstance(obj, list):
            return [convert_for_json(item) for item in obj]
        elif isinstance(obj, dict):
            return {key: convert_for_json(value) for key, value in obj.items()}
        else:
            return obj
    
    converted_results = convert_for_json(results)
    
    with open(output_file, 'w') as f:
        json.dump(converted_results, f, indent=2)
    
    print(f"Results saved to {output_file}")

# Save results
save_results(results)

## 11. Run Full Experiment

# Uncomment to run on more samples
# results = run_all_baselines(user_dataset, num_samples=50)  # Adjust as needed
# print_summary(results)
# plot_results(results)
# save_results(results, "enhanced_scientific_rag_results_full.json")