# Multi-Hop Question Answering with stackelberg-opt

This notebook demonstrates how to build and optimize a multi-hop question answering system using Stackelberg game theory.

## Overview

Multi-hop QA requires multiple reasoning steps to answer complex questions. We'll create a system where:
- **Leader module**: Generates initial search queries
- **Follower modules**: Retrieve context and generate follow-up queries
- **Independent module**: Synthesizes the final answer

In [None]:
# Import required libraries
from stackelberg_opt import (
    Module, ModuleType, SystemCandidate, ExecutionTrace,
    StackelbergOptimizer, OptimizerConfig
)
from stackelberg_opt.components import (
    StackelbergFeedbackExtractor,
    DependencyAnalyzer
)
from stackelberg_opt.utils import OptimizationVisualizer

import asyncio
import json
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

print("Libraries imported successfully!")

## 1. Define the Multi-Hop QA System

We'll create a system with strategic interactions between modules:

In [None]:
def create_multi_hop_qa_modules() -> Dict[str, Module]:
    """Create modules for multi-hop question answering."""
    
    modules = {
        # Leader: Initial query generation
        "query_generator": Module(
            name="query_generator",
            prompt="""Given a complex question that requires multiple steps to answer,
generate the FIRST search query to find relevant information.

Question: {question}

Consider what foundational information is needed first.
Output only the search query, nothing else.

Search query:""",
            module_type=ModuleType.LEADER,
            dependencies=[]
        ),
        
        # Follower 1: Context retrieval
        "context_retriever": Module(
            name="context_retriever",
            prompt="""Retrieve and summarize information for this search query.

Query: {query}
Original Question: {question}

Provide a focused summary of information that would be found for this query.
Focus on facts relevant to answering the original question.

Retrieved information:""",
            module_type=ModuleType.FOLLOWER,
            dependencies=["query_generator"]
        ),
        
        # Follower 2: Follow-up query generation
        "followup_generator": Module(
            name="followup_generator",
            prompt="""Based on the initial information, determine if more data is needed.

Question: {question}
Initial Query: {initial_query}
Information Found: {context}

If the information is sufficient to answer the question, output: NONE
Otherwise, generate ONE follow-up search query for missing information.

Follow-up query:""",
            module_type=ModuleType.FOLLOWER,
            dependencies=["query_generator", "context_retriever"]
        ),
        
        # Follower 3: Additional context retrieval
        "followup_retriever": Module(
            name="followup_retriever",
            prompt="""Retrieve information for the follow-up query.

Follow-up Query: {followup_query}
Previous Context: {previous_context}
Original Question: {question}

If the query is 'NONE', output: 'No additional information needed.'
Otherwise, provide relevant information for the follow-up query.

Additional information:""",
            module_type=ModuleType.FOLLOWER,
            dependencies=["followup_generator", "context_retriever"]
        ),
        
        # Independent: Answer synthesis
        "answer_synthesizer": Module(
            name="answer_synthesizer",
            prompt="""Synthesize a comprehensive answer from all gathered information.

Question: {question}
All Information:
{all_information}

Provide a clear, accurate, and complete answer that addresses all aspects of the question.
Be concise but thorough.

Answer:""",
            module_type=ModuleType.INDEPENDENT,
            dependencies=["context_retriever", "followup_retriever"]
        )
    }
    
    return modules

# Create the modules
qa_modules = create_multi_hop_qa_modules()

print(f"Created {len(qa_modules)} modules for multi-hop QA:")
for name, module in qa_modules.items():
    deps = f" (depends on: {', '.join(module.dependencies)})" if module.dependencies else ""
    print(f"  - {name} ({module.module_type.value}){deps}")

## 2. Analyze Module Dependencies

Let's visualize the dependency structure of our system:

In [None]:
# Analyze dependencies
analyzer = DependencyAnalyzer()
dep_analysis = analyzer.analyze_dependencies(qa_modules)

print("Dependency Analysis:")
print("=" * 50)
print(f"Is DAG (Directed Acyclic Graph): {dep_analysis['properties']['is_dag']}")
print(f"Has cycles: {dep_analysis['properties']['has_cycles']}")
print(f"Max depth: {dep_analysis['properties']['max_depth']}")

print("\nTopological order:")
for i, module in enumerate(dep_analysis['properties']['topological_order']):
    print(f"  {i+1}. {module}")

# Visualize with a simple text diagram
print("\nDependency Graph:")
print("query_generator (LEADER)")
print("    ├── context_retriever (FOLLOWER)")
print("    │   ├── followup_generator (FOLLOWER)")
print("    │   │   └── followup_retriever (FOLLOWER)")
print("    │   └── answer_synthesizer (INDEPENDENT)")
print("    └── followup_generator (connection)")

## 3. Create Training Data

Let's create diverse multi-hop questions for training:

In [None]:
# Create training data with multi-hop questions
train_data = [
    (
        "What impact did the invention of the printing press have on the Protestant Reformation?",
        "The printing press, invented by Gutenberg around 1440, dramatically accelerated the Protestant Reformation by enabling mass production of Luther's 95 Theses and Bible translations, spreading reformist ideas rapidly across Europe and undermining the Catholic Church's monopoly on religious texts."
    ),
    (
        "How do coral reefs protect coastal communities from climate change effects?",
        "Coral reefs act as natural barriers that reduce wave energy by up to 97%, protecting coastal communities from storm surges and erosion intensified by climate change, while also supporting fish populations that provide food security for millions of people."
    ),
    (
        "What role did the Silk Road play in the spread of the Black Death to Europe?",
        "The Silk Road served as the primary transmission route for the Black Death from Central Asia to Europe in the 1340s, as infected fleas on rats traveled with merchant caravans, reaching European ports through trade ships and causing the pandemic that killed one-third of Europe's population."
    ),
    (
        "How does quantum entanglement enable quantum computing to outperform classical computers?",
        "Quantum entanglement allows quantum computers to process information exponentially faster by enabling qubits to exist in superposition and be correlated across distances, allowing parallel processing of multiple calculations simultaneously, unlike classical bits that must be processed sequentially."
    ),
    (
        "What factors led to the fall of the Mayan civilization and are there parallels today?",
        "The Mayan civilization collapsed due to a combination of severe drought, deforestation, overpopulation, and political instability around 900 CE, paralleling modern concerns about climate change, resource depletion, and social inequality that threaten contemporary societies."
    )
]

print(f"Created {len(train_data)} multi-hop training examples\n")
print("Sample questions:")
for i, (question, _) in enumerate(train_data[:3]):
    print(f"{i+1}. {question}")

## 4. Implement Task Executor

Create a sophisticated task executor that simulates the multi-hop QA process:

In [None]:
class MultiHopQAExecutor:
    """Simulated executor for multi-hop question answering."""
    
    def __init__(self):
        self.execution_count = 0
        
    async def __call__(self, modules: Dict[str, Module], question: str) -> Tuple[str, ExecutionTrace]:
        """Execute the multi-hop QA pipeline."""
        import time
        import random
        
        trace = ExecutionTrace()
        trace.execution_order = []
        trace.module_outputs = {}
        trace.module_timings = {}
        trace.intermediate_scores = {}
        
        try:
            # Step 1: Generate initial query (Leader)
            start_time = time.time()
            initial_query = self._generate_query(question)
            trace.execution_order.append("query_generator")
            trace.module_outputs["query_generator"] = initial_query
            trace.module_timings["query_generator"] = time.time() - start_time
            trace.intermediate_scores["query_generator"] = 0.7 + random.random() * 0.3
            
            # Step 2: Retrieve initial context (Follower)
            start_time = time.time()
            context = self._retrieve_context(initial_query, question)
            trace.execution_order.append("context_retriever")
            trace.module_outputs["context_retriever"] = context
            trace.module_timings["context_retriever"] = time.time() - start_time
            trace.intermediate_scores["context_retriever"] = 0.6 + random.random() * 0.3
            
            # Step 3: Generate follow-up query (Follower)
            start_time = time.time()
            followup = self._generate_followup(question, initial_query, context)
            trace.execution_order.append("followup_generator")
            trace.module_outputs["followup_generator"] = followup
            trace.module_timings["followup_generator"] = time.time() - start_time
            trace.intermediate_scores["followup_generator"] = 0.65 + random.random() * 0.3
            
            # Step 4: Retrieve follow-up context (Follower)
            start_time = time.time()
            followup_context = self._retrieve_followup(followup, context, question)
            trace.execution_order.append("followup_retriever")
            trace.module_outputs["followup_retriever"] = followup_context
            trace.module_timings["followup_retriever"] = time.time() - start_time
            trace.intermediate_scores["followup_retriever"] = 0.6 + random.random() * 0.3
            
            # Step 5: Synthesize answer (Independent)
            start_time = time.time()
            all_info = f"Initial search: {initial_query}\n{context}\n\nFollow-up: {followup}\n{followup_context}"
            answer = self._synthesize_answer(question, all_info)
            trace.execution_order.append("answer_synthesizer")
            trace.module_outputs["answer_synthesizer"] = answer
            trace.module_timings["answer_synthesizer"] = time.time() - start_time
            trace.intermediate_scores["answer_synthesizer"] = 0.75 + random.random() * 0.25
            
            trace.success = True
            trace.final_score = np.mean(list(trace.intermediate_scores.values()))
            
            self.execution_count += 1
            return answer, trace
            
        except Exception as e:
            trace.success = False
            trace.error = str(e)
            trace.final_score = 0.0
            return "Error in execution", trace
    
    def _generate_query(self, question: str) -> str:
        """Simulate initial query generation."""
        # Extract key terms from question
        key_terms = [word for word in question.split() if len(word) > 4][:3]
        return f"{' '.join(key_terms)} definition history"
    
    def _retrieve_context(self, query: str, question: str) -> str:
        """Simulate context retrieval."""
        return f"Information about {query}: Historical background and key concepts relevant to the question."
    
    def _generate_followup(self, question: str, initial_query: str, context: str) -> str:
        """Simulate follow-up query generation."""
        if random.random() > 0.3:
            return f"specific impacts and connections related to {initial_query}"
        return "NONE"
    
    def _retrieve_followup(self, followup: str, context: str, question: str) -> str:
        """Simulate follow-up retrieval."""
        if followup == "NONE":
            return "No additional information needed."
        return f"Additional details about {followup}: Specific examples and evidence."
    
    def _synthesize_answer(self, question: str, all_info: str) -> str:
        """Simulate answer synthesis."""
        return f"Based on the gathered information, here is a comprehensive answer to '{question[:50]}...'"

# Create executor instance
qa_executor = MultiHopQAExecutor()
print("Multi-hop QA executor created")

## 5. Run Optimization

Now let's optimize the multi-hop QA system:

In [None]:
# Configure optimizer for multi-hop QA
qa_config = OptimizerConfig(
    budget=30,  # More budget for complex system
    population_size=8,
    mutation_rate=0.75,
    crossover_rate=0.25,
    performance_weight=0.4,
    equilibrium_weight=0.35,  # Important for leader-follower dynamics
    stability_weight=0.25,
    enable_caching=True,
    enable_visualization=True,
    verbose=True
)

print("Optimizer configuration for multi-hop QA:")
print(json.dumps({
    'budget': qa_config.budget,
    'population_size': qa_config.population_size,
    'weights': {
        'performance': qa_config.performance_weight,
        'equilibrium': qa_config.equilibrium_weight,
        'stability': qa_config.stability_weight
    }
}, indent=2))

In [None]:
# Create and run optimizer
qa_optimizer = StackelbergOptimizer(
    system_modules=qa_modules,
    train_data=train_data,
    task_executor=qa_executor,
    config=qa_config
)

print("Starting multi-hop QA optimization...")
print("=" * 60)

# Run optimization
best_qa_candidate = await qa_optimizer.optimize_async()

print("\n" + "=" * 60)
print("Optimization Complete!")
print("=" * 60)
print(f"Best candidate ID: {best_qa_candidate.candidate_id}")
print(f"Generation: {best_qa_candidate.generation}")
print(f"Average score: {best_qa_candidate.get_average_score():.3f}")
print(f"Equilibrium value: {best_qa_candidate.equilibrium_value:.3f}")
print(f"Stability score: {best_qa_candidate.stability_score:.3f}")

## 6. Analyze Optimization Results

Let's examine how the optimization improved the system:

In [None]:
# Compare original vs optimized prompts
print("Prompt Evolution Analysis")
print("=" * 60)

for module_name in qa_modules.keys():
    original_prompt = qa_modules[module_name].prompt
    optimized_prompt = best_qa_candidate.modules[module_name].prompt
    
    if original_prompt != optimized_prompt:
        print(f"\n{module_name}:")
        print(f"  Changed: YES")
        print(f"  Original length: {len(original_prompt)} chars")
        print(f"  Optimized length: {len(optimized_prompt)} chars")
        print(f"  Length change: {len(optimized_prompt) - len(original_prompt):+d} chars")
    else:
        print(f"\n{module_name}: No changes")

In [None]:
# Analyze module performance
if best_qa_candidate.traces:
    print("Module Performance Analysis")
    print("=" * 60)
    
    module_scores = {module: [] for module in qa_modules.keys()}
    
    for trace in best_qa_candidate.traces.values():
        for module, score in trace.intermediate_scores.items():
            module_scores[module].append(score)
    
    for module, scores in module_scores.items():
        if scores:
            avg_score = np.mean(scores)
            std_score = np.std(scores)
            module_type = qa_modules[module].module_type.value
            print(f"\n{module} ({module_type}):")
            print(f"  Average score: {avg_score:.3f}")
            print(f"  Std deviation: {std_score:.3f}")
            print(f"  Score range: [{min(scores):.3f}, {max(scores):.3f}]")

In [None]:
# Visualize optimization progress
visualizer = OptimizationVisualizer()

if hasattr(qa_optimizer, 'population_manager') and qa_optimizer.population_manager.generation_stats:
    # Create custom visualization
    stats = qa_optimizer.population_manager.generation_stats
    generations = sorted(stats.keys())
    
    avg_fitness = [stats[g].get('avg_fitness', 0) for g in generations]
    best_fitness = [stats[g].get('best_fitness', 0) for g in generations]
    diversity = [stats[g].get('diversity', 0) for g in generations]
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
    
    # Plot fitness
    ax1.plot(generations, avg_fitness, 'b-', label='Average Fitness', linewidth=2)
    ax1.plot(generations, best_fitness, 'r-', label='Best Fitness', linewidth=2)
    ax1.set_xlabel('Generation')
    ax1.set_ylabel('Fitness Score')
    ax1.set_title('Multi-Hop QA Optimization Progress')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot diversity
    ax2.plot(generations, diversity, 'g-', label='Population Diversity', linewidth=2)
    ax2.set_xlabel('Generation')
    ax2.set_ylabel('Diversity Score')
    ax2.set_title('Population Diversity Over Time')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No generation statistics available for visualization")

## 7. Test the Optimized System

Let's test the optimized multi-hop QA system with new questions:

In [None]:
# Test questions
test_questions = [
    "How did the discovery of DNA structure impact modern medicine?",
    "What role does the Amazon rainforest play in global climate regulation?",
    "How did the Industrial Revolution change social structures in Europe?"
]

print("Testing Optimized Multi-Hop QA System")
print("=" * 60)

for i, question in enumerate(test_questions):
    print(f"\nQuestion {i+1}: {question}")
    print("-" * 60)
    
    # Execute with optimized system
    answer, trace = await qa_executor(best_qa_candidate.modules, question)
    
    print(f"\nExecution trace:")
    for j, module in enumerate(trace.execution_order):
        output = trace.module_outputs.get(module, "N/A")[:100] + "..."
        score = trace.intermediate_scores.get(module, 0)
        print(f"  {j+1}. {module}: {score:.3f}")
        print(f"     Output: {output}")
    
    print(f"\nFinal Score: {trace.final_score:.3f}")
    print(f"Success: {trace.success}")

In [None]:
# Extract feedback for further improvement
feedback_extractor = StackelbergFeedbackExtractor()

print("Module-Specific Feedback Analysis")
print("=" * 60)

for module_name in qa_modules.keys():
    if best_qa_candidate.traces:
        feedback = feedback_extractor.extract_feedback(
            module_name,
            list(best_qa_candidate.traces.values()),
            best_qa_candidate.modules[module_name]
        )
        
        print(f"\n{module_name}:")
        print(f"  Average score: {feedback['avg_score']:.3f}")
        print(f"  Success rate: {feedback['success_rate']:.1%}")
        print(f"  Stability: {feedback['stability']:.3f}")
        
        if feedback['failure_patterns']:
            print(f"  Failure patterns: {', '.join(feedback['failure_patterns'][:2])}")
        if feedback['success_patterns']:
            print(f"  Success patterns: {', '.join(feedback['success_patterns'][:2])}")

## 8. Save and Export Results

Save the optimized system for future use:

In [None]:
# Export optimized prompts
optimized_prompts = {
    module_name: {
        'prompt': module.prompt,
        'type': module.module_type.value,
        'dependencies': module.dependencies
    }
    for module_name, module in best_qa_candidate.modules.items()
}

# Save to file
import json
with open('optimized_multihop_qa_prompts.json', 'w') as f:
    json.dump(optimized_prompts, f, indent=2)

print("Optimized prompts saved to 'optimized_multihop_qa_prompts.json'")

# Display summary
print("\nOptimization Summary:")
print(f"  Total evaluations: {qa_executor.execution_count}")
print(f"  Final average score: {best_qa_candidate.get_average_score():.3f}")
print(f"  Equilibrium value: {best_qa_candidate.equilibrium_value:.3f}")
print(f"  Stability score: {best_qa_candidate.stability_score:.3f}")

## Summary and Next Steps

In this notebook, we've:

1. **Built** a sophisticated multi-hop QA system with leader-follower dynamics
2. **Analyzed** the dependency structure and strategic interactions
3. **Optimized** the system using Stackelberg game theory
4. **Evaluated** the improvements in performance and stability
5. **Tested** the optimized system on new questions

### Key Insights

- The **leader module** (query_generator) significantly influences downstream performance
- **Follower modules** adapt their behavior based on leader outputs
- **Equilibrium optimization** helps balance module interactions
- **Stability metrics** ensure consistent performance across different inputs

### Next Steps

1. **Integrate with real LLMs** (e.g., using LiteLLM)
2. **Add more sophisticated retrieval** mechanisms
3. **Experiment with different module configurations**
4. **Fine-tune optimization parameters** for your use case
5. **Extend to other multi-step reasoning tasks**

For more examples and documentation, visit the [stackelberg-opt repository](https://github.com/youraanshshah/stackelberg-opt).