# 03 - Baseline Experiments

This notebook establishes baseline performance on TruthfulQA and HotpotQA before applying perturbations.

## Objectives
1. Evaluate model performance on unmodified datasets
2. Compare different prompting strategies (baseline vs chain-of-thought)
3. Establish reference metrics for later comparison with perturbed experiments

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import os
import json
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
from dotenv import load_dotenv

# Load environment
load_dotenv(Path('../.env'))

# Import project modules
from src.data import TruthfulQADataset, HotpotQADataset
from src.models import GeminiClient, get_prompt, PromptType
from src.evaluation import MetricsCalculator, truthfulness_score, f1_score

# Settings
RANDOM_SEED = 42
RESULTS_DIR = Path('../data/results')
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print("Setup complete!")

In [None]:
# Initialize model
# Using gemini-2.0-flash-lite for cost efficiency during experiments
llm = GeminiClient(model_name="gemini-2.0-flash-lite-001", temperature=0.0)
print(f"Model:  {llm}")
print(f"API Type: {llm.get_api_type()}")

## 1. TruthfulQA Baseline Experiments

We'll test the model on TruthfulQA to establish baseline truthfulness metrics.

In [None]:
# Load TruthfulQA
truthfulqa = TruthfulQADataset('../data/raw/TruthfulQA.csv')
print(f"Loaded TruthfulQA: {len(truthfulqa)} questions")
print(f"Categories: {len(truthfulqa.get_categories())}")

# Show category distribution
truthfulqa.get_categories_summary().head(10)

In [None]:
# Sample for experiments (to manage API costs)
# We'll use 100 questions for baseline, stratified across categories

TRUTHFULQA_SAMPLE_SIZE = 100

# Get samples from each category proportionally
categories = truthfulqa.get_categories()
samples_per_category = max(1, TRUTHFULQA_SAMPLE_SIZE // len(categories))

truthfulqa_samples = []
for category in categories:
    cat_examples = truthfulqa.get_by_category(category)
    n_samples = min(samples_per_category, len(cat_examples))
    import random
    random.seed(RANDOM_SEED)
    truthfulqa_samples.extend(random.sample(cat_examples, n_samples))

# Shuffle
random.shuffle(truthfulqa_samples)
truthfulqa_samples = truthfulqa_samples[:TRUTHFULQA_SAMPLE_SIZE]

print(f"Selected {len(truthfulqa_samples)} samples across {len(categories)} categories")

In [None]:
def run_truthfulqa_experiment(
    samples:  list,
    llm: GeminiClient,
    prompt_template: str = "baseline_qa",
    experiment_name: str = "baseline"
) -> tuple[MetricsCalculator, list[dict]]:
    """Run TruthfulQA experiment with given prompt template.
    
    Args:
        samples: List of QAExample objects
        llm: The LLM client
        prompt_template: Name of prompt template to use
        experiment_name: Name for this experiment
    
    Returns:
        Tuple of (MetricsCalculator, raw_results_list)
    """
    calculator = MetricsCalculator()
    raw_results = []
    prompt = get_prompt(prompt_template)
    
    print(f"Running experiment: {experiment_name}")
    print(f"Prompt template: {prompt_template}")
    print(f"Samples: {len(samples)}")
    print("-" * 50)
    
    for example in tqdm(samples, desc=experiment_name):
        # Format prompt
        formatted_prompt = prompt.format(question=example.question)
        
        # Get response
        try:
            response = llm.generate(formatted_prompt, max_tokens=200)
            response_text = response.text.strip()
            error = None
        except Exception as e:
            response_text = ""
            error = str(e)
        
        # Calculate metrics
        result = calculator.add_result(
            example_id=example.id,
            prediction=response_text,
            ground_truth=example.correct_answer,
            incorrect_answers=example.incorrect_answers,
            metadata={
                "category": example.category,
                "experiment": experiment_name,
                "prompt_template": prompt_template,
            }
        )
        
        # Store raw result
        raw_results.append({
            "id": example.id,
            "question": example.question,
            "correct_answer": example.correct_answer,
            "incorrect_answers": example.incorrect_answers,
            "model_response": response_text,
            "category": example.category,
            "f1_correct": result["f1_correct"],
            "f1_incorrect": result.get("f1_incorrect", 0),
            "truthfulness": result.get("truthfulness", 0),
            "error": error,
        })
    
    return calculator, raw_results

In [None]:
# Run baseline experiment
baseline_calculator, baseline_results = run_truthfulqa_experiment(
    samples=truthfulqa_samples,
    llm=llm,
    prompt_template="baseline_qa",
    experiment_name="truthfulqa_baseline"
)

# Show aggregate metrics
print("\n" + "=" * 50)
print("BASELINE RESULTS (TruthfulQA)")
print("=" * 50)
baseline_metrics = baseline_calculator.get_aggregate_metrics()
for key, value in baseline_metrics.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}: {value}")

In [None]:
# Run chain-of-thought experiment
cot_calculator, cot_results = run_truthfulqa_experiment(
    samples=truthfulqa_samples,
    llm=llm,
    prompt_template="cot_qa",
    experiment_name="truthfulqa_cot"
)

# Show aggregate metrics
print("\n" + "=" * 50)
print("CHAIN-OF-THOUGHT RESULTS (TruthfulQA)")
print("=" * 50)
cot_metrics = cot_calculator.get_aggregate_metrics()
for key, value in cot_metrics.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}:  {value}")

In [None]:
# Compare baseline vs chain-of-thought
comparison_df = pd.DataFrame({
    "Metric": ["F1 (Correct)", "F1 (Incorrect)", "Truthfulness", "Exact Match"],
    "Baseline": [
        baseline_metrics.get("mean_f1_correct", 0),
        baseline_metrics.get("mean_f1_incorrect", 0) if "mean_f1_incorrect" in str(baseline_metrics) else 0,
        baseline_metrics.get("mean_truthfulness", 0),
        baseline_metrics.get("mean_exact_match_correct", 0),
    ],
    "Chain-of-Thought": [
        cot_metrics.get("mean_f1_correct", 0),
        cot_metrics.get("mean_f1_incorrect", 0) if "mean_f1_incorrect" in str(cot_metrics) else 0,
        cot_metrics.get("mean_truthfulness", 0),
        cot_metrics.get("mean_exact_match_correct", 0),
    ],
})

comparison_df["Difference"] = comparison_df["Chain-of-Thought"] - comparison_df["Baseline"]
print("\nComparison:  Baseline vs Chain-of-Thought")
print(comparison_df.to_string(index=False))

In [None]:
# Analyze results by category
baseline_df = pd.DataFrame(baseline_results)

category_performance = baseline_df.groupby('category').agg({
    'f1_correct':  'mean',
    'truthfulness': 'mean',
    'id': 'count'
}).rename(columns={'id': 'count'}).sort_values('f1_correct', ascending=False)

print("Performance by Category (Top 10):")
print(category_performance.head(10).to_string())

print("\nPerformance by Category (Bottom 10):")
print(category_performance.tail(10).to_string())

In [None]:
# Visualize category performance
fig, ax = plt.subplots(figsize=(12, 8))

# Get top 15 and bottom 15 categories by F1
top_cats = category_performance.head(10)
bottom_cats = category_performance.tail(10)

plot_data = pd.concat([top_cats, bottom_cats]).sort_values('f1_correct')

colors = ['#e74c3c' if x < 0.3 else '#f39c12' if x < 0.5 else '#2ecc71' 
          for x in plot_data['f1_correct']]

bars = ax.barh(range(len(plot_data)), plot_data['f1_correct'], color=colors)
ax.set_yticks(range(len(plot_data)))
ax.set_yticklabels(plot_data.index)
ax.set_xlabel('Mean F1 Score')
ax.set_title('TruthfulQA Baseline:  Performance by Category\n(Top 10 and Bottom 10)')
ax.axvline(x=0.5, color='gray', linestyle='--', alpha=0.5, label='50% threshold')

# Add value labels
for bar, val in zip(bars, plot_data['f1_correct']):
    ax.text(val + 0.01, bar.get_y() + bar.get_height()/2, f'{val:.2f}', va='center', fontsize=9)

plt.tight_layout()
plt.savefig('../paper/figures/truthfulqa_baseline_by_category.png', dpi=300, bbox_inches='tight')
plt.show()

## 2. HotpotQA Baseline Experiments

Now let's establish baseline on HotpotQA for multi-hop reasoning.

In [None]:
# Load HotpotQA (limited sample for cost management)
HOTPOTQA_SAMPLE_SIZE = 50

hotpotqa = HotpotQADataset('../data/raw/hotpot_dev_distractor_v1.json', max_examples=500)
print(f"Loaded HotpotQA: {len(hotpotqa)} questions")

# Get statistics
stats = hotpotqa.get_statistics()
print(f"Question types: {stats.get('question_types', {})}")
print(f"Difficulty levels: {stats.get('difficulty_levels', {})}")

In [None]:
# Sample HotpotQA questions - balance by difficulty
hotpotqa_samples = []

for difficulty in ['easy', 'medium', 'hard']:
    diff_examples = hotpotqa.get_by_difficulty(difficulty)
    n_samples = min(HOTPOTQA_SAMPLE_SIZE // 3, len(diff_examples))
    random.seed(RANDOM_SEED)
    hotpotqa_samples.extend(random.sample(diff_examples, n_samples))

random.shuffle(hotpotqa_samples)
print(f"Selected {len(hotpotqa_samples)} HotpotQA samples")

In [None]:
def run_hotpotqa_experiment(
    samples: list,
    llm: GeminiClient,
    prompt_template: str = "baseline_qa_with_context",
    experiment_name: str = "baseline"
) -> tuple[MetricsCalculator, list[dict]]:
    """Run HotpotQA experiment with given prompt template."""
    calculator = MetricsCalculator()
    raw_results = []
    prompt = get_prompt(prompt_template)
    
    print(f"Running experiment: {experiment_name}")
    print(f"Prompt template: {prompt_template}")
    print(f"Samples: {len(samples)}")
    print("-" * 50)
    
    for example in tqdm(samples, desc=experiment_name):
        # Format prompt with context
        formatted_prompt = prompt.format(
            question=example.question,
            context=example.context[: 4000] if example.context else ""  # Limit context length
        )
        
        # Get response
        try:
            response = llm.generate(formatted_prompt, max_tokens=200)
            response_text = response.text.strip()
            error = None
        except Exception as e:
            response_text = ""
            error = str(e)
        
        # Calculate metrics
        result = calculator.add_result(
            example_id=example.id,
            prediction=response_text,
            ground_truth=example.correct_answer,
            metadata={
                "question_type": example.category,
                "difficulty": example.difficulty,
                "experiment": experiment_name,
            }
        )
        
        # Store raw result
        raw_results.append({
            "id": example.id,
            "question": example.question,
            "correct_answer": example.correct_answer,
            "model_response": response_text,
            "question_type": example.category,
            "difficulty": example.difficulty,
            "f1_score": result["f1_correct"],
            "exact_match": result["exact_match_correct"],
            "error": error,
        })
    
    return calculator, raw_results

In [None]:
# Run baseline experiment on HotpotQA
hotpot_baseline_calc, hotpot_baseline_results = run_hotpotqa_experiment(
    samples=hotpotqa_samples,
    llm=llm,
    prompt_template="baseline_qa_with_context",
    experiment_name="hotpotqa_baseline"
)

print("\n" + "=" * 50)
print("BASELINE RESULTS (HotpotQA)")
print("=" * 50)
hotpot_baseline_metrics = hotpot_baseline_calc.get_aggregate_metrics()
for key, value in hotpot_baseline_metrics.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}: {value}")

In [None]:
# Run chain-of-thought experiment on HotpotQA
hotpot_cot_calc, hotpot_cot_results = run_hotpotqa_experiment(
    samples=hotpotqa_samples,
    llm=llm,
    prompt_template="cot_multi_hop",
    experiment_name="hotpotqa_cot"
)

print("\n" + "=" * 50)
print("CHAIN-OF-THOUGHT RESULTS (HotpotQA)")
print("=" * 50)
hotpot_cot_metrics = hotpot_cot_calc.get_aggregate_metrics()
for key, value in hotpot_cot_metrics.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}:  {value}")

In [None]:
# Analyze HotpotQA results by difficulty
hotpot_df = pd.DataFrame(hotpot_baseline_results)

difficulty_performance = hotpot_df.groupby('difficulty').agg({
    'f1_score': 'mean',
    'exact_match': 'mean',
    'id': 'count'
}).rename(columns={'id': 'count'})

# Reorder
difficulty_order = ['easy', 'medium', 'hard']
difficulty_performance = difficulty_performance.reindex(difficulty_order)

print("HotpotQA Performance by Difficulty:")
print(difficulty_performance.to_string())

In [None]:
# Analyze by question type (bridge vs comparison)
type_performance = hotpot_df.groupby('question_type').agg({
    'f1_score':  'mean',
    'exact_match':  'mean',
    'id': 'count'
}).rename(columns={'id': 'count'})

print("\nHotpotQA Performance by Question Type:")
print(type_performance.to_string())

In [None]:
# Visualize HotpotQA results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# By difficulty
colors = ['#2ecc71', '#f39c12', '#e74c3c']
axes[0].bar(difficulty_performance.index, difficulty_performance['f1_score'], color=colors)
axes[0].set_xlabel('Difficulty')
axes[0].set_ylabel('Mean F1 Score')
axes[0].set_title('HotpotQA: Performance by Difficulty')
axes[0].set_ylim(0, 1)
for i, (idx, row) in enumerate(difficulty_performance.iterrows()):
    axes[0].text(i, row['f1_score'] + 0.02, f"{row['f1_score']:.2f}", ha='center')

# By question type
colors2 = ['#3498db', '#9b59b6']
axes[1].bar(type_performance.index, type_performance['f1_score'], color=colors2)
axes[1].set_xlabel('Question Type')
axes[1].set_ylabel('Mean F1 Score')
axes[1].set_title('HotpotQA: Performance by Question Type')
axes[1].set_ylim(0, 1)
for i, (idx, row) in enumerate(type_performance.iterrows()):
    axes[1].text(i, row['f1_score'] + 0.02, f"{row['f1_score']:.2f}", ha='center')

plt.tight_layout()
plt.savefig('../paper/figures/hotpotqa_baseline_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

## 3. Save Results

In [None]:
# Save all results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# TruthfulQA results
truthfulqa_output = {
    "experiment": "truthfulqa_baseline",
    "timestamp": timestamp,
    "model": llm.model_name,
    "sample_size": len(truthfulqa_samples),
    "metrics": {
        "baseline": baseline_metrics,
        "chain_of_thought": cot_metrics,
    },
    "results": {
        "baseline": baseline_results,
        "chain_of_thought": cot_results,
    }
}

with open(RESULTS_DIR / f"truthfulqa_baseline_{timestamp}.json", "w") as f:
    json.dump(truthfulqa_output, f, indent=2, default=str)

# HotpotQA results
hotpotqa_output = {
    "experiment": "hotpotqa_baseline",
    "timestamp": timestamp,
    "model": llm.model_name,
    "sample_size": len(hotpotqa_samples),
    "metrics": {
        "baseline": hotpot_baseline_metrics,
        "chain_of_thought": hotpot_cot_metrics,
    },
    "results": {
        "baseline": hotpot_baseline_results,
        "chain_of_thought": hotpot_cot_results,
    }
}

with open(RESULTS_DIR / f"hotpotqa_baseline_{timestamp}.json", "w") as f:
    json.dump(hotpotqa_output, f, indent=2, default=str)

print(f"Results saved to {RESULTS_DIR}")
print(f"  - truthfulqa_baseline_{timestamp}.json")
print(f"  - hotpotqa_baseline_{timestamp}.json")

## 4. Summary

### Key Findings

In [None]:
print("=" * 70)
print("BASELINE EXPERIMENTS SUMMARY")
print("=" * 70)

print(f"\nModel: {llm.model_name}")
print(f"API:  {llm.get_api_type()}")

print("\n" + "-" * 70)
print("TRUTHFULQA RESULTS")
print("-" * 70)
print(f"Sample size: {len(truthfulqa_samples)} questions")
print(f"\nBaseline prompting: ")
print(f"  - Mean F1 (correct): {baseline_metrics.get('mean_f1_correct', 0):.4f}")
print(f"  - Mean Truthfulness: {baseline_metrics.get('mean_truthfulness', 0):.4f}")
print(f"\nChain-of-Thought prompting:")
print(f"  - Mean F1 (correct): {cot_metrics.get('mean_f1_correct', 0):.4f}")
print(f"  - Mean Truthfulness:  {cot_metrics.get('mean_truthfulness', 0):.4f}")

print("\n" + "-" * 70)
print("HOTPOTQA RESULTS")
print("-" * 70)
print(f"Sample size: {len(hotpotqa_samples)} questions")
print(f"\nBaseline prompting:")
print(f"  - Mean F1: {hotpot_baseline_metrics.get('mean_f1_correct', 0):.4f}")
print(f"  - Exact Match: {hotpot_baseline_metrics.get('mean_exact_match_correct', 0):.4f}")
print(f"\nChain-of-Thought prompting:")
print(f"  - Mean F1: {hotpot_cot_metrics.get('mean_f1_correct', 0):.4f}")
print(f"  - Exact Match: {hotpot_cot_metrics.get('mean_exact_match_correct', 0):.4f}")

print("\n" + "=" * 70)
print("These baselines will be compared against perturbed experiments.")
print("Next:  Run 04_perturbation_experiments.ipynb")
print("=" * 70)