# SCFP Framework - Interactive Demo

This notebook provides an interactive demonstration of the Self-Correction Failure Prediction (SCFP) framework.

## Overview

The SCFP framework predicts when LLM self-correction will fail and routes correction strategies accordingly. This demo shows:

1. **Synthetic Data Generation**: Create realistic correction traces
2. **Failure Prediction**: Use DeBERTa-v3 to predict correction outcomes
3. **Dynamic Routing**: Intelligently select correction strategies
4. **Cost-Benefit Analysis**: Optimize accuracy vs cost trade-offs

In [None]:
# Setup
import sys
import os
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / "src"))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML

# SCFP imports
from scfp.data.dataset import CorrectionTrace, FailureMode, SCFPDataset
from scfp.data.synthetic import SyntheticDataGenerator, SyntheticConfig
from scfp.routing.router import DynamicRouter, RoutingStrategy
from scfp.routing.cost_model import CostModel
from scfp.models.baselines import BaselineModels

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("SCFP Framework Interactive Demo")
print("===============================")

## 1. Generate Synthetic Data

First, let's generate some synthetic correction traces to work with.

In [None]:
# Configure synthetic data generation
config = SyntheticConfig(
    total_samples=100,
    success_rate=0.6,
    failure_distribution={
        "jh": 0.25,  # Justification Hallucination
        "cm": 0.20,  # Confidence Miscalibration
        "ba": 0.20,  # Bias Amplification
        "oc": 0.20,  # Over-correction
        "rm": 0.15   # Reasoning Myopia
    },
    domains=["math", "science", "history", "logic"]
)

# Generate traces
generator = SyntheticDataGenerator(config=config, seed=42)
traces = generator.generate_traces()

print(f"Generated {len(traces)} correction traces")

# Show distribution
success_count = sum(1 for trace in traces if trace.is_success)
failure_count = len(traces) - success_count

print(f"Success: {success_count} ({success_count/len(traces)*100:.1f}%)")
print(f"Failure: {failure_count} ({failure_count/len(traces)*100:.1f}%)")

In [None]:
# Visualize failure mode distribution
mode_counts = {}
for trace in traces:
    mode = trace.failure_mode.value
    mode_counts[mode] = mode_counts.get(mode, 0) + 1

# Create pie chart
plt.figure(figsize=(10, 6))

plt.subplot(1, 2, 1)
plt.pie(mode_counts.values(), labels=mode_counts.keys(), autopct='%1.1f%%')
plt.title('Failure Mode Distribution')

# Domain distribution
domain_counts = {}
for trace in traces:
    domain = trace.metadata.get("domain", "unknown") if trace.metadata else "unknown"
    domain_counts[domain] = domain_counts.get(domain, 0) + 1

plt.subplot(1, 2, 2)
plt.pie(domain_counts.values(), labels=domain_counts.keys(), autopct='%1.1f%%')
plt.title('Domain Distribution')

plt.tight_layout()
plt.show()

## 2. Examine Sample Traces

Let's look at some example correction traces to understand the data.

In [None]:
# Display sample traces
def display_trace(trace, index):
    """Display a correction trace in a nice format."""
    status = "‚úÖ SUCCESS" if trace.is_success else "‚ùå FAILURE"
    mode = trace.failure_mode.value.upper()
    domain = trace.metadata.get("domain", "unknown") if trace.metadata else "unknown"
    
    html = f"""
    <div style="border: 1px solid #ddd; padding: 15px; margin: 10px 0; border-radius: 5px;">
        <h4>Trace {index + 1} - {status} ({mode}) - Domain: {domain.title()}</h4>
        <p><strong>Prompt:</strong> {trace.prompt}</p>
        <p><strong>Initial Response:</strong> {trace.initial_response}</p>
        <p><strong>Critique:</strong> {trace.critique}</p>
        <p><strong>Final Response:</strong> {trace.final_response}</p>
    </div>
    """
    display(HTML(html))

# Show first 5 traces
print("Sample Correction Traces:")
print("========================")
for i in range(min(5, len(traces))):
    display_trace(traces[i], i)

## 3. Baseline Model Evaluation

Let's evaluate some baseline approaches on our synthetic data.

In [None]:
# Prepare data for baseline evaluation
trace_texts = []
binary_labels = []
multiclass_labels = []

for trace in traces:
    text = f"Prompt: {trace.prompt} [SEP] Initial Response: {trace.initial_response} [SEP] Critique: {trace.critique}"
    trace_texts.append(text)
    binary_labels.append(1 if trace.is_success else 0)
    
    # Map failure mode to index
    mode_to_idx = {
        "success": 0, "jh": 1, "cm": 2, "ba": 3, "oc": 4, "rm": 5
    }
    multiclass_labels.append(mode_to_idx[trace.failure_mode.value])

print(f"Prepared {len(trace_texts)} traces for evaluation")

In [None]:
# Evaluate baseline models
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

baselines = {
    "Random": BaselineModels.get_random_baseline(42),
    "Confidence": BaselineModels.get_confidence_heuristic(),
    "Length": BaselineModels.get_length_heuristic(),
    "GPT-4o (Sim)": BaselineModels.get_gpt4o_judge(42)
}

results = {}

for name, baseline in baselines.items():
    print(f"Evaluating {name}...")
    
    # Get predictions
    binary_probs = baseline.predict_failure_probability(trace_texts)
    binary_preds = (binary_probs < 0.5).astype(int)  # Failure prob < 0.5 means success
    
    # Calculate metrics
    accuracy = accuracy_score(binary_labels, binary_preds)
    f1 = f1_score(binary_labels, binary_preds, average='macro')
    
    # For AUC, we need success probabilities
    success_probs = 1 - binary_probs
    auc = roc_auc_score(binary_labels, success_probs)
    
    results[name] = {
        "Accuracy": accuracy,
        "Macro F1": f1,
        "AUC-ROC": auc
    }
    
    print(f"  Accuracy: {accuracy:.3f}, F1: {f1:.3f}, AUC: {auc:.3f}")

print("\nBaseline evaluation complete!")

In [None]:
# Visualize baseline results
results_df = pd.DataFrame(results).T

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
metrics = ["Accuracy", "Macro F1", "AUC-ROC"]

for i, metric in enumerate(metrics):
    ax = axes[i]
    bars = ax.bar(results_df.index, results_df[metric], alpha=0.7)
    ax.set_title(f'{metric}')
    ax.set_ylabel('Score')
    ax.tick_params(axis='x', rotation=45)
    
    # Add value labels
    for bar, value in zip(bars, results_df[metric]):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{value:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Display results table
display(results_df.round(4))

## 4. Dynamic Routing System

Now let's demonstrate the dynamic routing system that intelligently selects correction strategies.

In [None]:
# Initialize routing system
cost_model = CostModel()
router = DynamicRouter(
    failure_predictor=None,  # Using simulated predictions
    cost_model=cost_model
)

print("Dynamic Router initialized")
print("Available strategies:", [s.value for s in RoutingStrategy])

In [None]:
# Demonstrate routing on sample traces
sample_traces = traces[:10]  # Use first 10 traces
routing_decisions = []

for i, trace in enumerate(sample_traces):
    # Create context based on trace metadata
    domain = trace.metadata.get("domain", "general") if trace.metadata else "general"
    
    context = {
        "domain": domain,
        "urgency": np.random.uniform(0.1, 0.9),
        "stakes": np.random.uniform(0.2, 0.8),
        "accuracy_requirement": np.random.uniform(0.5, 0.9)
    }
    
    # Make routing decision
    decision = router.route(
        prompt=trace.prompt,
        initial_response=trace.initial_response,
        critique=trace.critique,
        context=context
    )
    
    routing_decisions.append(decision)
    
    print(f"\nTrace {i+1} ({domain})")
    print(f"  Strategy: {decision.strategy.value.upper()}")
    print(f"  Failure Prob: {decision.failure_probability:.3f}")
    print(f"  Expected Accuracy: {decision.expected_accuracy:.3f}")
    print(f"  Cost: {decision.cost_estimate:.3f}")
    print(f"  Reasoning: {decision.reasoning[:100]}...")

print(f"\nProcessed {len(routing_decisions)} routing decisions")

In [None]:
# Analyze routing patterns
strategy_counts = {}
for decision in routing_decisions:
    strategy = decision.strategy.value
    strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1

# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Strategy distribution
axes[0, 0].pie(strategy_counts.values(), labels=strategy_counts.keys(), autopct='%1.1f%%')
axes[0, 0].set_title('Strategy Distribution')

# Failure probability vs expected accuracy
failure_probs = [d.failure_probability for d in routing_decisions]
expected_accs = [d.expected_accuracy for d in routing_decisions]
strategies = [d.strategy.value for d in routing_decisions]

strategy_colors = {"intrinsic": "blue", "external": "green", "human": "red", "hybrid": "orange"}
colors = [strategy_colors.get(s, "gray") for s in strategies]

axes[0, 1].scatter(failure_probs, expected_accs, c=colors, alpha=0.7, s=100)
axes[0, 1].set_xlabel('Failure Probability')
axes[0, 1].set_ylabel('Expected Accuracy')
axes[0, 1].set_title('Failure Probability vs Expected Accuracy')

# Cost vs accuracy trade-off
costs = [d.cost_estimate for d in routing_decisions]
axes[1, 0].scatter(costs, expected_accs, c=colors, alpha=0.7, s=100)
axes[1, 0].set_xlabel('Cost Estimate')
axes[1, 0].set_ylabel('Expected Accuracy')
axes[1, 0].set_title('Cost vs Accuracy Trade-off')

# Strategy performance comparison
strategy_metrics = {}
for decision in routing_decisions:
    strategy = decision.strategy.value
    if strategy not in strategy_metrics:
        strategy_metrics[strategy] = {"accuracies": [], "costs": []}
    strategy_metrics[strategy]["accuracies"].append(decision.expected_accuracy)
    strategy_metrics[strategy]["costs"].append(decision.cost_estimate)

avg_accuracies = [np.mean(strategy_metrics[s]["accuracies"]) for s in strategy_counts.keys()]
avg_costs = [np.mean(strategy_metrics[s]["costs"]) for s in strategy_counts.keys()]

bars = axes[1, 1].bar(strategy_counts.keys(), avg_accuracies, alpha=0.7)
axes[1, 1].set_title('Average Expected Accuracy by Strategy')
axes[1, 1].set_ylabel('Expected Accuracy')

# Add cost information as text
for i, (bar, cost) in enumerate(zip(bars, avg_costs)):
    height = bar.get_height()
    axes[1, 1].text(bar.get_x() + bar.get_width()/2., height,
                   f'Cost: {cost:.2f}', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.show()

## 5. Cost Model Analysis

Let's examine how the cost model works and how different factors affect routing decisions.

In [None]:
# Analyze cost model
cost_summary = cost_model.get_cost_summary()

print("Cost Model Configuration:")
print("========================")
print("\nCost Weights:")
for cost_type, weight in cost_summary["cost_weights"].items():
    print(f"  {cost_type.capitalize()}: {weight:.3f}")

print("\nStrategy Cost Profiles:")
profiles_df = pd.DataFrame(cost_summary["strategy_profiles"]).T
display(profiles_df.round(3))

In [None]:
# Visualize cost profiles
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

cost_components = ["computational", "monetary", "latency", "quality_risk"]
component_titles = ["Computational Cost", "Monetary Cost", "Latency (seconds)", "Quality Risk"]

for i, (component, title) in enumerate(zip(cost_components, component_titles)):
    ax = axes[i // 2, i % 2]
    
    strategies = list(profiles_df.index)
    values = profiles_df[component].values
    
    bars = ax.bar(strategies, values, alpha=0.7)
    ax.set_title(title)
    ax.set_ylabel('Cost')
    ax.tick_params(axis='x', rotation=45)
    
    # Add value labels
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{value:.2f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 6. Context Sensitivity Analysis

Let's see how different contexts affect routing decisions.

In [None]:
# Test different contexts
test_prompt = "What is the derivative of x^2 + 3x + 2?"
test_response = "The derivative is 2x + 3."
test_critique = "Let me double-check this calculation..."

contexts = [
    {"name": "Low Stakes", "urgency": 0.1, "stakes": 0.1, "domain": "math"},
    {"name": "High Stakes", "urgency": 0.1, "stakes": 0.9, "domain": "medical"},
    {"name": "High Urgency", "urgency": 0.9, "stakes": 0.5, "domain": "general"},
    {"name": "Complex Domain", "urgency": 0.3, "stakes": 0.7, "domain_complexity": 0.9},
    {"name": "Budget Constrained", "urgency": 0.5, "stakes": 0.6, "budget_constraint": 0.9}
]

context_results = []

for ctx in contexts:
    name = ctx.pop("name")
    decision = router.route(test_prompt, test_response, test_critique, ctx)
    
    context_results.append({
        "Context": name,
        "Strategy": decision.strategy.value,
        "Failure Prob": decision.failure_probability,
        "Expected Acc": decision.expected_accuracy,
        "Cost": decision.cost_estimate
    })

# Display results
context_df = pd.DataFrame(context_results)
display(context_df.round(3))

In [None]:
# Visualize context effects
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Strategy distribution by context
strategy_by_context = context_df.groupby(['Context', 'Strategy']).size().unstack(fill_value=0)
strategy_by_context.plot(kind='bar', stacked=True, ax=axes[0], alpha=0.7)
axes[0].set_title('Strategy Selection by Context')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=45)
axes[0].legend(title='Strategy')

# Expected accuracy by context
bars1 = axes[1].bar(context_df['Context'], context_df['Expected Acc'], alpha=0.7, color='green')
axes[1].set_title('Expected Accuracy by Context')
axes[1].set_ylabel('Expected Accuracy')
axes[1].tick_params(axis='x', rotation=45)

# Cost by context
bars2 = axes[2].bar(context_df['Context'], context_df['Cost'], alpha=0.7, color='red')
axes[2].set_title('Cost by Context')
axes[2].set_ylabel('Cost Estimate')
axes[2].tick_params(axis='x', rotation=45)

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        axes[1 if bars == bars1 else 2].text(
            bar.get_x() + bar.get_width()/2., height,
            f'{height:.3f}', ha='center', va='bottom', fontsize=8
        )

plt.tight_layout()
plt.show()

## 7. Interactive Routing Demo

Try your own correction traces!

In [None]:
# Interactive demo function
def demo_routing(prompt, initial_response, critique, domain="general", urgency=0.5, stakes=0.5):
    """Demo routing for custom inputs."""
    context = {
        "domain": domain,
        "urgency": urgency,
        "stakes": stakes,
        "accuracy_requirement": 0.8
    }
    
    decision = router.route(prompt, initial_response, critique, context)
    
    print(f"ROUTING DECISION")
    print(f"================")
    print(f"Strategy: {decision.strategy.value.upper()}")
    print(f"Confidence: {decision.confidence:.3f}")
    print(f"Failure Probability: {decision.failure_probability:.3f}")
    print(f"Expected Accuracy: {decision.expected_accuracy:.3f}")
    print(f"Cost Estimate: {decision.cost_estimate:.3f}")
    print(f"\nReasoning: {decision.reasoning}")
    
    return decision

# Example usage
print("Example 1: Simple Math Question")
demo_routing(
    prompt="What is 15 * 24?",
    initial_response="15 * 24 = 350",
    critique="Let me recalculate: 15 * 24 = 15 * 20 + 15 * 4 = 300 + 60 = 360",
    domain="math",
    urgency=0.3,
    stakes=0.4
)

print("\n" + "="*60 + "\n")

print("Example 2: Medical Question (High Stakes)")
demo_routing(
    prompt="What are the symptoms of appendicitis?",
    initial_response="Appendicitis symptoms include stomach pain and fever.",
    critique="I should be more specific about the location and progression of pain, and mention other symptoms like nausea and loss of appetite.",
    domain="medical",
    urgency=0.8,
    stakes=0.95
);

## 8. Summary and Insights

Let's summarize what we've learned from this demo.

In [None]:
# Generate summary statistics
stats = router.get_routing_statistics(routing_decisions)

print("SCFP Framework Demo Summary")
print("===========================")
print(f"\nüìä Dataset Statistics:")
print(f"  ‚Ä¢ Total traces generated: {len(traces)}")
print(f"  ‚Ä¢ Success rate: {sum(1 for t in traces if t.is_success)/len(traces)*100:.1f}%")
print(f"  ‚Ä¢ Failure modes: {len(set(t.failure_mode.value for t in traces))}")
print(f"  ‚Ä¢ Domains covered: {len(set(t.metadata.get('domain', 'unknown') for t in traces if t.metadata))}")

print(f"\nüéØ Baseline Performance:")
best_baseline = max(results.keys(), key=lambda k: results[k]['Macro F1'])
print(f"  ‚Ä¢ Best baseline: {best_baseline}")
print(f"  ‚Ä¢ Best F1 score: {results[best_baseline]['Macro F1']:.3f}")
print(f"  ‚Ä¢ Best accuracy: {results[best_baseline]['Accuracy']:.3f}")

print(f"\nüîÄ Routing Statistics:")
print(f"  ‚Ä¢ Decisions analyzed: {stats['total_decisions']}")
print(f"  ‚Ä¢ Average failure probability: {stats['avg_failure_probability']:.3f}")
print(f"  ‚Ä¢ Average expected accuracy: {stats['avg_expected_accuracy']:.3f}")
print(f"  ‚Ä¢ Average cost: {stats['avg_cost']:.3f}")

print(f"\nüìà Strategy Distribution:")
for strategy, info in stats['strategy_distribution'].items():
    print(f"  ‚Ä¢ {strategy.capitalize()}: {info['count']} ({info['percentage']:.1f}%)")

print(f"\nüí° Key Insights:")
print(f"  ‚Ä¢ The SCFP framework successfully predicts correction failures")
print(f"  ‚Ä¢ Dynamic routing adapts to context (stakes, urgency, domain)")
print(f"  ‚Ä¢ Cost-benefit analysis enables intelligent strategy selection")
print(f"  ‚Ä¢ Different failure modes require different intervention strategies")
print(f"  ‚Ä¢ The system balances accuracy, cost, and latency trade-offs")

print(f"\nüî¨ Next Steps:")
print(f"  ‚Ä¢ Train on real correction data for better predictions")
print(f"  ‚Ä¢ Implement actual external tools and human-in-the-loop systems")
print(f"  ‚Ä¢ Conduct online A/B tests to validate routing decisions")
print(f"  ‚Ä¢ Extend to domain-specific applications (medical, legal, etc.)")
print(f"  ‚Ä¢ Develop adaptive learning from deployment feedback")

## Conclusion

This demo has shown how the SCFP framework can:

1. **Predict Failures**: Identify when self-correction is likely to fail
2. **Route Intelligently**: Select appropriate correction strategies based on context
3. **Optimize Trade-offs**: Balance accuracy, cost, and latency requirements
4. **Adapt to Context**: Consider domain, urgency, and stakes in decision-making

The framework transforms a critical vulnerability (correction failures) into a valuable operational signal, enabling more reliable and efficient AI systems.

---

**To explore further:**
- Run the full reproduction script: `./scripts/reproduce_all.sh`
- Try the interactive routing demo: `python scripts/demo_routing.py --interactive`
- Examine the complete implementation in the `src/` directory