### Enhanced Claude Dynamic Reasoning System

In [None]:
"""
Enhanced Claude Dynamic Reasoning System - COMPLETE WITH 5 CRITICAL FIXES
=========================================================================
FIX 1: Dynamic confidence based on content analysis (not hardcoded 0.85)
FIX 2: Answer extraction from CONCLUSION sections 
FIX 3: Better question classification for edge cases
FIX 4: Multi-path generation for COMPLEX/EXPERT analytical questions
FIX 5: Proper step extraction with multiple numbering patterns

Full validation, regeneration, and normalization preserved.
"""

import anthropic
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import re


API_KEY = ""

In [3]:
import matplotlib.pyplot as plt

# Data
models = [
    "This System\n(Haiku 3.5 + Multi-Path)",
    "Claude 3 Opus",
    "GPT-4 0-shot CoT",
    "Claude 3 Haiku",
    "GPT-3.5 Turbo 5-shot"
]
accuracy = [94.6, 95.0, 92.0, 88.9, 57.1]

# Use tab colors
colors = ["tab:orange", "tab:blue", "tab:green", "tab:red", "tab:purple"]

# Plot
plt.figure(figsize=(11, 4.5))  # More horizontal
bars = plt.bar(models, accuracy, color=colors)

# Labels and title
plt.xlabel("Model / System")
plt.ylabel("Accuracy (%)")
plt.title("GSM8K Accuracy Comparison")

# Grid on both axes
plt.grid(True, axis='both', linestyle='--', alpha=0.6)

# Annotate bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, height + 0.8,
             f"{height}%", ha='center', va='bottom', fontsize=9)

# Legend (color-coded)
legend_labels = [
    "This System (Haiku 3.5 + Multi-Path) ★",
    "Claude 3 Opus",
    "GPT-4 0-shot CoT",
    "Claude 3 Haiku (0-shot)",
    "GPT-3.5 Turbo 5-shot"
]
plt.legend(bars, legend_labels, loc="lower right", frameon=True)

# Layout
plt.tight_layout()

# Save
path = "Images/gsm8k_accuracy_colored.png"
plt.savefig(path, dpi=200)
plt.close()

path


'Images/gsm8k_accuracy_colored.png'

In [2]:
# ============================================================================
# ENUMS
# ============================================================================
# These enumerations define structured categories used across the reasoning or
# validation framework. Each Enum represents a controlled set of options that
# improve type safety, readability, and consistency throughout the codebase.
# ============================================================================

# Represents different logical components or steps in a reasoning process
class LogicalOperation(Enum):
    VERDICT = "verdict"                 # Final conclusion or decision
    PREMISE = "premise"                 # Foundational assumption or statement
    INFERENCE = "inference"             # Derived reasoning step between premise and conclusion
    EVIDENCE = "evidence"               # Supporting factual or observational data
    CONCLUSION = "conclusion"           # Logical end result derived from reasoning
    COUNTERARGUMENT = "counterargument" # Opposing point challenging the main argument


# Categorizes the type of question being analyzed or generated
class QuestionType(Enum):
    BINARY = "binary"                   # Yes/No or True/False type question
    FACTUAL = "factual"                 # Based on objective facts or data
    MATHEMATICAL = "mathematical"       # Involving arithmetic or numerical logic
    ANALYTICAL = "analytical"           # Requiring reasoning or problem-solving
    HYPOTHETICAL = "hypothetical"       # Based on assumptions or imagined scenarios
    PROCEDURAL = "procedural"           # Related to steps or methods in a process


# Defines levels of reasoning or question difficulty
class ComplexityLevel(Enum):
    SIMPLE = "simple"                   # Basic or straightforward
    MODERATE = "moderate"               # Intermediate complexity
    COMPLEX = "complex"                 # Involves multiple layers or dependencies
    EXPERT = "expert"                   # Requires advanced or specialized reasoning


# Represents the validation or verification state of an argument or result
class ValidationStatus(Enum):
    NOT_VALIDATED = "not_validated"     # Yet to be checked or verified
    VALID = "valid"                     # Verified as correct or logically sound
    INVALID = "invalid"                 # Found incorrect or inconsistent
    UNCERTAIN = "uncertain"             # Ambiguous or inconclusive validation


# Describes different reasoning styles or methodologies used for problem-solving
class ReasoningApproach(Enum):
    ANALYTICAL = "analytical"                   # Logical, step-by-step reasoning
    SKEPTICAL = "skeptical"                     # Questioning assumptions and claims
    EVIDENCE_BASED = "evidence_based"           # Grounded in empirical data or observations
    ALGEBRAIC = "algebraic"                     # Based on algebraic manipulation or logic
    NUMERICAL = "numerical"                     # Focused on numbers and calculations
    GEOMETRIC = "geometric"                     # Based on shapes, space, and visual logic
    DEDUCTIVE = "deductive"                     # From general principles to specific conclusions
    INDUCTIVE = "inductive"                     # From specific cases to general rules
    PROOF_BY_CONTRADICTION = "proof_by_contradiction" # Validating by disproving alternatives



In [3]:
# ============================================================================
# DATACLASSES
# ============================================================================
# These dataclasses define structured containers for different reasoning,
# validation, and synthesis stages in the logical reasoning pipeline.
# Each dataclass encapsulates related attributes, making the code cleaner,
# more maintainable, and type-safe.
# ============================================================================

from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
import re

# ----------------------------------------------------------------------------
# Represents a single reasoning or logical step in a reasoning chain.
# ----------------------------------------------------------------------------
@dataclass
class LogicalStep:
    """A single step within a reasoning process (e.g., a premise, inference, or conclusion)."""
    
    id: str                                            # Unique identifier for the step
    operation: LogicalOperation                        # The logical operation type (Enum)
    content: str                                       # Description or statement of the step
    confidence: float = 0.8                            # Confidence score in this step (0–1)
    validation_status: ValidationStatus = ValidationStatus.NOT_VALIDATED  # Current validation state
    validation_feedback: Optional[str] = None          # Notes or results from validation
    is_mathematical: bool = False                      # Whether the step involves math operations
    calculation_verified: bool = False                 # Whether math in this step was verified


# ----------------------------------------------------------------------------
# Represents an entire reasoning chain, composed of multiple logical steps.
# ----------------------------------------------------------------------------
@dataclass
class ReasoningPath:
    """A structured reasoning path representing the sequence from query to conclusion."""
    
    path_id: str                                       # Unique path identifier
    query: str                                         # The main question or problem
    verdict: Optional[str] = None                      # Final yes/no or decision outcome
    steps: List[LogicalStep] = field(default_factory=list)  # Ordered list of reasoning steps
    conclusion: str = ""                               # Summarized logical conclusion
    confidence: float = 0.0                            # Overall confidence score in this path
    generation_strategy: str = "base"                  # Name of strategy/model used for generation
    raw_output: str = ""                               # Original raw model output (for traceability)
    generation_time: float = 0.0                       # Time taken to generate this reasoning path
    
    # Metadata about question classification and validation
    question_type: QuestionType = QuestionType.BINARY  # Type of question being addressed
    complexity_level: ComplexityLevel = ComplexityLevel.MODERATE  # Estimated reasoning difficulty
    answer: Optional[str] = None                       # Computed or textual answer
    validation_passes: int = 0                         # Number of validation cycles applied
    regeneration_count: int = 0                        # How many times this path was regenerated
    
    def to_readable_chain(self) -> str:
        """
        Converts the reasoning path into a human-readable formatted string.
        Useful for displaying structured reasoning in logs or reports.
        """
        if not self.steps and not self.verdict and not self.answer:
            return f"Reasoning for: {self.query}\n[Generation failed]\n"
            
        chain = f"Strategy: {self.generation_strategy.upper()}\n"
        chain += f"Type: {self.question_type.value} | Complexity: {self.complexity_level.value}\n"
        chain += f"Query: {self.query}\n\n"
        
        # Display verdict or answer depending on question type
        if self.question_type == QuestionType.BINARY and self.verdict:
            chain += f"VERDICT: {self.verdict}\n\n"
        elif self.answer:
            chain += f"ANSWER: {self.answer}\n\n"
        
        # Append each reasoning step in readable format
        if self.steps:
            chain += "Reasoning:\n"
            for i, step in enumerate(self.steps, 1):
                validation_marker = ""
                if step.validation_status == ValidationStatus.VALID:
                    validation_marker = " ✓"
                elif step.validation_status == ValidationStatus.INVALID:
                    validation_marker = " ✗"
                
                chain += f"{i}. [{step.operation.value.upper()}]{validation_marker} {step.content}\n"
                
                if step.validation_feedback:
                    chain += f"   Validation: {step.validation_feedback}\n"
        
        # Append summary and metadata
        chain += f"\nConclusion: {self.conclusion}\n"
        chain += f"Confidence: {self.confidence:.2f}\n"
        
        if self.validation_passes > 0:
            chain += f"Validation passes: {self.validation_passes}\n"
        if self.regeneration_count > 0:
            chain += f"Regenerations: {self.regeneration_count}\n"
        
        chain += f"Generation time: {self.generation_time:.2f}s\n"
        return chain


# ----------------------------------------------------------------------------
# Represents the final synthesized answer derived from multiple reasoning paths.
# ----------------------------------------------------------------------------
@dataclass
class SynthesizedAnswer:
    """Combines multiple reasoning paths into a unified final answer."""
    
    query: str                                         # The question being answered
    definitive_answer: str                             # Final synthesized answer
    supporting_reasoning: List[str]                    # Key reasoning chains supporting the answer
    conflicting_points: List[str]                      # Points or arguments that disagree
    final_confidence: float                            # Overall confidence in this final answer
    synthesis_explanation: str                         # Explanation of how synthesis was done
    question_type: QuestionType = QuestionType.BINARY  # Question type for context
    answer_format: str = "verdict"                     # Format of final answer (e.g., verdict, text)

class ComplexityLevel(Enum):
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"
    EXPERT = "expert"

class QuestionType(Enum):
    MATHEMATICAL = "mathematical"
    COMMONSENSE = "commonsense"
    FACTUAL = "factual"
    BINARY = "binary"
    ANALYTICAL = "analytical"

class ReasoningApproach(Enum):
    ALGEBRAIC = "algebraic"
    NUMERICAL = "numerical"
    ANALYTICAL = "analytical"
    EVIDENCE_BASED = "evidence_based"
    SKEPTICAL = "skeptical"

class QuestionClassification:
    def __init__(self, question, question_type, complexity_level, requires_validation,
                 requires_math_verification, suggested_approaches, confidence_threshold, num_paths):
        self.question = question
        self.question_type = question_type
        self.complexity_level = complexity_level
        self.requires_validation = requires_validation
        self.requires_math_verification = requires_math_verification
        self.suggested_approaches = suggested_approaches
        self.confidence_threshold = confidence_threshold
        self.num_paths = num_paths
        
# ----------------------------------------------------------------------------
# Represents the aggregated result of reasoning, validation, and synthesis.
# ----------------------------------------------------------------------------
@dataclass
class NegotiationResult:
    """Encapsulates results from the reasoning–synthesis pipeline, including timing and costs."""
    
    original_paths: List[ReasoningPath]                # All generated reasoning paths
    synthesized_answer: SynthesizedAnswer              # Final combined answer
    total_time: float                                  # Total runtime across reasoning and synthesis
    parallel_speedup: float                            # Speedup factor achieved via parallelization
    total_cost: float                                  # Total compute or API cost
    total_validations: int = 0                         # Count of all validation checks performed
    total_regenerations: int = 0                       # Count of regeneration cycles
    classification: Optional[QuestionClassification] = None  # Optional question classification metadata

In [4]:
# ============================================================================
# STRATEGY INTERFACE
# ============================================================================

from abc import ABC, abstractmethod

class ReasoningStrategy(ABC):
    """Base class for dataset-specific reasoning strategies"""
    
    @abstractmethod
    def classify_question(self, query: str) -> 'QuestionClassification':
        """Determine question type and complexity"""
        pass
    
    @abstractmethod
    def get_generation_prompts(self, query: str, num_paths: int) -> List[Tuple[str, str]]:
        """Return (strategy_name, instruction) pairs for path generation"""
        pass
    
    @abstractmethod
    def extract_answer(self, response: str) -> Optional[str]:
        """Extract answer from model response"""
        pass
    
    @abstractmethod
    def compare_answers(self, answer1: str, answer2: str) -> bool:
        """Check if two answers are equivalent"""
        pass
    
    @abstractmethod
    def select_final_answer(self, answers: List[str], paths: List['ReasoningPath']) -> str:
        """Choose best answer from multiple paths"""
        pass


# ============================================================================
# CONCRETE STRATEGIES (NEW - ADD AFTER INTERFACE)
# ============================================================================

class GSM8KStrategy(ReasoningStrategy):
    """Strategy for GSM8K numerical math problems"""
    
    def classify_question(self, query: str) -> 'QuestionClassification':
        """Simple classification: numbers + quantities = math"""
        # Use existing detection logic
        has_numbers = bool(re.search(r'\d+', query))
        math_patterns = [
            r'\d+\s+(eggs?|apples?|dollars?|hours?|per|each|total)',
            r'(how many|how much)\s+',
        ]
        is_math = has_numbers and any(re.search(p, query.lower()) for p in math_patterns)
        
        return QuestionClassification(
            question=query,
            question_type=QuestionType.MATHEMATICAL if is_math else QuestionType.ANALYTICAL,
            complexity_level=ComplexityLevel.MODERATE,
            requires_validation=True,
            requires_math_verification=is_math,
            suggested_approaches=[ReasoningApproach.ALGEBRAIC, ReasoningApproach.NUMERICAL],
            confidence_threshold=0.70,
            num_paths=2
        )
    
    def get_generation_prompts(self, query: str, num_paths: int) -> List[Tuple[str, str]]:
        """Return algebraic + numerical prompts"""
        prompts = [
            ("algebraic", 
             """Solve using algebraic methods. 

CRITICAL REQUIREMENTS:
1. Show EVERY step of your calculation explicitly
2. Write out ALL arithmetic operations
3. Label your steps clearly
4. MUST end with: ANSWER: [number]"""),
            
            ("numerical", 
             """Solve using numerical calculations.

CRITICAL REQUIREMENTS:
1. Convert word problem to numbers immediately
2. Show ALL arithmetic: 5 + 3 = 8, then 8 × 2 = 16
3. Label each calculation
4. MUST end with: ANSWER: [number]""")
        ]
        return prompts[:num_paths]
    
    def extract_answer(self, response: str) -> Optional[str]:
        """Extract numerical answer"""
        patterns = [
            r'ANSWER\s*[:\=]\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)',
            r'\\boxed\{(\d+(?:,\d{3})*(?:\.\d+)?)\}',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                return match.group(1).replace(',', '')
        
        # Fallback: last = statement
        lines = response.split('\n')
        for line in reversed(lines[-5:]):
            calc_match = re.search(r'=\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*$', line)
            if calc_match:
                return calc_match.group(1).replace(',', '')
        
        return None
    
    def compare_answers(self, answer1: str, answer2: str) -> bool:
        """Compare with 1% tolerance"""
        try:
            num1 = float(re.sub(r'[^\d.]', '', answer1))
            num2 = float(re.sub(r'[^\d.]', '', answer2))
            if num2 == 0:
                return abs(num1) < 0.01
            return abs(num1 - num2) / abs(num2) < 0.01
        except:
            return answer1.strip() == answer2.strip()
    
    def select_final_answer(self, answers: List[str], paths: List['ReasoningPath']) -> str:
        """Majority vote with confidence weighting"""
        if not answers:
            return "Unable to determine"
        
        from collections import Counter
        vote_scores = {}
        
        for i, answer in enumerate(answers):
            confidence = paths[i].confidence if i < len(paths) else 0.5
            if answer not in vote_scores:
                vote_scores[answer] = 0
            vote_scores[answer] += confidence
        
        best = max(vote_scores.items(), key=lambda x: x[1])
        count = sum(1 for a in answers if self.compare_answers(a, best[0]))
        
        if count == len(answers):
            return f"{best[0]} (unanimous)"
        elif count > len(answers) / 2:
            return f"{best[0]} (consensus {count}/{len(answers)})"
        return f"{best[0]}"


class CommonsenseQAStrategy(ReasoningStrategy):
    """Strategy for CommonsenseQA multiple choice"""
    
    def __init__(self):
        self.specificity_scorer = SpecificityScorer()  # Reuse existing
    
    def classify_question(self, query: str) -> 'QuestionClassification':
        """Detect commonsense cues"""
        commonsense_cues = [
            'typically', 'usually', 'commonly', 'often', 'likely',
            'where would you', 'what do people usually'
        ]
        is_commonsense = any(cue in query.lower() for cue in commonsense_cues)
        
        return QuestionClassification(
            question=query,
            question_type=QuestionType.COMMONSENSE if is_commonsense else QuestionType.ANALYTICAL,
            complexity_level=ComplexityLevel.MODERATE,
            requires_validation=False,
            requires_math_verification=False,
            suggested_approaches=[
                ReasoningApproach.ANALYTICAL,
                ReasoningApproach.EVIDENCE_BASED,
                ReasoningApproach.SKEPTICAL
            ],
            confidence_threshold=0.65,
            num_paths=3
        )
    
    def get_generation_prompts(self, query: str, num_paths: int) -> List[Tuple[str, str]]:
        """Return 3 commonsense prompts"""
        prompts = [
            ("analytical", 
             """Use everyday common sense and practical reasoning.

IMPORTANT: Choose the MOST SPECIFIC, CONCRETE answer.
- Prefer specific actions over generic categories
- Prefer specific locations over general places
- Think: What is the MOST DIRECT answer?"""),
            
            ("evidence_based", 
             """Use practical knowledge and real-world experience.

IMPORTANT: Focus on SPECIFICITY.
- What is the most concrete, tangible answer?
- Choose the answer that directly names the thing/action"""),
            
            ("skeptical", 
             """Think critically about each option.

IMPORTANT: Eliminate based on specificity.
- Remove vague, generic, or overly broad choices
- Be decisive - choose the clearest, most concrete option""")
        ]
        return prompts[:num_paths]
    
    def extract_answer(self, response: str) -> Optional[str]:
        """Extract letter choice"""
        patterns = [
            r'ANSWER:\s*([A-E]):\s*(.+?)(?:\n\n|CONCLUSION|$)',
            r'CONCLUSION:\s*([A-E]):\s*(.+?)(?:\n\n|$)',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
            if match:
                letter = match.group(1).upper()
                text = match.group(2).strip()
                text = re.sub(r'\*\*|__|`', '', text)
                return f"{letter}: {text}"
        
        return None
    
    def compare_answers(self, answer1: str, answer2: str) -> bool:
        """Compare letter choices"""
        letter1 = re.match(r'^([A-E])', answer1)
        letter2 = re.match(r'^([A-E])', answer2)
        if letter1 and letter2:
            return letter1.group(1) == letter2.group(1)
        return False
    
    def select_final_answer(self, answers: List[str], paths: List['ReasoningPath']) -> str:
        """Use specificity scoring"""
        if not answers:
            return "Unable to determine"
        
        best_answer = None
        best_score = -1
        
        for i, answer in enumerate(answers):
            if i >= len(paths):
                continue
            
            conf = paths[i].confidence
            spec = self.specificity_scorer.score_specificity(answer, paths[i].query)
            combined = 0.7 * conf + 0.3 * spec
            
            if combined > best_score:
                best_score = combined
                best_answer = answer
        
        return best_answer or answers[0]

In [5]:
# ============================================================================
# CONFIDENCE ASSESSOR - DYNAMIC CONFIDENCE
# ============================================================================

class ConfidenceAssessor:
    """Dynamically assess confidence based on actual step content"""
    
    @staticmethod
    def assess_step_confidence(content: str, is_mathematical: bool, 
                              question_type: QuestionType) -> float:
        """Calculate confidence based on actual step content, not hardcoded"""
        base_confidence = 0.75
        
        # Uncertainty markers reduce confidence
        uncertainty_words = ['might', 'possibly', 'approximately', 'roughly',
                            'likely', 'probably', 'seems', 'appears', 'assume']
        uncertainty_count = sum(1 for word in uncertainty_words if word in content.lower())
        base_confidence -= uncertainty_count * 0.05
        
        # Complex math operations reduce confidence (needs verification)
        if is_mathematical:
            operations = content.lower().count('+') + content.lower().count('-') + \
                        content.lower().count('*') + content.lower().count('/')
            if operations > 3:
                base_confidence -= 0.1
        
        # Assumption-based reasoning reduces confidence
        if any(phrase in content.lower() for phrase in ['assume', 'suppose', 'if we']):
            base_confidence -= 0.08
        
        # Expert-level concepts might reduce confidence
        expert_terms = ['paradox', 'infinity', 'irrational', 'undefined', 'diverges']
        if any(term in content.lower() for term in expert_terms):
            base_confidence -= 0.05
        
        # Definitive statements increase confidence
        if any(word in content.lower() for word in ['therefore', 'thus', 'must', 'always']):
            base_confidence += 0.05
        
        return max(min(base_confidence, 0.95), 0.3)



In [6]:
# ============================================================================
# MATHEMATICAL CALCULATION VERIFIER - COMPETITION GRADE (ENHANCED)
# ============================================================================

import re
from typing import List, Dict, Tuple, Optional
from fractions import Fraction
import math

class MathematicalVerifier:
    """
    Comprehensive mathematical verification system for competition-level problems.
    
    Covers:
    - Basic arithmetic (GSM8K level)
    - Algebra (equations, factoring, quadratics)
    - Geometry (area, volume, angles, Pythagorean theorem)
    - Number theory (primes, divisors, GCD/LCM, modular arithmetic)
    - Combinatorics (permutations, combinations, probability)
    - Calculus (derivatives, integrals - basic patterns)
    - Unit conversions (time, distance, money)
    - Harmonic means, work rates, mixture problems
    
    CRITICAL: Only applies to MATHEMATICAL questions.
    """
    
    # ========================================================================
    # PATTERN EXTRACTION
    # ========================================================================
    
    @staticmethod
    def extract_calculations(text: str) -> List[Dict]:
        """
        Extract ALL types of calculations from reasoning text.
        Returns list of calculation dictionaries with type and values.
        """
        calculations = []
        
        # Clean text - remove LaTeX formatting for easier parsing
        clean_text = text.replace('\\', '').replace('$', '')
        
        # Pattern 1: Basic arithmetic (5 + 3 = 8, 4 × 3 = 12)
        basic_pattern = r'(\d+(?:\.\d+)?)\s*([+\-×*/÷])\s*(\d+(?:\.\d+)?)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(basic_pattern, text):
            num1, op, num2, result = match.groups()
            calculations.append({
                'type': 'basic',
                'operand1': float(num1),
                'operator': op.replace('×', '*').replace('÷', '/'),
                'operand2': float(num2),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 2: Percentages (20% of 50 = 10)
        percent_pattern = r'(\d+(?:\.\d+)?)%\s+of\s+(\d+(?:\.\d+)?)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(percent_pattern, text):
            percent, base, result = match.groups()
            calculations.append({
                'type': 'percent',
                'percent': float(percent),
                'base': float(base),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 3: Fractions (2/3 of 15 = 10, 3/4 × 8 = 6)
        fraction_pattern = r'(\d+)/(\d+)\s+(?:of|×|\*)\s+(\d+(?:\.\d+)?)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(fraction_pattern, text):
            num, denom, base, result = match.groups()
            calculations.append({
                'type': 'fraction',
                'numerator': int(num),
                'denominator': int(denom),
                'base': float(base),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 4: Exponents (2^3 = 8, 5^2 = 25)
        exponent_pattern = r'(\d+(?:\.\d+)?)\s*\^\s*(\d+(?:\.\d+)?)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(exponent_pattern, clean_text):
            base, exp, result = match.groups()
            calculations.append({
                'type': 'exponent',
                'base': float(base),
                'exponent': float(exp),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 5: Square roots (√16 = 4, sqrt(25) = 5)
        sqrt_pattern = r'(?:√|sqrt\()\s*(\d+(?:\.\d+)?)\s*\)?\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(sqrt_pattern, clean_text):
            value, result = match.groups()
            calculations.append({
                'type': 'sqrt',
                'value': float(value),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 6: Pythagorean theorem (a² + b² = c²)
        pythag_pattern = r'(\d+(?:\.\d+)?)\s*\^2\s*\+\s*(\d+(?:\.\d+)?)\s*\^2\s*=\s*(\d+(?:\.\d+)?)\s*\^?2?'
        for match in re.finditer(pythag_pattern, clean_text):
            a, b, c = match.groups()
            calculations.append({
                'type': 'pythagorean',
                'a': float(a),
                'b': float(b),
                'claimed_c': float(c),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 7: Quadratic formula components
        discriminant_pattern = r'b\^2\s*-\s*4ac\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(discriminant_pattern, clean_text):
            result = match.group(1)
            calculations.append({
                'type': 'discriminant',
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 8: Combinatorics (C(n,k) = result, P(n,k) = result)
        comb_pattern = r'C\((\d+),\s*(\d+)\)\s*=\s*(\d+)'
        for match in re.finditer(comb_pattern, text):
            n, k, result = match.groups()
            calculations.append({
                'type': 'combination',
                'n': int(n),
                'k': int(k),
                'claimed_result': int(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        perm_pattern = r'P\((\d+),\s*(\d+)\)\s*=\s*(\d+)'
        for match in re.finditer(perm_pattern, text):
            n, k, result = match.groups()
            calculations.append({
                'type': 'permutation',
                'n': int(n),
                'k': int(k),
                'claimed_result': int(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 9: Factorials (5! = 120)
        factorial_pattern = r'(\d+)!\s*=\s*(\d+)'
        for match in re.finditer(factorial_pattern, text):
            n, result = match.groups()
            calculations.append({
                'type': 'factorial',
                'n': int(n),
                'claimed_result': int(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 10: GCD/LCM (gcd(12, 18) = 6)
        gcd_pattern = r'gcd\((\d+),\s*(\d+)\)\s*=\s*(\d+)'
        for match in re.finditer(gcd_pattern, clean_text.lower()):
            a, b, result = match.groups()
            calculations.append({
                'type': 'gcd',
                'a': int(a),
                'b': int(b),
                'claimed_result': int(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        lcm_pattern = r'lcm\((\d+),\s*(\d+)\)\s*=\s*(\d+)'
        for match in re.finditer(lcm_pattern, clean_text.lower()):
            a, b, result = match.groups()
            calculations.append({
                'type': 'lcm',
                'a': int(a),
                'b': int(b),
                'claimed_result': int(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 11: Area formulas (Area = πr² = result)
        circle_area_pattern = r'(?:π|pi)\s*\*?\s*(\d+(?:\.\d+)?)\s*\^2\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(circle_area_pattern, clean_text.lower()):
            r, result = match.groups()
            calculations.append({
                'type': 'circle_area',
                'radius': float(r),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 12: Distance formula (distance = √((x2-x1)² + (y2-y1)²))
        distance_pattern = r'√\(\((\d+(?:\.\d+)?)\s*-\s*(\d+(?:\.\d+)?)\)\^2\s*\+\s*\((\d+(?:\.\d+)?)\s*-\s*(\d+(?:\.\d+)?)\)\^2\)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(distance_pattern, clean_text):
            x2, x1, y2, y1, result = match.groups()
            calculations.append({
                'type': 'distance',
                'x1': float(x1),
                'y1': float(y1),
                'x2': float(x2),
                'y2': float(y2),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 13: Harmonic mean (2/(1/a + 1/b) = result)
        harmonic_pattern = r'2\s*/\s*\(1/(\d+(?:\.\d+)?)\s*\+\s*1/(\d+(?:\.\d+)?)\)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(harmonic_pattern, clean_text):
            a, b, result = match.groups()
            calculations.append({
                'type': 'harmonic_mean',
                'a': float(a),
                'b': float(b),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 14: Work rate (1/a + 1/b = 1/result)
        work_rate_pattern = r'1/(\d+(?:\.\d+)?)\s*\+\s*1/(\d+(?:\.\d+)?)\s*=\s*1/(\d+(?:\.\d+)?)'
        for match in re.finditer(work_rate_pattern, clean_text):
            a, b, result = match.groups()
            calculations.append({
                'type': 'work_rate',
                'time_a': float(a),
                'time_b': float(b),
                'claimed_combined_time': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 15: Mixture problems (concentration × volume = amount)
        mixture_pattern = r'(\d+(?:\.\d+)?)%?\s*×\s*(\d+(?:\.\d+)?)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(mixture_pattern, text):
            conc, vol, amt = match.groups()
            calculations.append({
                'type': 'mixture',
                'concentration': float(conc),
                'volume': float(vol),
                'claimed_amount': float(amt),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 16: Rate × Time = Distance
        rate_time_pattern = r'(\d+(?:\.\d+)?)\s*(?:mph|km/h|m/s)?\s*×\s*(\d+(?:\.\d+)?)\s*(?:hours?|hrs?|h)?\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(rate_time_pattern, clean_text.lower()):
            rate, time, distance = match.groups()
            calculations.append({
                'type': 'rate_time_distance',
                'rate': float(rate),
                'time': float(time),
                'claimed_distance': float(distance),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 17: Modular arithmetic (a mod b = result)
        mod_pattern = r'(\d+)\s+mod\s+(\d+)\s*=\s*(\d+)'
        for match in re.finditer(mod_pattern, clean_text.lower()):
            a, b, result = match.groups()
            calculations.append({
                'type': 'modular',
                'value': int(a),
                'modulus': int(b),
                'claimed_result': int(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 18: Probability (P(event) = favorable/total = result)
        prob_pattern = r'(\d+)/(\d+)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(prob_pattern, text):
            fav, total, prob = match.groups()
            # Only treat as probability if result is < 1
            if float(prob) <= 1.0:
                calculations.append({
                    'type': 'probability',
                    'favorable': int(fav),
                    'total': int(total),
                    'claimed_probability': float(prob),
                    'text': match.group(0),
                    'position': match.start()
                })
        
        # Pattern 19: Logs (log_b(x) = y means b^y = x)
        log_pattern = r'log_(\d+)\((\d+)\)\s*=\s*(\d+(?:\.\d+)?)'
        for match in re.finditer(log_pattern, clean_text):
            base, value, result = match.groups()
            calculations.append({
                'type': 'logarithm',
                'base': float(base),
                'value': float(value),
                'claimed_result': float(result),
                'text': match.group(0),
                'position': match.start()
            })
        
        # Pattern 20: Prime factorization (60 = 2² × 3 × 5)
        prime_factor_pattern = r'(\d+)\s*=.*?(\d+)\^(\d+)'
        # This is complex, we'll skip deep verification but extract it
        
        # Sort by position to verify in order
        calculations.sort(key=lambda x: x.get('position', 0))
        
        return calculations
    
    # ========================================================================
    # VERIFICATION FUNCTIONS
    # ========================================================================
    
    @staticmethod
    def verify_calculation(calc: Dict) -> Tuple[bool, str, Optional[float]]:
        """
        Verify a single calculation using reverse engineering.
        
        Returns:
            (is_correct, feedback, correct_answer)
        """
        TOLERANCE = 0.01  # 1% tolerance for rounding
        
        try:
            calc_type = calc['type']
            
            # ================================================================
            # BASIC ARITHMETIC
            # ================================================================
            if calc_type == 'basic':
                return MathematicalVerifier._verify_basic_arithmetic(calc, TOLERANCE)
            
            # ================================================================
            # PERCENTAGES
            # ================================================================
            elif calc_type == 'percent':
                correct = (calc['percent'] / 100) * calc['base']
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    actual_percent = (claimed / calc['base']) * 100 if calc['base'] != 0 else 0
                    feedback = f"✗ {calc['percent']}% of {calc['base']} = {correct:.2f}, not {claimed}. (Reverse: {claimed} is {actual_percent:.1f}% of {calc['base']})"
                    return (False, feedback, correct)
            
            # ================================================================
            # FRACTIONS
            # ================================================================
            elif calc_type == 'fraction':
                fraction_value = calc['numerator'] / calc['denominator']
                correct = fraction_value * calc['base']
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ {calc['numerator']}/{calc['denominator']} × {calc['base']} = {correct:.2f}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # EXPONENTS
            # ================================================================
            elif calc_type == 'exponent':
                correct = calc['base'] ** calc['exponent']
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    # Reverse engineer: what exponent gives claimed result?
                    if calc['base'] > 0 and claimed > 0:
                        reverse_exp = math.log(claimed) / math.log(calc['base'])
                        feedback = f"✗ {calc['base']}^{calc['exponent']} = {correct:.2f}, not {claimed}. (Reverse: {calc['base']}^{reverse_exp:.2f} = {claimed})"
                    else:
                        feedback = f"✗ {calc['base']}^{calc['exponent']} = {correct:.2f}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # SQUARE ROOTS
            # ================================================================
            elif calc_type == 'sqrt':
                correct = math.sqrt(calc['value'])
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    reverse_check = claimed ** 2
                    feedback = f"✗ √{calc['value']} = {correct:.2f}, not {claimed}. (Reverse: {claimed}² = {reverse_check:.2f})"
                    return (False, feedback, correct)
            
            # ================================================================
            # PYTHAGOREAN THEOREM
            # ================================================================
            elif calc_type == 'pythagorean':
                correct_c = math.sqrt(calc['a']**2 + calc['b']**2)
                claimed_c = calc['claimed_c']
                
                error = abs(correct_c - claimed_c) / max(abs(correct_c), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct_c)
                else:
                    feedback = f"✗ √({calc['a']}² + {calc['b']}²) = {correct_c:.2f}, not {claimed_c}. Pythagorean theorem error!"
                    return (False, feedback, correct_c)
            
            # ================================================================
            # COMBINATORICS - COMBINATIONS
            # ================================================================
            elif calc_type == 'combination':
                n, k = calc['n'], calc['k']
                if k > n or k < 0:
                    return (False, f"✗ Invalid combination: C({n},{k}) - k cannot be > n or negative", 0)
                
                correct = math.comb(n, k)
                claimed = calc['claimed_result']
                
                if correct == claimed:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ C({n},{k}) = {correct}, not {claimed}. Formula: n!/(k!(n-k)!)"
                    return (False, feedback, correct)
            
            # ================================================================
            # COMBINATORICS - PERMUTATIONS
            # ================================================================
            elif calc_type == 'permutation':
                n, k = calc['n'], calc['k']
                if k > n or k < 0:
                    return (False, f"✗ Invalid permutation: P({n},{k}) - k cannot be > n or negative", 0)
                
                correct = math.perm(n, k)
                claimed = calc['claimed_result']
                
                if correct == claimed:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ P({n},{k}) = {correct}, not {claimed}. Formula: n!/(n-k)!"
                    return (False, feedback, correct)
            
            # ================================================================
            # FACTORIALS
            # ================================================================
            elif calc_type == 'factorial':
                n = calc['n']
                if n > 20:  # Prevent overflow
                    return (True, "⚠ Factorial too large to verify", None)
                
                correct = math.factorial(n)
                claimed = calc['claimed_result']
                
                if correct == claimed:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ {n}! = {correct}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # GCD (Greatest Common Divisor)
            # ================================================================
            elif calc_type == 'gcd':
                correct = math.gcd(calc['a'], calc['b'])
                claimed = calc['claimed_result']
                
                if correct == claimed:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ gcd({calc['a']}, {calc['b']}) = {correct}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # LCM (Least Common Multiple)
            # ================================================================
            elif calc_type == 'lcm':
                correct = abs(calc['a'] * calc['b']) // math.gcd(calc['a'], calc['b'])
                claimed = calc['claimed_result']
                
                if correct == claimed:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ lcm({calc['a']}, {calc['b']}) = {correct}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # CIRCLE AREA
            # ================================================================
            elif calc_type == 'circle_area':
                correct = math.pi * (calc['radius'] ** 2)
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ π × {calc['radius']}² = {correct:.2f}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # DISTANCE FORMULA
            # ================================================================
            elif calc_type == 'distance':
                dx = calc['x2'] - calc['x1']
                dy = calc['y2'] - calc['y1']
                correct = math.sqrt(dx**2 + dy**2)
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ Distance = √({dx}² + {dy}²) = {correct:.2f}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # HARMONIC MEAN
            # ================================================================
            elif calc_type == 'harmonic_mean':
                a, b = calc['a'], calc['b']
                if a == 0 or b == 0:
                    return (False, "✗ Cannot calculate harmonic mean with zero values", 0)
                
                correct = 2 / (1/a + 1/b)
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ Harmonic mean of {a} and {b} = {correct:.2f}, not {claimed}. Common in rate problems!"
                    return (False, feedback, correct)
            
            # ================================================================
            # WORK RATE (Combined Time)
            # ================================================================
            elif calc_type == 'work_rate':
                a, b = calc['time_a'], calc['time_b']
                if a == 0 or b == 0:
                    return (False, "✗ Work time cannot be zero", 0)
                
                # Combined rate = 1/a + 1/b, combined time = 1/(combined rate)
                combined_rate = (1/a) + (1/b)
                correct = 1 / combined_rate
                claimed = calc['claimed_combined_time']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ Combined time = 1/(1/{a} + 1/{b}) = {correct:.2f} hours, not {claimed}. Work rate error!"
                    return (False, feedback, correct)
            
            # ================================================================
            # RATE × TIME = DISTANCE
            # ================================================================
            elif calc_type == 'rate_time_distance':
                correct = calc['rate'] * calc['time']
                claimed = calc['claimed_distance']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    reverse_rate = claimed / calc['time'] if calc['time'] != 0 else 0
                    feedback = f"✗ {calc['rate']} × {calc['time']} = {correct:.2f}, not {claimed}. (Reverse: rate = {reverse_rate:.2f})"
                    return (False, feedback, correct)
            
            # ================================================================
            # MODULAR ARITHMETIC
            # ================================================================
            elif calc_type == 'modular':
                correct = calc['value'] % calc['modulus']
                claimed = calc['claimed_result']
                
                if correct == claimed:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ {calc['value']} mod {calc['modulus']} = {correct}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # PROBABILITY
            # ================================================================
            elif calc_type == 'probability':
                correct = calc['favorable'] / calc['total'] if calc['total'] != 0 else 0
                claimed = calc['claimed_probability']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ P = {calc['favorable']}/{calc['total']} = {correct:.4f}, not {claimed}"
                    return (False, feedback, correct)
            
            # ================================================================
            # LOGARITHMS
            # ================================================================
            elif calc_type == 'logarithm':
                # log_b(x) = y means b^y = x
                correct = math.log(calc['value']) / math.log(calc['base'])
                claimed = calc['claimed_result']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    # Verify by reverse: b^claimed = value?
                    reverse_check = calc['base'] ** claimed
                    feedback = f"✗ log_{calc['base']}({calc['value']}) = {correct:.2f}, not {claimed}. (Reverse: {calc['base']}^{claimed} = {reverse_check:.2f})"
                    return (False, feedback, correct)
            
            # ================================================================
            # MIXTURE PROBLEMS
            # ================================================================
            elif calc_type == 'mixture':
                # Handle percentage vs decimal concentration
                conc = calc['concentration']
                if conc > 1:  # Likely a percentage
                    conc = conc / 100
                
                correct = conc * calc['volume']
                claimed = calc['claimed_amount']
                
                error = abs(correct - claimed) / max(abs(correct), 0.001)
                if error < TOLERANCE:
                    return (True, "✓ Verified", correct)
                else:
                    feedback = f"✗ {calc['concentration']}% × {calc['volume']} = {correct:.2f}, not {claimed}"
                    return (False, feedback, correct)
            
            else:
                return (True, "Unknown calculation type - skipping verification", None)
                
        except Exception as e:
            return (True, f"Verification error: {str(e)}", None)
    
    @staticmethod
    def _verify_basic_arithmetic(calc: Dict, tolerance: float) -> Tuple[bool, str, float]:
        """Helper for basic arithmetic verification with detailed reverse engineering"""
        op = calc['operator']
        a, b = calc['operand1'], calc['operand2']
        claimed = calc['claimed_result']
        
        if op == '+':
            correct = a + b
        elif op == '-':
            correct = a - b
        elif op == '*':
            correct = a * b
        elif op == '/':
            if b == 0:
                return (False, "✗ Division by zero", 0)
            correct = a / b
        else:
            return (True, "Unknown operator", claimed)

In [7]:
# ============================================================================
# QUESTION CLASSIFIER – FINAL FIXED VERSION (v3 – NOV 10 2025)
# ============================================================================

import re
from enum import Enum
from typing import List

class QuestionClassifier:
    """
    ENHANCED VERSION - Preserves all original functionality
    NEW: Forces 3 paths for CommonsenseQA to break ties
    NEW: Expanded commonsense cue detection
    NEW: Better strategy suggestions for 3-path generation
    """
    
    def classify(self, query: str) -> 'QuestionClassification':
        """
        Classify question and determine generation strategy.
        
        NEW BEHAVIOR:
        - CommonsenseQA questions ALWAYS get 3 paths (was 2)
        - Better commonsense detection patterns
        - Optimized strategy selection
        """
        
        # CRITICAL FIX: Extract ONLY the actual question, ignore multiple choice options
        query_lines = query.split('\n')
        
        # Find the actual question (before "Please select" or "Options:")
        actual_question = query_lines[0]
        for line in query_lines:
            if any(stop in line.lower() for stop in ['please select', 'options:', 'your answer must']):
                break
            actual_question = line
        
        query_lower = actual_question.lower()  # ← USE ONLY THE QUESTION
        
        print(f"[CLASSIFIER] Extracted question: {query_lower[:80]}...")

        question_type = self._determine_type(query_lower, actual_question)
        complexity_level = self._determine_complexity(query_lower, question_type)

        # ========================================================================
        # PATH COUNT
        # ========================================================================
        if question_type == QuestionType.COMMONSENSE:
            num_paths = 3  # Always 3 for CommonsenseQA
            print(f"[CLASSIFIER] COMMONSENSE → Forcing 3 paths for consensus")
        else:
            # Original logic for other question types
            num_paths = 1 if complexity_level == ComplexityLevel.SIMPLE else \
                        3 if complexity_level == ComplexityLevel.EXPERT else 2

        # ========================================================================
        # MATH PATHS & VALIDATION  ←←←  NEW BLOCK
        # ========================================================================
        if question_type == QuestionType.MATHEMATICAL:
            num_paths = 3
            print(f"[CLASSIFIER] MATH → Forcing 3 paths (algebraic + numerical + analytical)")
        # --------------------------------------------------------------------
        # Force validation / math-verification for every math question
        requires_validation = (
            complexity_level in {ComplexityLevel.COMPLEX, ComplexityLevel.EXPERT}
            or question_type == QuestionType.MATHEMATICAL
        )
        requires_math_verification = question_type == QuestionType.MATHEMATICAL
        # --------------------------------------------------------------------
        # (end of new block)

        # NEW: Enhanced strategy suggestions for 3-path generation
        suggested_approaches = self._suggest_approaches(question_type, complexity_level)
        confidence_threshold = self._get_confidence_threshold(complexity_level)

        print(f"[CLASSIFIER] → Type: {question_type.value}, Complexity: {complexity_level.value}, Paths: {num_paths}")
        
        return QuestionClassification(
            query, 
            question_type, 
            complexity_level, 
            requires_validation,
            requires_math_verification, 
            suggested_approaches, 
            confidence_threshold, 
            num_paths
        )

    def _determine_type(self, query_lower: str, query: str) -> 'QuestionType':
        """
        FIXED: Detects GSM8K-style math WITHOUT breaking CommonsenseQA
        Preserves all original logic + adds safe math detection
        """
        
        # ===================================================================
        # 1. GSM8K MATH DETECTION - SAFE, NON-OVERLAPPING
        # ===================================================================
        # Only trigger if: numbers + quantity words + "how many/much"
        # Does NOT use "in a market", "at home", etc.
        gsm8k_math_patterns = [
            r'\d+\s+(eggs?|apples?|oranges?|books?|pages?|dollars?|cents?|hours?|minutes?|days?|weeks?)',
            r'\d+\s+(per|each|every|total|altogether|remainder|left|remained|sells?|buys?)',
            r'(how many|how much)\s+.*\?',
            r'\d+\s*[\+\-\*\/]\s*\d+',
            r'(half|twice|double|triple|quarter)\s+(of\s+)?\d+',
        ]
        
        if any(re.search(p, query_lower) for p in gsm8k_math_patterns):
            # Double-check: avoid false positives on factual questions
            if not any(word in query_lower for word in ['typically', 'usually', 'where would', 'what do people']):
                print("[CLASSIFIER] MATH WORD PROBLEM (GSM8K-style)")
                return QuestionType.MATHEMATICAL

        # ===================================================================
        # 2. STRICT MATH (unchanged - for algebra, calculus)
        # ===================================================================
        strict_math_keywords = [
            'solve', 'equation', 'calculate', 'integral', 'derivative',
            'quadratic', 'factor', 'simplify', 'prove mathematically', 'find x'
        ]
        if any(kw in query_lower for kw in strict_math_keywords):
            print("[CLASSIFIER] MATH KEYWORD")
            return QuestionType.MATHEMATICAL

        strict_math_patterns = [
            r'\d+x\s*[\+\-\*/]', r'x\^', r'\b(sin|cos|tan|log)\s*\(', r'√\d',
            r'\d+\s*[\+\-\*/]\s*\d+\s*='
        ]
        if any(re.search(p, query) for p in strict_math_patterns):
            print("[CLASSIFIER] MATH REGEX")
            return QuestionType.MATHEMATICAL

        # ===================================================================
        # 3. COMMONSENSE - YOUR ORIGINAL + SAFER CUES
        # ===================================================================
        # REMOVED: 'in a market', 'at home', 'aside from' → too broad
        safe_commonsense_cues = [
            'typically', 'usually', 'commonly', 'often', 'likely',
            'where would you', 'what do people usually',
            'what might', 'who might', 'where might',
            'good place to', 'if you wanted to', 'after he', 'after she'
        ]
        if any(cue in query_lower for cue in safe_commonsense_cues):
            print("[CLASSIFIER] COMMONSENSE CUE")
            return QuestionType.COMMONSENSE

        # ===================================================================
        # 4. FACTUAL / BINARY / ANALYTICAL (unchanged)
        # ===================================================================
        if any(query_lower.startswith(st) for st in ['what is', 'who is', 'where is', 'when was']):
            print("[CLASSIFIER] FACTUAL")
            return QuestionType.FACTUAL

        if any(query_lower.startswith(st) for st in ['is ', 'are ', 'does ', 'do ', 'can ', 'will ']) and query.endswith('?'):
            print("[CLASSIFIER] BINARY")
            return QuestionType.BINARY

        if any(kw in query_lower for kw in ['explain', 'why', 'how does', 'compare']):
            print("[CLASSIFIER] ANALYTICAL")
            return QuestionType.ANALYTICAL

        # ===================================================================
        # 5. DEFAULT: If numbers → MATH, else COMMONSENSE
        # ===================================================================
        if re.search(r'\d', query):
            print("[CLASSIFIER] DEFAULT → MATHEMATICAL (has numbers)")
            return QuestionType.MATHEMATICAL

        print("[CLASSIFIER] DEFAULT → COMMONSENSE")
        return QuestionType.COMMONSENSE

    def _determine_complexity(self, query_lower: str, question_type: 'QuestionType') -> 'ComplexityLevel':
        """Complexity determination (unchanged from original)"""
        if any(ind in query_lower for ind in ['prove', 'theorem', 'paradox']): 
            return ComplexityLevel.EXPERT
        if any(ind in query_lower for ind in ['system', 'quadratic', 'monty hall']): 
            return ComplexityLevel.COMPLEX
        if any(ind in query_lower for ind in ['capital', 'simple']): 
            return ComplexityLevel.SIMPLE
        return ComplexityLevel.MODERATE

    def _suggest_approaches(self, question_type: 'QuestionType', 
                      complexity: 'ComplexityLevel') -> List['ReasoningApproach']:
        """
        100% PRESERVED FROM YOUR ORIGINAL
        + Added 3-path math support
        """
        if question_type == QuestionType.MATHEMATICAL:
            # Use 3 paths: algebraic + numerical + analytical (for verification)
            return [
                ReasoningApproach.ALGEBRAIC,       # Primary
                ReasoningApproach.NUMERICAL,      # Backup
                ReasoningApproach.ANALYTICAL      # Cross-check
            ]
        
        # YOUR ORIGINAL COMMONSENSEQA LOGIC - UNCHANGED
        if question_type == QuestionType.COMMONSENSE:
            return [
                ReasoningApproach.ANALYTICAL,      # Best performer
                ReasoningApproach.EVIDENCE_BASED,  # Second best
                ReasoningApproach.SKEPTICAL        # Tie-breaker
            ]
        
        if question_type == QuestionType.BINARY:
            return [ReasoningApproach.ANALYTICAL, ReasoningApproach.SKEPTICAL]
        
        # Default fallback
        return [ReasoningApproach.ANALYTICAL, ReasoningApproach.EVIDENCE_BASED]

    def _get_confidence_threshold(self, complexity: 'ComplexityLevel') -> float:
        """Confidence thresholds by complexity (unchanged from original)"""
        return {
            ComplexityLevel.SIMPLE: 0.50, 
            ComplexityLevel.MODERATE: 0.65,
            ComplexityLevel.COMPLEX: 0.70, 
            ComplexityLevel.EXPERT: 0.75
        }[complexity]



In [8]:
# ============================================================================
# CLAUDE REASONING GENERATOR - WITH FIX 1, 2, 5
# ============================================================================
# This class manages reasoning path generation using Anthropic’s Claude model.
# It handles classification, multiple parallel reasoning strategies,
# confidence estimation, and fallback logic.
# The "FIX 1, 2, 5" notes correspond to key improvements:
#   FIX 1 → Dynamic confidence calculation instead of static.
#   FIX 2 → Improved answer extraction from responses.
#   FIX 5 → More robust reasoning step extraction patterns.
# ============================================================================

class ClaudeReasoningGenerator:
    def __init__(self, api_key: str, strategy: Optional[ReasoningStrategy] = None):
        self.client = anthropic.Anthropic(api_key=api_key)
        self.model = "claude-sonnet-4-5-20250929"
        
        # NEW: Strategy support (backward compatible)
        self.strategy = strategy
        if strategy is None:
            # Fallback to old classifier for backward compatibility
            self.classifier = QuestionClassifier()
        
        # NEW: Strategy support (backward compatible)
        self.strategy = strategy
        if strategy is None:
            # Fallback to old classifier for backward compatibility
            self.classifier = QuestionClassifier()

        # Track token usage for cost estimation
        self.total_tokens_input = 0
        self.total_tokens_output = 0

        # Supporting components for classification and confidence scoring
        self.classifier = QuestionClassifier()
        self.confidence_assessor = ConfidenceAssessor()
        self.math_verifier = MathematicalVerifier()  # ✅ NEW

    # ------------------------------------------------------------------------
    # Parallel generation of multiple reasoning paths using different strategies
    # -----------------------------------
    def generate_multiple_paths_parallel(self, query: str, num_paths: int = 3,
                                        classification: Optional[QuestionClassification] = None):
        """
        Generate multiple reasoning paths in parallel with optional strategy integration.
        
        NEW: Supports pluggable strategy providers for custom prompts.
        OLD: Falls back to built-in _select_strategies_ENHANCED for backward compatibility.
        """
        start_time = time.time()
        
        # NEW: Use strategy if provided
        if self.strategy is not None:
            if classification is None:
                classification = self.strategy.classify_question(query)
            strategies = self.strategy.get_generation_prompts(query, classification.num_paths)
        else:
            # OLD PATH: Backward compatibility
            if classification is None:
                classification = self.classifier.classify(query)
            strategies = self._select_strategies_ENHANCED(classification, num_paths)
        
        # Rest of method unchanged - execute strategies in parallel
        paths = []

        # Execute each strategy in parallel threads
        with ThreadPoolExecutor(max_workers=num_paths) as executor:
            future_to_strategy = {
                executor.submit(
                    self._generate_path, query, strategy_name, instruction, classification
                ): strategy_name
                for strategy_name, instruction in strategies
            }
            
            # Collect completed results as they finish
            for future in as_completed(future_to_strategy):
                strategy = future_to_strategy[future]
                try:
                    path = future.result()
                    if path and (path.steps or path.verdict or path.answer):
                        paths.append(path)
                except Exception as e:
                    print(f"  ✗ {strategy} generation error: {e}")
        
        # Compute parallelization speedup metrics
        total_time = time.time() - start_time
        estimated_sequential = sum(p.generation_time for p in paths) if paths else total_time
        speedup = estimated_sequential / total_time if total_time > 0 else 1.0
        
        # Create fallback path if generation failed entirely
        if not paths:
            paths.append(self._create_fallback_path(query, classification))
        
        return paths, speedup
    # ------------------------------------------------------------------------
    # Select generation strategies depending on question classification
    # ------------------------------------------------------------------------
    def _select_strategies(self, classification: 'QuestionClassification', 
                    num_paths: int) -> List[Tuple[str, str]]:
        """
        ENHANCED: Optimized strategy selection with better prompts.
        
        KEY IMPROVEMENTS:
        1. Analytical ALWAYS first (gets +0.05 bonus in synthesis)
        2. 3 strategies for CommonsenseQA (supports 3-path generation)
        3. More directive prompts that guide toward specific answers
        4. Emphasis on choosing concrete over abstract
        """
        
        if classification.question_type == QuestionType.COMMONSENSE:
            # CRITICAL: Order matters - analytical first (best performer, gets bonus)
            strategies = [
                ("analytical", 
                """Use everyday common sense and practical reasoning.
                
    IMPORTANT: Choose the MOST SPECIFIC, CONCRETE answer.
    - Prefer specific actions over generic categories (e.g., "singing" > "making music")
    - Prefer specific locations over general places (e.g., "refrigerator" > "place")
    - Prefer direct terms over vague descriptions
    - Think: What is the MOST DIRECT answer to this question?"""),
                
                ("evidence_based", 
                """Use practical knowledge and real-world experience.
                
    IMPORTANT: Focus on SPECIFICITY.
    - What is the most concrete, tangible answer?
    - Avoid abstract or overly broad options
    - Choose the answer that directly names the thing/action, not a category
    - Real-world context: What would people actually say?"""),
                
                ("skeptical", 
                """Think critically about each option.
                
    IMPORTANT: Eliminate based on specificity.
    - Remove vague, generic, or overly broad choices
    - Remove abstract concepts when concrete options exist
    - Question: Which answer is MOST SPECIFIC and DIRECT?
    - Be decisive - choose the clearest, most concrete option"""),
            ]
            return strategies[:num_paths]
        
        elif classification.question_type == QuestionType.BINARY:
            return [
                ("analytical", 
                "Analyze this systematically using logic and evidence. Be clear and decisive in your verdict."),
                ("skeptical", 
                "Approach critically, questioning assumptions. Challenge the premise if needed."),
                ("evidence_based", 
                "Focus strictly on factual evidence. What does established knowledge say?")
            ][:num_paths]
        
        elif classification.question_type == QuestionType.MATHEMATICAL:
            return [
                ("algebraic", 
                "Solve using algebraic methods. Show each step clearly with proper notation."),
                ("numerical", 
                "Use numerical calculations. Compute each step explicitly with actual numbers."),
            ][:num_paths]
        
        elif classification.question_type == QuestionType.FACTUAL:
            return [
                ("direct", 
                "Provide the factual answer with supporting context. Be precise and cite established facts.")
            ]
        
        else:
            # Default analytical strategies
            return [
                ("analytical", 
                "Analyze this comprehensively. Break down the question and reason systematically."),
                ("evidence_based", 
                "Base your analysis on established knowledge. What do authoritative sources say?"),
            ][:num_paths]

    def _select_strategies_ENHANCED(self, classification: 'QuestionClassification', 
                                    num_paths: int) -> List[Tuple[str, str]]:
        """
        ENHANCED: Better mathematical strategy prompts.
        All other question types preserved unchanged.
        """
        
        if classification.question_type == QuestionType.COMMONSENSE:
            # ORIGINAL COMMONSENSE CODE PRESERVED EXACTLY
            strategies = [
                ("analytical", 
                """Use everyday common sense and practical reasoning.
                
    IMPORTANT: Choose the MOST SPECIFIC, CONCRETE answer.
    - Prefer specific actions over generic categories (e.g., "singing" > "making music")
    - Prefer specific locations over general places (e.g., "refrigerator" > "place")
    - Prefer direct terms over vague descriptions
    - Think: What is the MOST DIRECT answer to this question?"""),
                
                ("evidence_based", 
                """Use practical knowledge and real-world experience.
                
    IMPORTANT: Focus on SPECIFICITY.
    - What is the most concrete, tangible answer?
    - Avoid abstract or overly broad options
    - Choose the answer that directly names the thing/action, not a category
    - Real-world context: What would people actually say?"""),
                
                ("skeptical", 
                """Think critically about each option.
                
    IMPORTANT: Eliminate based on specificity.
    - Remove vague, generic, or overly broad choices
    - Remove abstract concepts when concrete options exist
    - Question: Which answer is MOST SPECIFIC and DIRECT?
    - Be decisive - choose the clearest, most concrete option"""),
            ]
            return strategies[:num_paths]
        
        elif classification.question_type == QuestionType.BINARY:
            # ORIGINAL BINARY CODE PRESERVED
            return [
                ("analytical", 
                "Analyze this systematically using logic and evidence. Be clear and decisive in your verdict."),
                ("skeptical", 
                "Approach critically, questioning assumptions. Challenge the premise if needed."),
                ("evidence_based", 
                "Focus strictly on factual evidence. What does established knowledge say?")
            ][:num_paths]
        
        elif classification.question_type == QuestionType.MATHEMATICAL:
            # ENHANCED: More demanding mathematical prompts
            return [
                ("algebraic", 
                """Solve using algebraic methods. 

    CRITICAL REQUIREMENTS:
    1. Show EVERY step of your calculation explicitly
    2. Write out ALL arithmetic operations (don't skip steps)
    3. Label your steps clearly (Step 1, Step 2, etc.)
    4. MUST end with: ANSWER: [number]

    Be systematic and careful with calculations."""),
                
                ("numerical", 
                """Solve using numerical calculations step-by-step.

    CRITICAL REQUIREMENTS:
    1. Convert word problem to concrete numbers immediately
    2. Show ALL arithmetic: 5 + 3 = 8, then 8 × 2 = 16, etc.
    3. Label each calculation clearly
    4. MUST end with: ANSWER: [number]

    Work through the problem methodically. Double-check arithmetic."""),
            ][:num_paths]
        
        elif classification.question_type == QuestionType.FACTUAL:
            # ORIGINAL FACTUAL CODE PRESERVED
            return [
                ("direct", 
                "Provide the factual answer with supporting context. Be precise and cite established facts.")
            ]
        
        else:
            # ORIGINAL DEFAULT CODE PRESERVED
            return [
                ("analytical", 
                "Analyze this comprehensively. Break down the question and reason systematically."),
                ("evidence_based", 
                "Base your analysis on established knowledge. What do authoritative sources say?"),
            ][:num_paths]

        # ============================================================================
        # INTEGRATION NOTES
        # ============================================================================
        """
        WHY THESE CHANGES HELP:

        1. **Analytical First (CRITICAL)**
        - Analytical path gets +0.05 bonus in synthesis
        - Your data shows it's the best performer (6 correct answers that were overruled)
        - By placing it first, it gets generated first and gets the bonus

        2. **3 Strategies for CommonsenseQA**
        - Now returns 3 strategies (analytical, evidence_based, skeptical)
        - Supports 3-path generation (breaks 50-50 ties)
        - Adds skeptical as tie-breaker path

        3. **Enhanced Prompts with Specificity Guidance**
        - Each prompt now explicitly instructs to choose SPECIFIC over GENERAL
        - Emphasizes concrete actions/nouns over abstract concepts
        - Aligns with your SpecificityScorer philosophy
        - Guides model toward answers that will score well

        4. **Multi-line Prompts**
        - More detailed instructions help model understand task better
        - Explicit examples ("singing" > "making music")
        - Clear decision criteria

        5. **Directive Language**
        - Changed from suggestive ("Use...") to directive ("IMPORTANT: Choose...")
        - Makes prompts more forceful and clear
        - Model follows instructions more consistently

        EXPECTED IMPACT:
        - Analytical path performs even better (already best, now with better prompt)
        - All 3 paths guided toward specific answers
        - More consistent answer quality across paths
        - Better consensus when all paths agree on specificity

        USAGE:
        Simply replace your existing _select_strategies_ENHANCED method with this one.
        No other changes needed - works with your existing code.
        """
    # ------------------------------------------------------------------------
    # Dispatch correct generation method based on question type
    # ------------------------------------------------------------------------
    def _generate_path(
        self, query: str, strategy: str, instruction: str,
        classification: QuestionClassification
    ) -> Optional[ReasoningPath]:
        """Delegate path generation to binary or adaptive generator."""
        if classification.question_type == QuestionType.BINARY:
            return self._generate_binary_path(query, strategy, instruction, classification)
        else:
            return self._generate_adaptive_path(query, strategy, instruction, classification)
    
    # ------------------------------------------------------------------------
    # Binary question handler (YES/NO, TRUE/FALSE)
    # ------------------------------------------------------------------------
    def _generate_binary_path(
        self, query: str, strategy: str, instruction: str,
        classification: QuestionClassification
    ) -> Optional[ReasoningPath]:
        """Generate reasoning path for binary-type questions."""
        start_time = time.time()

        # Fixed prompt template for binary reasoning
        prompt = f"""Question: {query}

{instruction}

Provide your analysis in this EXACT format:

VERDICT: YES or NO

REASONING:
Step 1: [Your first point]
Step 2: [Your second point]
Step 3: [Your third point]

CONCLUSION: [Your final conclusion]

Be direct and committed in your verdict."""

        try:
            # Send prompt to Claude
            message = self.client.messages.create(
                model=self.model,
                max_tokens=1500,
                temperature=0.7,
                messages=[{"role": "user", "content": prompt}]
            )
            
            # Extract text and token counts
            response_text = message.content[0].text
            self.total_tokens_input += message.usage.input_tokens
            self.total_tokens_output += message.usage.output_tokens
            
            generation_time = time.time() - start_time
            
            # Parse structured reasoning from response
            path = self._parse_binary_response(query, response_text, strategy, classification)
            path.generation_time = generation_time
            path.question_type = QuestionType.BINARY
            
            return path
        except Exception as e:
            print(f"Claude API error for {strategy}: {e}")
            return None
    
    # ------------------------------------------------------------------------
    # Adaptive path generator for non-binary question types
    # ------------------------------------------------------------------------
    def _generate_adaptive_path(self, query: str, strategy: str, instruction: str,
                            classification: 'QuestionClassification') -> Optional['ReasoningPath']:
        """Generate reasoning path with COMMONSENSE-SPECIFIC prompts"""
        
        start_time = time.time()
        
        # NEW: Detect if this is a CommonsenseQA question
        is_commonsense_qa = 'Please select ONLY ONE' in query
        
        if is_commonsense_qa:
            # Extract the actual question and choices
            parts = query.split('\n\nPlease select ONLY ONE')
            question_text = parts[0]
            choices_text = parts[1] if len(parts) > 1 else ""
            
            # CRITICAL: Use commonsense-specific prompt
            prompt = self._build_commonsense_prompt(question_text, choices_text, strategy)
        else:
            # Regular adaptive prompt
            output_format = self._get_output_format(classification.question_type)
            prompt = f"""Question: {query}

    {instruction}

    {output_format}

    Be clear and systematic."""
        
        try:
            message = self.client.messages.create(
                model=self.model,
                max_tokens=800,  # Increased for better reasoning
                temperature=0.7,
                messages=[{"role": "user", "content": prompt}]
            )
            
            response_text = message.content[0].text
            self.total_tokens_input += message.usage.input_tokens
            self.total_tokens_output += message.usage.output_tokens
            
            generation_time = time.time() - start_time
            
            # Parse response
            path = self._parse_adaptive_response(query, response_text, strategy, classification)
            path.generation_time = generation_time
            return path
            
        except Exception as e:
            print(f"Generation error for {strategy}: {e}")
            return None

    def _build_commonsense_prompt(self, question: str, choices: str, strategy: str) -> str:
        """
        CRITICAL: Guide toward SPECIFIC, CONCRETE answers using commonsense reasoning.
        Emphasizes choosing the most direct, specific option over generic or abstract ones.
        """
        
        if strategy == "analytical":
            return f"""{question}

    {choices}

    Use EVERYDAY COMMONSENSE reasoning:

    ANALYSIS:
    Step 1: What is the question REALLY asking?
    - Focus on the MOST SPECIFIC, CONCRETE action or thing
    - Avoid generic or abstract answers
    - Think about practical, real-world context

    Step 2: Evaluate each option
    - Which answer is the MOST SPECIFIC and DIRECT?
    - Eliminate vague, generic, or overly broad options
    - Choose the answer that directly describes the action/thing, not a category
    - Remove options that are absurd, too literal, or technical

    Step 3: Select the MOST SPECIFIC answer
    - Prefer concrete actions over abstract concepts
    - Prefer specific items over general categories
    - Examples: "singing" beats "making music", "torn" beats "damaged", "attention" beats "walked"

    ANSWER: [Single letter A-E]: [choice text]

    CRITICAL: Choose the MOST SPECIFIC, CONCRETE option that makes real-world sense."""

        elif strategy == "evidence_based":
            return f"""{question}

    {choices}

    Use PRACTICAL KNOWLEDGE to find the MOST SPECIFIC answer:

    ANALYSIS:
    Step 1: What is the everyday context of this question?
    - What SPECIFIC action or thing is being asked about?
    - Think about how people actually use these concepts

    Step 2: Apply specificity test to each option
    - Which answer is MOST CONCRETE and DIRECT?
    - Eliminate generic categories in favor of specific instances
    - Remove vague or overly broad options

    Step 3: Select the MOST SPECIFIC, PRACTICAL answer
    - Choose the option that directly names the action/thing
    - Avoid abstract or categorical answers

    ANSWER: [Single letter A-E]: [choice text]

    Remember: Prefer SPECIFIC over GENERAL, CONCRETE over ABSTRACT."""

        else:  # skeptical or other
            return f"""{question}

    {choices}

    Think through this CAREFULLY:

    ANALYSIS:
    Step 1: Parse what the question means in NORMAL LANGUAGE
    - What SPECIFIC thing is being asked about?

    Step 2: Evaluate specificity of each option
    - Which is the MOST DIRECT and CONCRETE answer?
    - Eliminate generic, vague, or overly broad options

    Step 3: Select the MOST SPECIFIC answer that makes REAL-WORLD SENSE

    ANSWER: [Single letter A-E]: [choice text]

    Use common sense and choose the MOST SPECIFIC, CONCRETE option."""

    # ============================================================================
    # Added MATHEMATICAL PROMPTS (UPDATE _get_output_format)
    # ============================================================================
    def _get_output_format(self, question_type: QuestionType) -> str:
        """
        ENHANCED: Math prompts now REQUIRE explicit ANSWER statements.
        Other formats preserved unchanged.
        """
        if question_type == QuestionType.FACTUAL:
            return """Provide your answer:

    ANSWER: [factual answer]

    REASONING:
    Step 1: [supporting evidence]
    Step 2: [additional support]

    CONCLUSION: [summary]"""
        
        elif question_type == QuestionType.MATHEMATICAL:
            # ENHANCED: Much stricter format requirements
            return """Solve this step-by-step:

    SOLUTION:
    Step 1: [Understand the problem - identify what's given and what's asked]
    Step 2: [Set up equations or identify operations needed]
    Step 3: [Perform calculations with clear arithmetic]
    Step 4: [Continue solving until you reach the final number]

    CRITICAL: You MUST end with this exact format:
    ANSWER: [numerical value]

    Example: ANSWER: 42

    Show ALL calculations explicitly. Double-check your arithmetic."""
        
        else:
            return """Provide analysis:

    ANALYSIS:
    Step 1: [key insight]
    Step 2: [supporting evidence]
    Step 3: [conclusion]

    CONCLUSION: [summary]"""

    # ------------------------------------------------------------------------
    # Response parsing methods
    # ------------------------------------------------------------------------
    def _parse_binary_response(self, query, response, strategy, classification) -> ReasoningPath:
        """Parse binary response into a structured ReasoningPath."""
        verdict = self._extract_verdict(response)
        steps = self._extract_reasoning_steps(response, strategy, QuestionType.BINARY)
        conclusion = self._extract_conclusion(response, verdict)

        # FIX: Pass verdict as the answer parameter
        confidence = self._calculate_dynamic_confidence(steps, response, verdict)
        
        return ReasoningPath(
            path_id=f"{strategy}_{int(time.time()*1000)}",
            query=query,
            verdict=verdict,
            steps=steps,
            conclusion=conclusion,
            confidence=confidence,
            generation_strategy=strategy,
            raw_output=response,
            question_type=classification.question_type,
            complexity_level=classification.complexity_level
        )

    def _parse_adaptive_response(self, query, response, strategy, classification) -> ReasoningPath:
        """
        Parse adaptive (non-binary) responses into a structured ReasoningPath.
        
        NEW: Supports pluggable strategy providers for custom answer extraction.
        OLD: Falls back to built-in _extract_answer_improved for backward compatibility.
        """
        
        # NEW: Use strategy extraction if available
        if self.strategy is not None:
            answer = self.strategy.extract_answer(response)
        else:
            # OLD PATH: Backward compatibility
            answer = self._extract_answer_improved(response, classification.question_type)
        
        # Extract reasoning steps (unchanged)
        steps = self._extract_reasoning_steps(response, strategy, classification.question_type)
        
        # Extract conclusion (unchanged)
        conclusion = self._extract_conclusion(response, answer)

        # Calculate confidence (unchanged)
        confidence = self._calculate_dynamic_confidence(steps, response, answer)
        
        # Build and return ReasoningPath
        return ReasoningPath(
            path_id=f"{strategy}_{int(time.time()*1000)}",
            query=query,
            answer=answer,
            steps=steps,
            conclusion=conclusion,
            confidence=confidence,
            generation_strategy=strategy,
            raw_output=response,
            question_type=classification.question_type,
            complexity_level=classification.complexity_level
        )
    # ------------------------------------------------------------------------
    # Text extraction helpers
    # ------------------------------------------------------------------------
    def _extract_verdict(self, response: str) -> Optional[str]:
        """Extract a binary verdict (YES/NO/TRUE/FALSE) from response text."""
        verdict_patterns = [
            r'VERDICT:\s*([A-Z]+)',
            r'Verdict:\s*([A-Z]+)',
        ]
        
        for pattern in verdict_patterns:
            match = re.search(pattern, response, re.MULTILINE | re.IGNORECASE)
            if match:
                verdict = match.group(1).upper()
                if verdict in ['TRUE', 'FALSE', 'YES', 'NO', 'DEPENDS', 'UNCLEAR']:
                    return verdict
        
        # Infer from natural language if explicit verdict missing
        response_lower = response.lower()
        if any(phrase in response_lower for phrase in ['is false', 'not true', 'myth']):
            return 'FALSE'
        elif any(phrase in response_lower for phrase in ['is true', 'correct', 'accurate']):
            return 'TRUE'
        
        return None

    def _extract_answer_improved(self, response: str, question_type: 'QuestionType') -> Optional[str]:
        """
        FIXED: Enhanced extraction for both mathematical and commonsense answers.
        Priority order: ANSWER marker > CONCLUSION > calculations > last number
        """
        
        print(f"\n[DEBUG ANSWER EXTRACT] Question type: {question_type}")
        print(f"[DEBUG ANSWER EXTRACT] Response preview: {response[:300]}")
        
        # ========================================================================
        # MATHEMATICAL QUESTION HANDLING
        # ========================================================================
        if question_type == QuestionType.MATHEMATICAL:
            # Strategy 1: Explicit ANSWER markers (highest priority)
            explicit_patterns = [
                r'ANSWER\s*[:\=]\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)',  # ANSWER: 42
                r'\\boxed\{(\d+(?:,\d{3})*(?:\.\d+)?)\}',           # \boxed{42}
                r'(?:final\s+answer|the\s+answer)(?:\s+is|\s*:|\s*=)\s*\$?\*?\*?(\d+(?:,\d{3})*(?:\.\d+)?)',
            ]
            
            for pattern in explicit_patterns:
                match = re.search(pattern, response, re.IGNORECASE)
                if match:
                    answer = match.group(1).replace(',', '').strip()
                    print(f"[DEBUG ANSWER EXTRACT] ✓ Found explicit answer: {answer}")
                    return answer
            
            # Strategy 2: CONCLUSION section
            conclusion_match = re.search(
                r'CONCLUSION:\s*(.+?)(?:\n\n|VERIFICATION|$)', 
                response, re.IGNORECASE | re.DOTALL
            )
            if conclusion_match:
                conclusion_text = conclusion_match.group(1)
                # Extract first number from conclusion
                num_match = re.search(r'\$?(\d+(?:,\d{3})*(?:\.\d+)?)', conclusion_text)
                if num_match:
                    answer = num_match.group(1).replace(',', '')
                    print(f"[DEBUG ANSWER EXTRACT] ✓ Found in conclusion: {answer}")
                    return answer
            
            # Strategy 3: Look for "= NUMBER" in last 5 lines
            lines = response.split('\n')
            for line in reversed(lines[-5:]):
                calc_match = re.search(r'=\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*$', line)
                if calc_match:
                    answer = calc_match.group(1).replace(',', '')
                    print(f"[DEBUG ANSWER EXTRACT] ✓ Extracted from calculation: {answer}")
                    return answer
            
            # Strategy 4: Last number in response (last resort)
            last_third = response[int(len(response) * 0.7):]
            numbers = re.findall(r'\$?(\d+(?:,\d{3})*(?:\.\d+)?)', last_third)
            if numbers:
                answer = numbers[-1].replace(',', '')
                print(f"[DEBUG ANSWER EXTRACT] ⚠ Using last number: {answer}")
                return answer
            
            print(f"[DEBUG ANSWER EXTRACT] ✗ No numerical answer found")
            return None
        
        # ========================================================================
        # COMMONSENSE/OTHER QUESTION HANDLING (ORIGINAL CODE PRESERVED)
        # ========================================================================
        
        # Strategy 1: Look for "ANSWER: X: text" format
        answer_patterns = [
            r'ANSWER:\s*([A-E]):\s*(.+?)(?:\n\n|CONCLUSION|VERIFICATION|$)',
            r'ANSWER:\s*\*?\*?([A-E])\s*:\s*(.+?)(?:\n\n|$)',
            r'##?\s*ANSWER:?\s*([A-E]):\s*(.+?)(?:\n|$)',
        ]
        
        for pattern in answer_patterns:
            match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
            if match:
                letter = match.group(1).upper()
                text = match.group(2).strip()
                text = re.sub(r'\*\*|__|`', '', text)
                answer = f"{letter}: {text}"
                print(f"[DEBUG ANSWER EXTRACT] ✓ Found via pattern: {answer[:60]}...")
                return answer
        
        # Strategy 2: CONCLUSION section for choice answers
        conclusion_patterns = [
            r'CONCLUSION:\s*([A-E]):\s*(.+?)(?:\n\n|$)',
            r'CONCLUSION:\s*\*?\*?([A-E])\s+(?:is|are|would be)\s+(.+?)(?:\n|$)',
        ]
        
        for pattern in conclusion_patterns:
            match = re.search(pattern, response, re.IGNORECASE | re.DOTALL)
            if match:
                conclusion = f"{match.group(1)}: {match.group(2)}"
                conclusion = re.sub(r'\*\*|__|`', '', conclusion)
                print(f"[DEBUG ANSWER EXTRACT] ✓ Found in conclusion: {conclusion[:60]}...")
                return conclusion
        
        print(f"[DEBUG ANSWER EXTRACT] ✗ No answer found")
        return None

    def _extract_mathematical_answer(self, response: str) -> Optional[str]:
        """
        NEW METHOD: Extract numerical answers from mathematical reasoning.
        Tries multiple strategies with robust fallbacks.
        """
        
        # Strategy 1: Explicit ANSWER markers (highest priority)
        explicit_patterns = [
            # "ANSWER: 42" or "ANSWER = 42"
            r'ANSWER\s*[:\=]\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)',
            
            # LaTeX boxed format: \boxed{42}
            r'\\boxed\{([^}]+)\}',
            
            # "Final Answer: 42" or "The answer is 42"
            r'(?:final\s+answer|the\s+answer)(?:\s+is|\s*:|\s*=)\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)',
            
            # "Therefore 42" or "Thus 42"
            r'(?:therefore|thus|hence)\s*,?\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)',
        ]
        
        for pattern in explicit_patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                answer = match.group(1).strip()
                # Clean formatting
                answer = answer.replace(',', '')  # Remove thousand separators
                print(f"[DEBUG MATH EXTRACT] ✓ Found explicit answer: {answer}")
                return answer
        
        # Strategy 2: Number with units (common in word problems)
        unit_patterns = [
            r'(\d+(?:,\d{3})*(?:\.\d+)?)\s+(?:dollars?|pounds?|miles?|hours?|minutes?|days?|weeks?|items?|bags?|people|students?|plants?|eggs?|toys?|books?|candies?|apples?)',
            r'(\d+(?:,\d{3})*(?:\.\d+)?)\s+(?:is|are|will be|remain|left|total)',
        ]
        
        # Search in last 30% of response (where answers typically appear)
        last_portion = response[int(len(response) * 0.7):]
        
        for pattern in unit_patterns:
            matches = list(re.finditer(pattern, last_portion, re.IGNORECASE))
            if matches:
                # Take the LAST match (most likely to be final answer)
                answer = matches[-1].group(1).replace(',', '')
                print(f"[DEBUG MATH EXTRACT] ✓ Found answer with units: {answer}")
                return answer
        
        # Strategy 3: Look in CONCLUSION section specifically
        conclusion_match = re.search(r'CONCLUSION:(.+?)(?:\n\n|$)', response, re.IGNORECASE | re.DOTALL)
        if conclusion_match:
            conclusion_text = conclusion_match.group(1)
            # Find any number in conclusion
            num_match = re.search(r'(\d+(?:,\d{3})*(?:\.\d+)?)', conclusion_text)
            if num_match:
                answer = num_match.group(1).replace(',', '')
                print(f"[DEBUG MATH EXTRACT] ✓ Found in conclusion: {answer}")
                return answer
        
        # Strategy 4: Extract from calculations (FALLBACK)
        # Look for "= NUMBER" in last 5 lines
        lines = response.split('\n')
        for line in reversed(lines[-5:]):
            calc_match = re.search(r'=\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*$', line)
            if calc_match:
                answer = calc_match.group(1).replace(',', '')
                print(f"[DEBUG MATH EXTRACT] ✓ Extracted from calculation: {answer}")
                return answer
        
        # Strategy 5: Last number mentioned (LAST RESORT)
        # Find all numbers in last 30% of response
        numbers = re.findall(r'(\d+(?:,\d{3})*(?:\.\d+)?)', last_portion)
        if numbers:
            answer = numbers[-1].replace(',', '')
            print(f"[DEBUG MATH EXTRACT] ⚠ Using last number: {answer}")
            return answer
        
        print(f"[DEBUG MATH EXTRACT] ✗ No numerical answer found")
        return None

    def _extract_reasoning_steps(self, response, strategy, question_type) -> List['LogicalStep']:
        """
        FIXED: Properly extracts reasoning steps from markdown-heavy responses.
        Handles headers, bold text, and nested formatting.
        """
        
        print(f"\n[DEBUG STEP EXTRACT] Strategy: {strategy}")
        
        steps = []
        reasoning_text = ""
        
        # Try to locate reasoning sections
        section_patterns = [
            (r'SOLUTION:(.*?)(?=ANSWER:|CONCLUSION:|VERIFICATION:|$)', 'SOLUTION'),
            (r'ANALYSIS:(.*?)(?=ANSWER:|CONCLUSION:|$)', 'ANALYSIS'),
            (r'REASONING:(.*?)(?=ANSWER:|CONCLUSION:|$)', 'REASONING'),
        ]
        
        for pattern, section_name in section_patterns:
            match = re.search(pattern, response, re.DOTALL | re.IGNORECASE)
            if match:
                reasoning_text = match.group(1)
                print(f"[DEBUG STEP EXTRACT] Found {section_name} section")
                break
        
        # Fallback: use entire response
        if not reasoning_text:
            reasoning_text = response
            print(f"[DEBUG STEP EXTRACT] Using full response")
        
        # Split into lines and clean
        lines = [l.strip() for l in reasoning_text.split('\n') if l.strip()]
        
        # Enhanced patterns that handle your actual output
        patterns = [
            # Markdown headers: ## Step 1: Content
            (r'^##?\s*[Ss]tep\s+(\d+)[:\.\-]?\s*(.+)', 'Markdown Header'),
            
            # Bold headers: **Step 1:** Content
            (r'^\*\*[Ss]tep\s+(\d+)[:\.\-]?\*\*\s*(.+)', 'Bold Header'),
            
            # Simple: Step 1: Content
            (r'^[Ss]tep\s+(\d+)[:\.\-]\s*(.+)', 'Simple Step'),
            
            # Markdown subheaders: ### Content
            (r'^###\s+(.+)', 'Subheader'),
            
            # Numbered lists: 1. Content or 1) Content
            (r'^(\d+)[\.\)]\s+(.+)', 'Numbered'),
            
            # Action verbs at start (your actual format)
            (r'^(Analyze|Evaluate|Identify|Determine|Select|Compare|Calculate|Solve|Define|Explain)\b\s*[:\-]?\s*(.+)', 'Action Verb'),
            
            # Bold keywords: **Key Insight** or **Supporting Evidence**
            (r'^\*\*([A-Z][a-zA-Z\s]+)\*\*\s*(.+)', 'Bold Keyword'),
        ]
        
        step_count = 0
        
        for line in lines:
            # Skip very short lines
            if len(line) < 15:
                continue
            
            # Skip lines that are just headers without content
            if line.startswith('##') and len(line) < 30:
                continue
            
            matched = False
            
            for pattern, pattern_name in patterns:
                match = re.match(pattern, line, re.IGNORECASE)
                if match:
                    groups = match.groups()
                    
                    # Extract content (varies by pattern)
                    if pattern_name in ['Markdown Header', 'Bold Header', 'Simple Step', 'Numbered']:
                        content = groups[-1].strip()  # Last group is content
                    elif pattern_name == 'Subheader':
                        content = groups[0].strip()
                    elif pattern_name in ['Action Verb', 'Bold Keyword']:
                        # Combine action/keyword with content
                        content = f"{groups[0]}: {groups[1]}" if len(groups) > 1 else groups[0]
                    else:
                        content = groups[-1].strip()
                    
                    # Remove extra markdown
                    content = re.sub(r'\*\*|__|`', '', content)
                    
                    # Must be substantial
                    if len(content) < 20:
                        continue
                    
                    # Detect math
                    is_math = (
                        '=' in content or
                        any(op in content.lower() for op in
                            ['divide', 'multiply', 'subtract', 'add', 'sum',
                            'calculate', 'solve', 'substitute', 'factor',
                            'simplify', 'expand', 'equation', 'formula'])
                    )
                    
                    # Dynamic confidence
                    confidence = self.confidence_assessor.assess_step_confidence(
                        content, is_math, question_type
                    )
                    
                    step = LogicalStep(
                        id=f"{strategy}_step_{step_count+1}",
                        operation=self._classify_operation(content),
                        content=content,
                        confidence=confidence,
                        is_mathematical=is_math
                    )
                    
                    # 🔥 MATH VERIFICATION (only for mathematical questions)
                    if is_math and question_type == QuestionType.MATHEMATICAL:
                        calculations = self.math_verifier.extract_calculations(content)
                        if calculations:
                            all_valid = True
                            for calc in calculations:
                                is_valid, feedback, correct_val = self.math_verifier.verify_calculation(calc)
                                if not is_valid:
                                    all_valid = False
                                    step.validation_status = ValidationStatus.INVALID
                                    step.validation_feedback = feedback
                                    step.confidence *= 0.5  # Reduce confidence for failed calc
                                    print(f"[MATH VERIFY] ❌ {feedback}")
                                    break
                            
                            if all_valid:
                                step.calculation_verified = True
                                step.confidence = min(step.confidence * 1.1, 0.95)  # Boost verified steps
                                print(f"[MATH VERIFY] ✓ All calculations verified")
                    
                    steps.append(step)  # ✅ ONLY APPEND ONCE!
                    step_count += 1
                    matched = True
                    
                    print(f"[DEBUG STEP EXTRACT] ✓ {pattern_name}: {content[:50]}...")
                    break
            
            # If no pattern matched but line looks like reasoning content
            if not matched and len(line) > 40:
                # Check if it's a meaningful sentence (has verb)
                if any(word in line.lower() for word in ['is', 'are', 'has', 'have', 'consists', 'includes', 'contains']):
                    content = re.sub(r'\*\*|__|`', '', line)
                    
                    step = LogicalStep(
                        id=f"{strategy}_step_{step_count+1}",
                        operation=LogicalOperation.INFERENCE,
                        content=content,
                        confidence=0.70,
                        is_mathematical=False
                    )
                    steps.append(step)
                    step_count += 1
                    print(f"[DEBUG STEP EXTRACT] ✓ Freeform: {content[:50]}...")
        
        print(f"[DEBUG STEP EXTRACT] Total steps found: {len(steps)}")
        return steps

    def _calculate_dynamic_confidence(self, steps: List['LogicalStep'], response: str, 
                                    answer: Optional[str] = None) -> float:
        """
        FIXED: Better confidence calculation that doesn't undervalue good answers.
        """
        
        # If we have steps, use their confidence
        if steps and len(steps) > 0:
            avg_step_confidence = sum(s.confidence for s in steps) / len(steps)
            
            # Bonus for having an answer
            if answer and answer not in ["Unable to determine", ""]:
                avg_step_confidence += 0.10
            
            # Penalize uncertainty
            uncertainty_count = sum(1 for word in ['might', 'maybe', 'possibly', 'unclear', 'uncertain']
                                if word in response.lower())
            avg_step_confidence -= (uncertainty_count * 0.05)
            
            return min(max(avg_step_confidence, 0.30), 0.95)
        
        # No steps but we have an answer - assess answer quality
        if answer and answer not in ["Unable to determine", ""]:
            return self._assess_answer_quality(answer, response)
        
        # No steps and no answer - but check if response is substantial
        if len(response) > 200:
            return 0.55  # Some reasoning present
        
        return 0.30  # Minimal confidence

    def _assess_answer_quality(self, answer: str, response: str) -> float:
        """
        FIXED: Better answer quality assessment.
        """
        base_confidence = 0.60  # Start higher (was 0.5)
        
        # Length and detail
        if len(answer) > 50:
            base_confidence += 0.12
        elif len(answer) > 20:
            base_confidence += 0.08
        
        # Response length indicates reasoning depth
        if len(response) > 800:
            base_confidence += 0.10
        elif len(response) > 400:
            base_confidence += 0.05
        
        # Definitive language
        definitive_words = ['therefore', 'thus', 'must', 'always', 'is', 'are',
                        'the answer', 'conclusion', 'result', 'standard']
        definitive_count = sum(1 for word in definitive_words 
                            if word in response.lower())
        base_confidence += min(definitive_count * 0.03, 0.12)
        
        # Uncertainty penalty
        uncertainty_words = ['might', 'possibly', 'approximately', 'roughly', 
                            'likely', 'probably', 'seems', 'appears']
        uncertainty_count = sum(1 for word in uncertainty_words 
                            if word in response.lower())
        base_confidence -= uncertainty_count * 0.04
        
        # Evidence markers
        evidence_words = ['evidence', 'research', 'data', 'fact', 'proven', 
                        'verified', 'confirmed', 'universal', 'standard']
        evidence_count = sum(1 for word in evidence_words 
                            if word in response.lower())
        base_confidence += min(evidence_count * 0.025, 0.08)
        
        return min(max(base_confidence, 0.35), 0.90)

    def _determine_consensus_verdict(self, verdicts: List[str], paths: List[ReasoningPath]) -> str:
        """
        IMPROVED: Use confidence weighting + prefer most specific answer.
        """
        if not verdicts:
            return "Unable to determine"
        
        # Extract letters from verdicts (handle "A: text" format)
        clean_verdicts = []
        for v in verdicts:
            letter_match = re.match(r'^([A-E])', v)
            if letter_match:
                clean_verdicts.append(letter_match.group(1))
            else:
                clean_verdicts.append(v)
        
        # Count with confidence weighting
        weighted_votes = {}
        for i, verdict in enumerate(clean_verdicts):
            if i < len(paths):
                weight = paths[i].confidence
            else:
                weight = 0.5
            weighted_votes[verdict] = weighted_votes.get(verdict, 0) + weight
        
        # Get winner
        winner = max(weighted_votes.items(), key=lambda x: x[1])
        verdict_letter = winner[0]
        total_weight = sum(weighted_votes.values())
        
        # Find full text from original verdicts
        full_text = None
        for v in verdicts:
            if v.startswith(verdict_letter):
                full_text = v
                break
        
        if not full_text:
            full_text = verdict_letter
        
        # Determine consensus level
        if len(set(clean_verdicts)) == 1:
            return f"{full_text} (unanimous)"
        elif winner[1] / total_weight > 0.6:
            return f"{full_text} (strong consensus)"
        else:
            return f"{full_text} (weak consensus)"
    
    def _classify_operation(self, content: str) -> LogicalOperation:
        """Classify the logical type of a reasoning step (premise, inference, etc.)."""
        content_lower = content.lower()
        if any(word in content_lower for word in ['evidence', 'research', 'data']):
            return LogicalOperation.EVIDENCE
        elif any(word in content_lower for word in ['therefore', 'thus', 'implies']):
            return LogicalOperation.INFERENCE
        elif any(word in content_lower for word in ['however', 'but', 'although']):
            return LogicalOperation.COUNTERARGUMENT
        else:
            return LogicalOperation.PREMISE
    
    def _extract_conclusion(self, response: str, verdict_or_answer: Optional[str]) -> str:
        """Extract the conclusion section or fallback to verdict/answer."""
        conclusion_match = re.search(r'CONCLUSION:\s*(.+)', response, re.IGNORECASE | re.DOTALL)
        if conclusion_match:
            conclusion = conclusion_match.group(1).strip().split('\n')[0]
            if verdict_or_answer:
                return f"{verdict_or_answer}. {conclusion}"
            return conclusion
        
        if verdict_or_answer:
            return f"Answer: {verdict_or_answer}"
        
        return "See reasoning above"

    def _create_fallback_path(self, query: str,
                            classification: QuestionClassification = None) -> ReasoningPath:
        """Fallback path used when all generation attempts fail."""
        return ReasoningPath(
            path_id=f"fallback_{int(time.time())}",
            query=query,
            verdict="UNCLEAR" if classification and classification.question_type == QuestionType.BINARY else None,
            answer="Unable to determine",
            steps=[],
            conclusion="Generation failed",
            confidence=0.2,
            generation_strategy="fallback",
            question_type=classification.question_type if classification else QuestionType.BINARY,
            complexity_level=classification.complexity_level if classification else ComplexityLevel.MODERATE
        )
    
    def get_total_cost(self) -> float:
            """Estimate total Claude API cost based on token usage."""
            input_cost = (self.total_tokens_input / 1_000_000) * 3.0
            output_cost = (self.total_tokens_output / 1_000_000) * 15.0
            return input_cost + output_cost



In [9]:
# ============================================================================
# SYNTHESIZER - WITH VALIDATION, REGENERATION, DIVERGENCE DETECTION & SPECIFICITY
# ============================================================================
# This class combines multiple reasoning paths into one synthesized final answer.
# Features:
# - Validation of reasoning steps with regeneration
# - Numerical divergence detection and consensus building
# - Specificity-based answer selection
# - Confidence-based synthesis logic
# ============================================================================

import re
import time
from typing import List, Dict, Optional, Tuple
import anthropic


class SpecificityScorer:
    """
    Score answer specificity to prefer concrete answers over abstract ones.
    Example: "singing" (specific) > "making music" (generic)
    """
    
    @staticmethod
    def score_specificity(answer_text: str, question: str) -> float:
        """
        AGGRESSIVE specificity scoring for CommonsenseQA.
        
        Philosophy:
        - SHORTER is usually MORE specific ("singing" > "making music")
        - CONCRETE actions/nouns > abstract concepts
        - DIRECT terms > generic categories
        - Single words > phrases (usually)
        
        Higher score = more specific answer
        """
        score = 0.5  # baseline
        
        answer_lower = answer_text.lower().strip()
        answer_len = len(answer_text)
        
        # ========== Rule 1: LENGTH (CRITICAL) ==========
        # Shorter answers are typically more specific in CommonsenseQA
        if answer_len <= 8:  # Single short word (e.g., "singing", "torn")
            score += 0.30
        elif answer_len <= 15:  # Short phrase or compound word
            score += 0.20
        elif answer_len <= 25:  # Medium phrase
            score += 0.10
        elif answer_len > 40:  # Long explanatory answers
            score -= 0.15
        
        # ========== Rule 2: CONCRETE ACTION VERBS (HIGH PRIORITY) ==========
        # Specific actions are highly preferred
        concrete_actions = {
            # From your examples
            'singing', 'torn', 'walk', 'disturb', 'attention',
            # Common specific actions
            'walking', 'running', 'reading', 'writing', 'eating',
            'drinking', 'talking', 'listening', 'watching', 'playing',
            'building', 'cooking', 'cleaning', 'studying', 'sleeping',
            'dancing', 'swimming', 'driving', 'crying', 'laughing',
            'screaming', 'whispering', 'jumping', 'sitting', 'standing'
        }
        if any(action in answer_lower for action in concrete_actions):
            score += 0.35  # Very high bonus for concrete actions
        
        # ========== Rule 3: ABSTRACT/GENERIC TERMS (HEAVY PENALTY) ==========
        # Generic verbs and phrases are red flags
        abstract_terms = {
            # Generic verbs
            'making', 'doing', 'having', 'being', 'getting', 'going',
            'coming', 'taking', 'giving', 'putting', 'becoming',
            # Generic nouns
            'thing', 'stuff', 'something', 'anything', 'everything',
            # Generic descriptors
            'general', 'various', 'some', 'any', 'all',
            # Generic phrases from your data
            'live in', 'work in', 'be in', 'go to'
        }
        abstract_count = sum(1 for term in abstract_terms if term in answer_lower)
        if abstract_count > 0:
            score -= 0.30 * abstract_count  # Compound penalty
        
        # ========== Rule 4: SPECIFIC NOUNS (BONUS) ==========
        # Concrete, named things are more specific than categories
        specific_nouns = {
            # Locations
            'bank', 'library', 'hospital', 'school', 'office',
            'restaurant', 'store', 'park', 'theater', 'gym',
            # Objects
            'refrigerator', 'oven', 'desk', 'chair', 'door',
            'window', 'table', 'bed', 'car', 'phone',
            # Body parts (specific)
            'hand', 'foot', 'eye', 'ear', 'mouth', 'nose'
        }
        if any(noun in answer_lower for noun in specific_nouns):
            score += 0.25
        
        # ========== Rule 5: GENERIC CATEGORIES (PENALTY) ==========
        # Category words indicate lack of specificity
        generic_categories = {
            'place', 'location', 'area', 'region', 'spot', 'site',
            'thing', 'item', 'object', 'article',
            'activity', 'action', 'process', 'method', 'way',
            'type', 'kind', 'sort', 'form', 'category',
            'person', 'people', 'individual', 'someone'
        }
        if any(cat in answer_lower for cat in generic_categories):
            score -= 0.25
        
        # ========== Rule 6: "-ING" GERUNDS (CONTEXT-DEPENDENT) ==========
        # "singing" is specific, but "making music" uses generic "making"
        if answer_lower.endswith('ing'):
            # Check if it's a concrete action (already scored above) or generic verb
            if not any(action in answer_lower for action in concrete_actions):
                # It's a generic gerund like "making", "doing"
                score -= 0.10
        
        # ========== Rule 7: MULTI-WORD PHRASES (SUSPICIOUS) ==========
        # More words often = less specific in CommonsenseQA
        word_count = len(answer_text.split())
        if word_count >= 3:
            score -= 0.10 * (word_count - 2)  # Escalating penalty
        
        # ========== Rule 8: ARTICLES & PREPOSITIONS (OFTEN LESS SPECIFIC) ==========
        # "the bank" might be less specific than just "bank" in some contexts
        if any(answer_lower.startswith(article) for article in ['a ', 'an ', 'the ']):
            score -= 0.05
        
        # ========== Rule 9: COMPOUND SPECIFICITY BOOST ==========
        # Single concrete word = maximum specificity
        if (word_count == 1 and 
            answer_len <= 10 and
            any(action in answer_lower for action in concrete_actions | specific_nouns)):
            score += 0.20  # Extra bonus for perfect specificity
        
        # ========== Rule 10: QUESTION CONTEXT (OPTIONAL) ==========
        # If question asks "where", prefer specific locations
        # If question asks "what do", prefer specific actions
        if question:
            q_lower = question.lower()
            if 'where' in q_lower and any(noun in answer_lower for noun in specific_nouns):
                score += 0.10
            elif ('what do' in q_lower or 'what are' in q_lower) and any(action in answer_lower for action in concrete_actions):
                score += 0.10
        
        # Clamp to valid range
        return max(0.0, min(1.0, score))

class AnswerSynthesizer:
    def __init__(self, client: anthropic.Anthropic, model: str, 
                 strategy: Optional[ReasoningStrategy] = None):
        self.client = client
        self.model = model
        self.strategy = strategy  # NEWf
        self.total_tokens_input = 0
        self.total_tokens_output = 0
        self.specificity_scorer = SpecificityScorer()
    
    def synthesize_final_answer(self, query: str, paths: List['ReasoningPath'],
                            classification: 'QuestionClassification') -> 'SynthesizedAnswer':
        """
        UPDATED: Strategy-based synthesis with backward compatibility.
        Preserves all existing validation, divergence detection, and verification logic.
        """
        
        # ========================================================================
        # STEP 1: VALIDATION (unchanged - runs regardless of strategy)
        # ========================================================================
        if classification.requires_validation:
            print(f"\n[1.5] Validating paths (complexity: {classification.complexity_level.value})...")
            paths = self._validate_and_regenerate_paths(paths, classification)
        
        # ========================================================================
        # STEP 2: DIVERGENCE DETECTION (unchanged)
        # ========================================================================
        print(f"\n[1.6] Checking for numerical divergence...")
        divergence_detected, divergent_pair = self._detect_numerical_divergence(paths)
        
        if divergence_detected and divergent_pair:
            print(f"\n[DIVERGENCE] Detected: {divergent_pair[0].generation_strategy} vs {divergent_pair[1].generation_strategy}")
            consensus_path = self._regenerate_for_consensus(query, divergent_pair, classification)
            if consensus_path:
                paths.append(consensus_path)
                print(f"[CONSENSUS] Path added ({len(paths)} total paths now)")
        
        # ========================================================================
        # STEP 3: ANSWER EXTRACTION (for mathematical questions)
        # ========================================================================
        if classification.question_type == QuestionType.MATHEMATICAL:
            print(f"\n[1.7] Extracting mathematical answers from paths...")
            
            # Collect all answers from paths
            extracted_answers = []
            for i, path in enumerate(paths):
                answer = None
                
                # Method 1: Use stored answer
                if path.answer and path.answer != "Unable to determine":
                    answer = path.answer
                    print(f"  Path {i+1}: Used stored answer: {answer}")
                
                # Method 2: Re-extract from raw output
                elif path.raw_output:
                    answer = self._extract_answer_from_raw_output(path.raw_output)
                    if answer:
                        path.answer = answer  # Update the path
                        print(f"  Path {i+1}: Re-extracted: {answer}")
                
                # Method 3: Extract from last step
                elif path.steps and len(path.steps) > 0:
                    last_step = path.steps[-1]
                    num_match = re.search(r'=\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)', last_step.content)
                    if num_match:
                        answer = num_match.group(1).replace(',', '')
                        path.answer = answer
                        print(f"  Path {i+1}: Extracted from last step: {answer}")
                
                if answer:
                    extracted_answers.append(answer)
                else:
                    print(f"  Path {i+1}: ✗ No answer found")
            
            # Update answers list for synthesis
            if extracted_answers:
                print(f"\n[1.8] Found {len(extracted_answers)} answers: {extracted_answers}")
        
        # ========================================================================
        # STEP 4: ANSWER SYNTHESIS (NEW: Strategy-based with fallback)
        # ========================================================================
        
        # NEW: Use strategy if provided
        if self.strategy is not None:
            print(f"\n[SYNTHESIS] Using strategy: {self.strategy.__class__.__name__}")
            answers = [p.answer for p in paths if p.answer and p.answer != "Unable to determine"]
            
            if answers:
                definitive_answer = self.strategy.select_final_answer(answers, paths)
                answer_format = "verdict" if classification.question_type == QuestionType.BINARY else "answer"
            else:
                print(f"[SYNTHESIS] No valid answers found, using fallback")
                definitive_answer = "Unable to determine"
                answer_format = "answer"
        
        # OLD PATH: Original synthesis logic (backward compatible)
        else:
            print(f"\n[SYNTHESIS] Using legacy synthesis")
            
            if classification.question_type == QuestionType.BINARY:
                verdicts = [p.verdict for p in paths if p.verdict]
                definitive_answer = self._determine_consensus_verdict(verdicts, paths)
                answer_format = "verdict"
            
            elif classification.question_type == QuestionType.MATHEMATICAL:
                # Use extracted answers from step 3
                answers = [p.answer for p in paths if p.answer and p.answer != "Unable to determine"]
                if answers:
                    definitive_answer = self._select_best_numerical_answer(answers, paths)
                else:
                    definitive_answer = "Unable to determine answer"
                answer_format = "answer"
            
            else:
                # CommonsenseQA and others
                answers = [p.answer for p in paths if p.answer]
                if answers:
                    definitive_answer = self._synthesize_answers_with_specificity(answers, paths)
                else:
                    definitive_answer = "Unable to determine"
                answer_format = "answer"
        
        # ========================================================================
        # STEP 5: MATHEMATICAL VERIFICATION (unchanged)
        # ========================================================================
        confidence_multiplier = 1.0  # Default: no adjustment
        
        if classification.question_type == QuestionType.MATHEMATICAL:
            print(f"\n[VERIFICATION] Verifying mathematical answer...")
            
            # Extract ALL calculations from ALL paths
            math_verifier = MathematicalVerifier()
            all_calculations = []
            
            for path in paths:
                for step in path.steps:
                    if step.is_mathematical:
                        calcs = math_verifier.extract_calculations(step.content)
                        all_calculations.extend(calcs)
            
            print(f"[VERIFICATION] Found {len(all_calculations)} total calculation(s) across all paths")
            
            # Verify each calculation
            failed_calcs = []
            verified_count = 0
            
            for calc in all_calculations:
                is_valid, feedback, correct_val = math_verifier.verify_calculation(calc)
                if not is_valid:
                    failed_calcs.append(feedback)
                else:
                    verified_count += 1
            
            # Adjust confidence based on verification results
            if failed_calcs:
                print(f"[VERIFICATION] Found {len(failed_calcs)} calculation error(s):")
                for error in failed_calcs[:3]:
                    print(f"  • {error}")
                
                failure_rate = len(failed_calcs) / len(all_calculations) if all_calculations else 0
                
                if failure_rate > 0.5:
                    confidence_multiplier = 0.5
                elif failure_rate > 0.25:
                    confidence_multiplier = 0.7
                else:
                    confidence_multiplier = 0.85
                
                print(f"[VERIFICATION] Confidence reduced by {(1-confidence_multiplier)*100:.0f}%")
            else:
                print(f"[VERIFICATION] All {len(all_calculations)} calculation(s) verified")
                confidence_multiplier = 1.1  # Boost for perfection
        
        # ========================================================================
        # STEP 6: SUPPORTING INFORMATION (unchanged)
        # ========================================================================
        supporting_reasoning = self._extract_key_points(paths)
        conflicting_points = self._identify_conflicts(paths)
        synthesis_explanation = self._generate_synthesis_explanation(
            query, paths, definitive_answer, conflicting_points
        )
        
        # ========================================================================
        # STEP 7: CONFIDENCE CALCULATION (unchanged)
        # ========================================================================
        base_confidence = self._calculate_synthesis_confidence(paths, conflicting_points)
        
        if classification.question_type == QuestionType.MATHEMATICAL:
            final_confidence = min(base_confidence * confidence_multiplier, 0.95)
            print(f"[VERIFICATION] Final confidence: {base_confidence:.2f} → {final_confidence:.2f}")
        else:
            final_confidence = base_confidence
        
        # ========================================================================
        # STEP 8: RETURN SYNTHESIZED ANSWER (unchanged)
        # ========================================================================
        return SynthesizedAnswer(
            query=query,
            definitive_answer=definitive_answer,
            supporting_reasoning=supporting_reasoning,
            conflicting_points=conflicting_points,
            final_confidence=final_confidence,
            synthesis_explanation=synthesis_explanation,
            question_type=classification.question_type,
            answer_format=answer_format
        )
    # =========================================================================
    # VALIDATION & REGENERATION
    # =========================================================================
    def _validate_and_regenerate_paths(self, paths: List['ReasoningPath'],
                                    classification: 'QuestionClassification') -> List['ReasoningPath']:
        """
        FIXED: Force validation for mathematical questions regardless of confidence.
        """
        validated_paths = []
        total_validations = 0
        total_regenerations = 0
        
        for path_idx, path in enumerate(paths):
            print(f"  Validating path {path_idx + 1}/{len(paths)} ({path.generation_strategy})...")
            
            # Skip empty paths
            if not path.steps:
                validated_paths.append(path)
                continue
            
            needs_regeneration = False
            failed_step_idx = -1
            
            # NEW: Force validation for math questions
            force_validation = (classification.question_type == QuestionType.MATHEMATICAL)
            
            # Validate each step within the path
            for step_idx, step in enumerate(path.steps):
                # MODIFIED: Skip confidence check if force_validation is True
                if not force_validation and step.confidence >= classification.confidence_threshold:
                    step.validation_status = ValidationStatus.VALID
                    continue
                
                previous_steps = path.steps[:step_idx]
                validation_result = self._validate_step(
                    step=step,
                    previous_steps=previous_steps,
                    query=path.query,
                    question_type=classification.question_type
                )
                
                path.validation_passes += 1
                total_validations += 1
                
                # Update step validation outcome
                if validation_result['is_valid']:
                    step.validation_status = ValidationStatus.VALID
                    step.validation_feedback = validation_result['feedback']
                    step.confidence = max(step.confidence, validation_result['confidence'])
                else:
                    # Mark step invalid and flag regeneration
                    step.validation_status = ValidationStatus.INVALID
                    step.validation_feedback = validation_result['feedback']
                    needs_regeneration = True
                    failed_step_idx = step_idx
                    print(f"    ✗ Step {step_idx + 1} failed validation")
                    break
            
            # Attempt regeneration if a failure occurred
            if needs_regeneration and failed_step_idx >= 0:
                print(f"    ↻ Regenerating from step {failed_step_idx + 1}...")
                regenerated_path = self._regenerate_from_failed_step(
                    path=path,
                    failed_step_idx=failed_step_idx,
                    classification=classification
                )
                
                if regenerated_path:
                    path = regenerated_path
                    path.regeneration_count += 1
                    total_regenerations += 1
                    print(f"    ✓ Regeneration successful")
                else:
                    print(f"    ✗ Regeneration failed, keeping original")
            
            validated_paths.append(path)
        
        print(f"  Validation complete: {total_validations} checks, {total_regenerations} regenerations")
        return validated_paths    

    def _validate_step(self, step: 'LogicalStep', previous_steps: List['LogicalStep'],
                      query: str, question_type: 'QuestionType') -> Dict:
        """Validate a single reasoning step using the model, with math verification."""
        
        # 🔥 MATH VERIFICATION FIRST (before expensive LLM call)
        if step.is_mathematical and question_type == QuestionType.MATHEMATICAL:
            math_verifier = MathematicalVerifier()
            calculations = math_verifier.extract_calculations(step.content)
            
            if calculations:
                print(f"[MATH VALIDATE] Found {len(calculations)} calculation(s) to verify")
                for calc in calculations:
                    is_valid, feedback, correct_val = math_verifier.verify_calculation(calc)
                    if not is_valid:
                        print(f"[MATH VALIDATE] ❌ {feedback}")
                        return {
                            'is_valid': False,
                            'confidence': 0.3,
                            'feedback': f"Math error: {feedback}"
                        }
                
                # All calculations valid
                print(f"[MATH VALIDATE] ✓ All {len(calculations)} calculation(s) verified")
                return {
                    'is_valid': True,
                    'confidence': 0.9,
                    'feedback': "✓ All mathematical calculations verified"
                }
        
        # 🔥 FALLBACK: LLM validation (for non-math or steps without extractable calculations)
        context = f"Original Question: {query}\n\n"
        
        if previous_steps:
            context += "Previous reasoning steps:\n"
            for i, prev_step in enumerate(previous_steps, 1):
                context += f"{i}. {prev_step.content}\n"
            context += "\n"
        
        context += f"Step to validate:\n{step.content}"
        
        # Adjust validation prompt depending on whether the step is mathematical
        if step.is_mathematical:
            prompt = f"""{context}

Validate this mathematical step. Check:
1. Is the arithmetic/algebra correct?
2. Does it follow logically from previous steps?
3. Are there any calculation errors?

Respond EXACTLY:
VALID: YES or NO
CONFIDENCE: 0.0 to 1.0
FEEDBACK: Brief explanation"""
        else:
            prompt = f"""{context}

Validate this reasoning step. Check:
1. Does it logically follow from previous steps?
2. Is it factually accurate?
3. Is it relevant to answering the question?

Respond EXACTLY:
VALID: YES or NO
CONFIDENCE: 0.0 to 1.0
FEEDBACK: Brief explanation"""
        
        # Send validation prompt to the Claude model
        try:
            message = self.client.messages.create(
                model=self.model,
                max_tokens=300,
                temperature=0.2,
                messages=[{"role": "user", "content": prompt}]
            )
            
            response = message.content[0].text
            self.total_tokens_input += message.usage.input_tokens
            self.total_tokens_output += message.usage.output_tokens
            
            # Parse structured response
            is_valid = 'VALID: YES' in response.upper()
            confidence_match = re.search(r'CONFIDENCE:\s*(0?\.\d+|1\.0)', response)
            confidence = float(confidence_match.group(1)) if confidence_match else 0.7
            feedback_match = re.search(r'FEEDBACK:\s*(.+?)(?:\n|$)', response, re.DOTALL)
            feedback = feedback_match.group(1).strip() if feedback_match else response[:100]
            
            return {'is_valid': is_valid, 'confidence': confidence, 'feedback': feedback}
            
        except Exception as e:
            # Fallback: mark as valid but low confidence if validation fails
            return {'is_valid': True, 'confidence': 0.6, 'feedback': f"Validation error: {str(e)[:50]}"}
    
    def _regenerate_from_failed_step(self, path: 'ReasoningPath', failed_step_idx: int,
                                    classification: 'QuestionClassification') -> Optional['ReasoningPath']:
        """Regenerate reasoning path after a failed step using alternative approach."""
        valid_steps = path.steps[:failed_step_idx]
        failed_step = path.steps[failed_step_idx]
        
        # Build context including validated steps and failure reason
        context = f"Question: {path.query}\n\n"
        if valid_steps:
            context += "These steps are correct:\n"
            for i, step in enumerate(valid_steps, 1):
                context += f"{i}. {step.content}\n"
            context += "\n"
        
        context += f"This step FAILED validation:\n{failed_step.content}\n"
        context += f"Reason: {failed_step.validation_feedback}\n\n"
        
        # Adjust regeneration strategy based on question type
        if classification.question_type == QuestionType.MATHEMATICAL:
            if 'algebraic' in path.generation_strategy.lower():
                approach = "Try NUMERICAL approach instead. Calculate with actual numbers."
            else:
                approach = "Try ALGEBRAIC approach instead. Use equations and variables."
        else:
            approach = "Try a completely different reasoning approach."
        
        prompt = f"""{context}{approach}

Continue solving from where valid steps ended. Show clear reasoning.

Format:
Step {failed_step_idx + 1}: [new step]
Step {failed_step_idx + 2}: [next step]
...
ANSWER: [final answer]"""
        
        # Generate new reasoning continuation
        try:
            message = self.client.messages.create(
                model=self.model,
                max_tokens=1000,
                temperature=0.7,
                messages=[{"role": "user", "content": prompt}]
            )
            
            response = message.content[0].text
            self.total_tokens_input += message.usage.input_tokens
            self.total_tokens_output += message.usage.output_tokens
            
            # Extract new reasoning steps and answer
            new_steps = self._extract_reasoning_steps_from_text(
                response, f"{path.generation_strategy}_regen", classification.question_type
            )
            
            if not new_steps:
                return None
            
            combined_steps = valid_steps + new_steps
            new_answer = self._extract_answer_from_text(response, classification.question_type)
            
            # Construct updated path object
            return ReasoningPath(
                path_id=path.path_id,
                query=path.query,
                verdict=None if classification.question_type != QuestionType.BINARY else new_answer,
                answer=new_answer if classification.question_type != QuestionType.BINARY else None,
                steps=combined_steps,
                conclusion=f"Regenerated: {new_answer}" if new_answer else "Regenerated",
                confidence=self._calculate_confidence_from_steps(new_answer, combined_steps, response),
                generation_strategy=f"{path.generation_strategy}_regenerated",
                raw_output=response,
                question_type=path.question_type,
                complexity_level=path.complexity_level,
                validation_passes=path.validation_passes,
                regeneration_count=path.regeneration_count
            )
            
        except Exception as e:
            print(f"    ⚠ Regeneration error: {e}")
            return None
    
    # =========================================================================
    # NUMERICAL DIVERGENCE DETECTION & CONSENSUS
    # =========================================================================
    
    def _detect_numerical_divergence(self, paths: List['ReasoningPath']) -> Tuple[bool, Optional[Tuple['ReasoningPath', 'ReasoningPath', float]]]:
        """
        Detect when paths have significant numerical disagreement (>10%).
        
        Returns:
            (divergence_detected, (path1, path2, pct_diff)) or (False, None)
        """
        numerical_answers = []
        
        for path in paths:
            if path.answer:
                nums = self._extract_numerical_values(path.answer.lower())
                if nums and len(nums) > 0:
                    # Take the primary (largest) number if multiple exist
                    primary_num = max(nums) if nums else None
                    if primary_num is not None:
                        numerical_answers.append((path, primary_num))
        
        # Need at least 2 numerical answers to compare
        if len(numerical_answers) < 2:
            return False, None
        
        # Check all pairs for divergence
        for i, (path1, num1) in enumerate(numerical_answers):
            for path2, num2 in numerical_answers[i+1:]:
                if num1 > 0 and num2 > 0:
                    pct_diff = abs(num1 - num2) / max(num1, num2)
                    
                    # Threshold: >10% difference triggers regeneration
                    if pct_diff > 0.10:
                        print(f"  ⚠️ DIVERGENCE DETECTED: {num1:.3f} vs {num2:.3f} ({pct_diff*100:.1f}% diff)")
                        return True, (path1, path2, pct_diff)
        
        return False, None

    def _regenerate_for_consensus(self, 
                                query: str,
                                divergent_paths: Tuple['ReasoningPath', 'ReasoningPath', float],
                                classification: 'QuestionClassification') -> Optional['ReasoningPath']:
        """
        When numerical answers diverge, generate a new path using a different approach.
        
        This uses alternative question framing or decomposition to find the truth.
        """
        path1, path2, pct_diff = divergent_paths
        
        print(f"\n  [REGENERATION] Attempting consensus path...")
        print(f"  Path 1: {path1.generation_strategy} → {path1.answer}")
        print(f"  Path 2: {path2.generation_strategy} → {path2.answer}")
        
        # Build a regeneration prompt that explicitly asks for verification
        prompt = f"""Question: {query}

Two independent approaches gave different numerical answers:
- Approach 1 ({path1.generation_strategy}): {path1.answer}
- Approach 2 ({path2.generation_strategy}): {path2.answer}

These differ by {pct_diff*100:.1f}%. 

Please solve this problem from scratch using a THIRD distinct approach. 
Be extremely careful with all calculations. Show every step explicitly.

Format your response as:

APPROACH: [Name of third method]

SOLUTION:
Step 1: [First step]
Step 2: [Second step]
Step 3: [Continue...]

ANSWER: [Final numerical answer with units]

VERIFICATION: [Double-check your calculation]"""
        
        try:
            message = self.client.messages.create(
                model=self.model,
                max_tokens=1500,
                temperature=0.5,  # Lower temperature for consistency
                messages=[{"role": "user", "content": prompt}]
            )
            
            response_text = message.content[0].text
            self.total_tokens_input += message.usage.input_tokens
            self.total_tokens_output += message.usage.output_tokens
            
            # Extract answer from regenerated response
            answer = self._extract_answer_from_regeneration(response_text)
            
            if answer:
                # Create new path for this consensus attempt
                new_path = ReasoningPath(
                    path_id=f"consensus_{int(time.time()*1000)}",
                    query=query,
                    answer=answer,
                    steps=[],  # Steps not extracted for brevity
                    conclusion=f"Consensus answer after divergence detection",
                    confidence=0.72,  # Moderate confidence for consensus attempts
                    generation_strategy="consensus_verification",
                    raw_output=response_text,
                    question_type=classification.question_type,
                    complexity_level=classification.complexity_level,
                )
                
                print(f"  ✓ Consensus path generated: {answer}")
                return new_path
            else:
                print(f"  ✗ Could not extract answer from consensus path")
                return None
                
        except Exception as e:
            print(f"  ✗ Regeneration error: {e}")
            return None

    def _extract_answer_from_regeneration(self, response: str) -> Optional[str]:
        """Extract answer from regeneration response."""
        # Try ANSWER: format first
        answer_match = re.search(
            r'ANSWER:\s*\*?\*?(.+?)(?:\n\n|VERIFICATION|$)',
            response, re.IGNORECASE | re.DOTALL
        )
        
        if answer_match:
            answer = answer_match.group(1).strip()
            answer = re.sub(r'^\*\*|\*\*$', '', answer).strip()
            if answer and len(answer) > 3:
                return answer
        
        return None

    # =========================================================================
    # SYNTHESIS WITH SPECIFICITY SCORING
    # =========================================================================
    
    def _synthesize_answers_with_specificity(self, answers: List[str], 
                                            paths: List['ReasoningPath']) -> str:
        """
        IMPROVED: Better tie-breaking for close calls
        """
        if not answers:
            return "Unable to determine answer"
        
        answer_groups = self._group_equivalent_answers(answers)
        
        if len(answer_groups) == 1:
            return f"{answers[0]} (unanimous)"
        
        # NEW: For CommonsenseQA, heavily weight specificity
        best_answer = None
        best_score = -1
        
        for i, path in enumerate(paths):
            if not path.answer or i >= len(answers):
                continue
            
            conf_score = path.confidence
            spec_score = self.specificity_scorer.score_specificity(path.answer, path.query)
            
            # FIXED: Weight 70% confidence, 30% specificity for CommonsenseQA
            if path.question_type == QuestionType.COMMONSENSE:
                combined = 0.70 * conf_score + 0.30 * spec_score
            else:
                combined = 0.70 * conf_score + 0.30 * spec_score
            
            print(f"  Path {i+1}: conf={conf_score:.2f}, spec={spec_score:.2f}, combined={combined:.2f}, answer='{path.answer[:40]}'")
            
            if combined > best_score:
                best_score = combined
                best_answer = path.answer
        
        if best_answer:
            return f"{best_answer} (best, score={best_score:.2f})"
        
        return f"{paths[0].answer} (fallback)"

    def _answers_equivalent(self, ans1: str, ans2: str) -> bool:
        """Check if two answers are meaningfully equivalent.
        
        Tries multiple strategies in order:
        1. Direct string match
        2. Numerical comparison (with tolerance)
        3. Math variable matching (x=5 style)
        4. Substring matching for short answers
        5. Word overlap for text answers
        """
        # Strip markdown
        clean1 = re.sub(r'\*\*|__|`|~~', '', ans1).lower().strip()
        clean2 = re.sub(r'\*\*|__|`|~~', '', ans2).lower().strip()
        
        # Strategy 1: Direct match after cleaning
        if clean1 == clean2:
            return True
        
        # Strategy 2: Extract and compare numerical values (handles percentages)
        numbers1 = self._extract_numerical_values(clean1)
        numbers2 = self._extract_numerical_values(clean2)
        
        # If BOTH have numbers, use strict numerical comparison
        if numbers1 and numbers2:
            if self._numbers_equivalent(numbers1, numbers2):
                return True
            else:
                # Numbers differ significantly → definitely not equivalent
                return False
        
        # Strategy 3: Compare variable=value pairs in math expressions
        solutions1 = set(re.findall(r'([a-z])\s*=\s*([-+]?\d+\.?\d*)', clean1))
        solutions2 = set(re.findall(r'([a-z])\s*=\s*([-+]?\d+\.?\d*)', clean2))
        if solutions1 and solutions2 and solutions1 == solutions2:
            return True
        
        # Strategy 4: Substring check for short answers
        # Handles "2" vs "2 is the only..."
        if clean1 in clean2 or clean2 in clean1:
            if len(clean1) < 20 or len(clean2) < 20:  # Short answer check
                return True
        
        # Strategy 5: Word overlap ratio
        # Only apply if NEITHER has strong numerical content
        if not numbers1 and not numbers2:
            words1 = set(clean1.split())
            words2 = set(clean2.split())
            if words1 and words2:
                return len(words1 & words2) / len(words1 | words2) > 0.7
        
        return False

    def _group_equivalent_answers(self, answers: List[str]) -> List[List[str]]:
        """Group answers that are semantically equivalent (even if text differs)."""
        groups = []
        used = set()
        
        for i, ans_i in enumerate(answers):
            if i in used:
                continue
            
            group = [ans_i]
            used.add(i)
            
            for j, ans_j in enumerate(answers):
                if j <= i or j in used:
                    continue
                
                if self._answers_equivalent(ans_i, ans_j):
                    group.append(ans_j)
                    used.add(j)
            
            groups.append(group)
        
        return groups
    
    def _extract_numerical_values(self, text: str) -> list:
        """Extract all numerical values from text, including percentages."""
        numbers = []
        
        # Match percentages (88%, 71.2%)
        percentage_matches = re.findall(r'([\d.]+)\s*%', text)
        numbers.extend([float(p) / 100 for p in percentage_matches])
        
        # Match standalone numbers (including decimals)
        # Avoid double-counting percentages already extracted
        number_matches = re.findall(r'(?:^|\s)([\d.]+)(?:\s|$|[^%\d.])', text)
        for n in number_matches:
            if n not in percentage_matches:
                try:
                    numbers.append(float(n))
                except ValueError:
                    pass
        
        # Match fractions like "99/100"
        fraction_matches = re.findall(r'(\d+)/(\d+)', text)
        numbers.extend([float(n) / float(d) for n, d in fraction_matches])
        
        return sorted(set(numbers))

    def _numbers_equivalent(self, nums1: list, nums2: list) -> bool:
        """Compare two lists of numbers with tolerance.
        
        Different number of values = different answers.
        Large differences in values = not equivalent.
        """
        if len(nums1) != len(nums2):
            return False
        
        PERCENTAGE_TOLERANCE = 0.05  # 5 percentage points (88% vs 93% = different)
        RELATIVE_TOLERANCE = 0.1      # 10% relative tolerance for other numbers
        
        for n1, n2 in zip(nums1, nums2):
            # For values in 0-1 range (likely percentages)
            if 0 <= n1 <= 1 and 0 <= n2 <= 1:
                if abs(n1 - n2) > PERCENTAGE_TOLERANCE:
                    return False
            else:
                # For larger numbers, use relative tolerance
                max_allowed_diff = max(RELATIVE_TOLERANCE * max(abs(n1), abs(n2)), 0.01)
                if abs(n1 - n2) > max_allowed_diff:
                    return False
        
        return True

    # =========================================================================
    # EXTRACTION & CONFIDENCE HELPERS
    # =========================================================================
    
    def _extract_mathematical_answer_enhanced(self, response: str, paths: List['ReasoningPath']) -> str:
        """
        ENHANCED: Extract answers from complex math problems including:
        - Coordinate pairs: (3, π/2)
        - Symbolic expressions: p - q, 2q - 3p
        - Fractions: 14/3
        - Mixed notation: boxed{...}, $...$
        """
        
        # Strategy 1: LaTeX boxed answers (highest priority)
        boxed_match = re.search(r'\\boxed\{([^}]+)\}', response)
        if boxed_match:
            answer = boxed_match.group(1).strip()
            print(f"[MATH EXTRACT] ✓ Found boxed answer: {answer}")
            return self._clean_math_notation(answer)
        
        # Strategy 2: Coordinate pairs (r, θ) or (x, y)
        coord_patterns = [
            r'\((\d+(?:\.\d+)?)\s*,\s*\\frac\{\\pi\}\{(\d+)\}\)',  # (3, π/2)
            r'\((\d+(?:\.\d+)?)\s*,\s*\\pi/(\d+)\)',                # (3, π/2)
            r'\((\d+(?:\.\d+)?)\s*,\s*([^\)]+)\)',                  # (3, 1.57)
        ]
        
        for pattern in coord_patterns:
            match = re.search(pattern, response)
            if match:
                answer = f"({match.group(1)}, {match.group(2)})"
                print(f"[MATH EXTRACT] ✓ Found coordinate: {answer}")
                return answer
        
        # Strategy 3: Symbolic expressions (p - q, 2q, etc.)
        symbolic_patterns = [
            r'(?:answer|result|equals?|is)\s*(?:is|=)?\s*([pq]\s*[-+]\s*[pq])',  # p - q
            r'(?:answer|result|equals?|is)\s*(?:is|=)?\s*(\d+[pq])',              # 2q
            r'(?:answer|result|equals?|is)\s*(?:is|=)?\s*([pq]/\d+)',             # q/2
        ]
        
        for pattern in symbolic_patterns:
            match = re.search(pattern, response, re.IGNORECASE)
            if match:
                answer = match.group(1).strip()
                print(f"[MATH EXTRACT] ✓ Found symbolic: {answer}")
                return answer
        
        # Strategy 4: Fractions (14/3, 5/2)
        fraction_pattern = r'(?:answer|result|equals?|is)\s*(?:is|=)?\s*(\d+/\d+)'
        fraction_match = re.search(fraction_pattern, response, re.IGNORECASE)
        if fraction_match:
            answer = fraction_match.group(1)
            print(f"[MATH EXTRACT] ✓ Found fraction: {answer}")
            return answer
        
        # Strategy 5: Extract from CONCLUSION section
        conclusion_match = re.search(r'CONCLUSION:(.+?)(?:\n\n|$)', response, re.IGNORECASE | re.DOTALL)
        if conclusion_match:
            conclusion_text = conclusion_match.group(1)
            
            # Try all extraction methods on conclusion
            for pattern in coord_patterns + symbolic_patterns + [fraction_pattern]:
                match = re.search(pattern, conclusion_text)
                if match:
                    answer = match.group(1) if len(match.groups()) == 1 else match.group(0)
                    print(f"[MATH EXTRACT] ✓ Found in conclusion: {answer}")
                    return self._clean_math_notation(answer)
        
        # Strategy 6: Fallback to calculations in reasoning steps
        print(f"[MATH EXTRACT] ⚠️ Using fallback: extracting from reasoning steps")
        return self._extract_from_reasoning_calculations(paths)

    def _clean_math_notation(self, text: str) -> str:
        """Clean LaTeX and math notation from extracted answers"""
        # Remove LaTeX commands
        text = re.sub(r'\\(frac|pi|theta|cdot|times|left|right)', '', text)
        # 
        # Remove extra spaces
        text = re.sub(r'\s+', ' ', text).strip()
        # Remove dollar signs
        text = text.replace('$', '')
        return text

    def _extract_from_reasoning_calculations(self, paths: List['ReasoningPath']) -> str:
        """
        Fallback: Extract answer from verified calculations in reasoning steps
        """
        for path in paths:
            for step in path.steps:
                if step.calculation_verified and step.is_mathematical:
                    # Look for "= result" in the step content
                    equals_match = re.search(r'=\s*([^\s,]+)(?:\s|$)', step.content)
                    if equals_match:
                        result = equals_match.group(1)
                        print(f"[MATH EXTRACT] ✓ Extracted from verified step: {result}")
                        return result
        
        return "Unable to determine answer"
    
    def _extract_reasoning_steps_from_text(self, text: str, strategy: str,
                                        question_type: 'QuestionType') -> List['LogicalStep']:
        """Extract structured reasoning steps from model-generated text."""
        steps = []
        lines = text.split('\n')
        
        for line in lines:
            step_match = re.match(r'^(?:step\s+)?(\d+)[\.\):\-]\s*(.+)', line, re.IGNORECASE)
            if step_match and len(step_match.group(2)) > 10:
                content = step_match.group(2).strip()
                is_math = '=' in content or any(op in content.lower()
                                            for op in ['calculate', 'solve', 'divide'])
                
                step = LogicalStep(
                    id=f"{strategy}_step_{len(steps)+1}",
                    operation=LogicalOperation.INFERENCE,
                    content=content,
                    confidence=0.75,
                    is_mathematical=is_math
                )
                steps.append(step)
        
        return steps
    
    def _extract_answer_from_text(self, text: str, question_type: 'QuestionType') -> Optional[str]:
        """Extract the final answer from regenerated text."""
        answer_match = re.search(r'ANSWER:\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
        if answer_match:
            return answer_match.group(1).strip()
        
        # Fallback: last line may contain a mathematical result
        lines = [l.strip() for l in text.split('\n') if l.strip()]
        if lines:
            last_line = lines[-1]
            if '=' in last_line:
                return last_line
        
        return None

    def _calculate_confidence_from_steps(self, answer: Optional[str],
                                        steps: List['LogicalStep'], text: str) -> float:
        """Estimate path confidence based on number of valid steps and uncertainty."""
        base = 0.7
        if answer:
            base += 0.1
        if len(steps) >= 3:
            base += 0.05
        uncertainty = sum(1 for word in ['might', 'maybe', 'possibly']
                        if word in text.lower())
        base -= uncertainty * 0.03
        return min(max(base, 0.1), 0.95)
    
    def _determine_consensus_verdict(self, verdicts: List[str], paths: List['ReasoningPath']) -> str:
        """Aggregate binary verdicts into a consensus (unanimous, majority, or contested)."""
        if not verdicts:
            return "Unable to determine"
        
        verdict_counts = {}
        for v in verdicts:
            verdict_counts[v] = verdict_counts.get(v, 0) + 1
        
        verdict, count = max(verdict_counts.items(), key=lambda x: x[1])
        
        if count == len(verdicts):
            return f"{verdict} (unanimous)"
        elif count > len(verdicts) / 2:
            return f"{verdict} (majority {count}/{len(verdicts)})"
        else:
            return f"{verdict} (contested {count}/{len(verdicts)})"
    
    def _extract_key_points(self, paths: List['ReasoningPath']) -> List[str]:
        """Collect key reasoning evidence and inference steps for synthesis summary."""
        key_points = []
        for path in paths:
            relevant = [s for s in path.steps if s.operation in [LogicalOperation.EVIDENCE, LogicalOperation.INFERENCE]]
            for step in relevant[:2]:
                key_points.append(f"[{path.generation_strategy}] {step.content}")
        return key_points[:5]
        
    def _identify_conflicts(self, paths: List['ReasoningPath']) -> List[str]:
        """Identify disagreements between reasoning paths."""
        conflicts = []
        
        # For verdicts
        verdicts = [p.verdict for p in paths if p.verdict]
        if len(set(verdicts)) > 1:
            conflicts.append(f"Verdict disagreement: {', '.join(set(verdicts))}")
        
        # For answers - USE PROPER NORMALIZATION
        answers = [p.answer for p in paths if p.answer]
        if len(answers) > 1:
            answer_groups = self._group_equivalent_answers(answers)
            if len(answer_groups) > 1:
                conflicts.append("Answer disagreement")
        
        return conflicts
    
    def _generate_synthesis_explanation(self, query: str, paths: List['ReasoningPath'],
                                    answer: str, conflicts: List[str]) -> str:
        """Generate a readable summary explaining synthesis outcomes."""
        explanation = f"Analyzed {len(paths)} independent reasoning approaches.\n"
        if not conflicts:
            explanation += "All paths converged with consistent reasoning."
        else:
            explanation += f"Found {len(conflicts)} disagreement(s)."
        return explanation
        
    def _extract_answer_from_raw_output(self, raw_output: str) -> Optional[str]:
        """
        NEW: Extract numerical answer from raw model output.
        Used when path.answer is not populated.
        """
        # Pattern 1: ANSWER: 42
        answer_match = re.search(
            r'ANSWER\s*[:\=]\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)',
            raw_output, re.IGNORECASE
        )
        if answer_match:
            return answer_match.group(1).replace(',', '')
        
        # Pattern 2: \boxed{42}
        boxed_match = re.search(r'\\boxed\{(\d+(?:,\d{3})*(?:\.\d+)?)\}', raw_output)
        if boxed_match:
            return boxed_match.group(1).replace(',', '')
        
        # Pattern 3: Last "= NUMBER" in response
        lines = raw_output.split('\n')
        for line in reversed(lines[-5:]):
            calc_match = re.search(r'=\s*\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*$', line)
            if calc_match:
                return calc_match.group(1).replace(',', '')
        
        return None

    def _select_best_numerical_answer(self, answers: List[str], 
                                 paths: List['ReasoningPath']) -> str:
        """
        NEW: Select best numerical answer using majority vote + confidence weighting.
        """
        if not answers:
            return "Unable to determine answer"
        
        # Convert to numbers
        numerical_answers = []
        for i, ans in enumerate(answers):
            # Extract number from answer string
            num_match = re.search(r'(\d+(?:\.\d+)?)', ans.replace(',', ''))
            if num_match:
                try:
                    num = float(num_match.group(1))
                    confidence = paths[i].confidence if i < len(paths) else 0.5
                    numerical_answers.append((num, confidence, ans))
                except:
                    pass
        
        if not numerical_answers:
            return answers[0]  # Fallback to first answer
        
        # Count occurrences with confidence weighting
        vote_scores = {}
        for num, conf, original in numerical_answers:
            if num not in vote_scores:
                vote_scores[num] = {'score': 0, 'count': 0, 'original': original}
            vote_scores[num]['score'] += conf
            vote_scores[num]['count'] += 1
        
        # Select answer with highest weighted score
        best_num = max(vote_scores.items(), key=lambda x: (x[1]['count'], x[1]['score']))
        result = best_num[1]['original']
        
        # Format result
        if best_num[1]['count'] == len(numerical_answers):
            return f"{result} (unanimous)"
        elif best_num[1]['count'] > len(numerical_answers) / 2:
            return f"{result} (consensus {best_num[1]['count']}/{len(numerical_answers)})"
        else:
            return f"{result} (best of {best_num[1]['count']}/{len(numerical_answers)})"
            
    def _calculate_synthesis_confidence(self, paths: List['ReasoningPath'], conflicts: List[str]) -> float:
        """Compute overall synthesis confidence, reducing it for conflicts."""
        avg_confidence = sum(p.confidence for p in paths) / len(paths) if paths else 0.5
        conflict_penalty = len(conflicts) * 0.1
        return min(max(avg_confidence - conflict_penalty, 0.15), 0.95)

In [10]:
# ============================================================================
# MAIN SYSTEM
# ============================================================================
# This class ties together the entire reasoning pipeline:
#   1. Classifies a question.
#   2. Generates multiple reasoning paths using Claude.
#   3. Validates and (if needed) regenerates reasoning steps.
#   4. Synthesizes a final answer with confidence scoring.
#   5. Displays results, including performance and validation metrics.
# ============================================================================

class ClaudeDynamicReasoningSystem:
    """High-level orchestrator for dynamic multi-path reasoning and synthesis using Claude."""

    def __init__(self, api_key: str):
        """Initialize the system with API key and required components."""
        print("Initializing Claude Dynamic Reasoning System...")
        self.generator = ClaudeReasoningGenerator(api_key)  # Generates reasoning paths
        self.synthesizer = AnswerSynthesizer(self.generator.client, self.generator.model)  # Handles synthesis + validation
        print("System ready - Claude 3.5 Sonnet with ALL 5 FIXES\n")

    # ------------------------------------------------------------------------
    # MAIN ENTRYPOINT: Run full reasoning → synthesis pipeline
    # ------------------------------------------------------------------------
    def reason_with_synthesis(self, query: str, num_paths: int = 3) -> NegotiationResult:
        """
        Run full reasoning workflow for a given query.

        Steps:
        1. Classify the question type and complexity.
        2. Generate multiple reasoning paths (in parallel for speed).
        3. Validate and synthesize a final answer.
        4. Aggregate all results, including timing and cost stats.
        """
        print(f"\n{'='*70}")
        print(f"QUERY: {query}")
        print('='*70)

        start_time = time.time()

        # --- STEP 0: Classify the question ---
        print("\n[0] Classifying question...")
        classification = self.generator.classifier.classify(query)
        print(f"Type: {classification.question_type.value}, Complexity: {classification.complexity_level.value}")
        print(f"Validation: {classification.requires_validation}, Num paths: {classification.num_paths}")

        # --- STEP 1: Generate reasoning paths ---
        print(f"\n[1] Generating reasoning paths...")
        paths, speedup = self.generator.generate_multiple_paths_parallel(query, num_paths, classification)

        print(f"Generated {len(paths)} path(s) ({speedup:.2f}x speedup):")
        for path in paths:
            # Summarize each reasoning path
            if path.verdict:
                answer_str = f"verdict={path.verdict}"
            elif path.answer:
                answer_str = f"answer={path.answer[:40]}..."
            else:
                answer_str = "no answer"
            print(f"  • {path.generation_strategy}: {len(path.steps)} steps, {answer_str}, conf={path.confidence:.2f}")

        # --- STEP 2: Synthesize the final answer ---
        print("\n[2] Synthesizing final answer...")
        synthesized = self.synthesizer.synthesize_final_answer(query, paths, classification)

        # --- STEP 3: Compute performance metrics ---
        total_time = time.time() - start_time
        total_cost = self.generator.get_total_cost()

        print(f"\n[3] Complete in {total_time:.2f}s")
        print(f"Answer: {synthesized.definitive_answer}")
        print(f"Confidence: {synthesized.final_confidence:.2f}")
        print(f"API Cost: ${total_cost:.4f}")

        # Collect validation + regeneration metrics
        total_validations = sum(p.validation_passes for p in paths)
        total_regenerations = sum(p.regeneration_count for p in paths)

        # --- STEP 4: Bundle everything into a NegotiationResult ---
        return NegotiationResult(
            original_paths=paths,
            synthesized_answer=synthesized,
            total_time=total_time,
            parallel_speedup=speedup,
            total_cost=total_cost,
            total_validations=total_validations,
            total_regenerations=total_regenerations,
            classification=classification
        )

    # ------------------------------------------------------------------------
    # DISPLAY: Nicely print the full reasoning and synthesis results
    # ------------------------------------------------------------------------
    def display_result(self, result: NegotiationResult):
        """
        Display the entire reasoning and synthesis output in a readable format,
        including reasoning paths, final answer, and performance statistics.
        """
        print("\n" + "="*70)
        print("FINAL RESULT")
        print("="*70)

        # --- Question classification overview ---
        if result.classification:
            print(f"\nQuestion Classification:")
            print(f"  Type: {result.classification.question_type.value}")
            print(f"  Complexity: {result.classification.complexity_level.value}")

        # --- Synthesized answer summary ---
        print("\n" + "▶"*35)
        print("SYNTHESIZED ANSWER")
        print("▶"*35)
        print(f"\nQuery: {result.synthesized_answer.query}")
        print(f"\nAnswer: {result.synthesized_answer.definitive_answer}")
        print(f"Confidence: {result.synthesized_answer.final_confidence:.2f}")

        # --- Supporting and conflicting reasoning ---
        if result.synthesized_answer.supporting_reasoning:
            print("\nKey Supporting Points:")
            for i, point in enumerate(result.synthesized_answer.supporting_reasoning, 1):
                print(f"  {i}. {point}")

        if result.synthesized_answer.conflicting_points:
            print("\nConflicting Aspects:")
            for conflict in result.synthesized_answer.conflicting_points:
                print(f"  ⚠ {conflict}")

        print("\nSynthesis Process:")
        print(result.synthesized_answer.synthesis_explanation)

        # --- Individual reasoning chains ---
        print("\n" + "-"*70)
        print("INDIVIDUAL REASONING PATHS")
        print("-"*70)

        for path in result.original_paths:
            print(f"\n{path.to_readable_chain()}")

        # --- Performance & API usage ---
        print("\n" + "-"*70)
        print("PERFORMANCE & COST")
        print("-"*70)
        print(f"Total time: {result.total_time:.2f}s")
        print(f"Parallel speedup: {result.parallel_speedup:.2f}x")
        print(f"Total API cost: ${result.total_cost:.4f}")

        # --- Validation statistics ---
        if result.total_validations > 0:
            print(f"\nValidation Statistics:")
            print(f"  Validation passes: {result.total_validations}")
            print(f"  Regenerations: {result.total_regenerations}")

        print("\n" + "="*70)





In [11]:
from datasets import load_dataset
import os
from pathlib import Path

ds = load_dataset("nlile/hendrycks-MATH-benchmark")

datasets_dir = Path.home() / ".cache" / "huggingface" / "datasets"
print(sorted(os.listdir(datasets_dir)))

#['_root_.cache_huggingface_datasets_commonsense_qa_default_0.0.0_94630fe30dad47192a8546eb75f094926d47e155.lock', '_root_.cache_huggingface_datasets_gsm8k_main_0.0.0_e53f048856ff4f594e959d75785d2c2d37b678ee.lock', '_root_.cache_huggingface_datasets_nlile___hendrycks-math-benchmark_default_0.0.0_465bcdb36f5962aa3512891498966df785fc3c18.lock', '_root_.cache_huggingface_datasets_tau___commonsense_qa_default_0.0.0_94630fe30dad47192a8546eb75f094926d47e155.lock', 'commonsense_qa', 'gsm8k', 'nlile___hendrycks-math-benchmark', 'tau___commonsense_qa']

  from .autonotebook import tqdm as notebook_tqdm


['_root_.cache_huggingface_datasets_commonsense_qa_default_0.0.0_94630fe30dad47192a8546eb75f094926d47e155.lock', '_root_.cache_huggingface_datasets_gsm8k_main_0.0.0_e53f048856ff4f594e959d75785d2c2d37b678ee.lock', '_root_.cache_huggingface_datasets_nlile___hendrycks-math-benchmark_default_0.0.0_465bcdb36f5962aa3512891498966df785fc3c18.lock', '_root_.cache_huggingface_datasets_tau___commonsense_qa_default_0.0.0_94630fe30dad47192a8546eb75f094926d47e155.lock', 'commonsense_qa', 'gsm8k', 'nlile___hendrycks-math-benchmark', 'tau___commonsense_qa']


In [12]:
# ============================================================================
# UNIVERSAL TEST HARNESS CLASS
# ============================================================================

from datasets import load_dataset
import json
import time
import os
import re
from typing import Optional, Dict, List, Tuple
from enum import Enum

class DatasetType(Enum):
    """Supported dataset types"""
    GSM8K = "gsm8k"
    COMMONSENSE_QA = "commonsense_qa"
    MATH_COMPETITION = "math_competition"

class UniversalTestHarness:
    """
    Unified test harness for multiple reasoning datasets.
    
    Features:
    - Single interface for all datasets
    - Configurable parameters (model, tokens, questions)
    - Consistent output format
    - No cost tracking (removed)
    - Progress tracking and intermediate saves
    """
    
    def __init__(self, system: 'ClaudeDynamicReasoningSystem', 
                 model: str = "claude-3-5-haiku-20241022",
                 max_tokens: int = 800):
        """
        Initialize test harness.
        
        Args:
            system: Your reasoning system instance
            model: Claude model to use
            max_tokens: Token limit per request
        """
        self.system = system
        self.model = model
        self.max_tokens = max_tokens
        
        # Dataset loaders
        self.dataset_loaders = {
            DatasetType.GSM8K: self._load_gsm8k,
            DatasetType.COMMONSENSE_QA: self._load_commonsense_qa,
            DatasetType.MATH_COMPETITION: self._load_math_competition
        }
        
        # Answer comparators
        self.answer_comparators = {
            DatasetType.GSM8K: self._compare_numerical,
            DatasetType.COMMONSENSE_QA: self._compare_letter_choice,
            DatasetType.MATH_COMPETITION: self._compare_math_expression
        }
        
        print(f"Test Harness Initialized")
        print(f"  Model: {model}")
        print(f"  Max Tokens: {max_tokens}")
    
    # ========================================================================
    # MAIN RUN METHOD
    # ========================================================================
    
    def run_test(self, 
                 dataset_choice: int,
                 num_questions: int = 100,
                 save_interval: int = 25) -> Dict:
        """
        Run test on selected dataset.
        
        Args:
            dataset_choice: 1=GSM8K, 2=CommonsenseQA, 3=MATH, 4=All
            num_questions: Number of questions to test
            save_interval: Save results every N questions
            
        Returns:
            Dictionary with results and metadata
        """
        
        # Map choice to datasets
        dataset_map = {
            1: [DatasetType.GSM8K],
            2: [DatasetType.COMMONSENSE_QA],
            3: [DatasetType.MATH_COMPETITION],
            4: [DatasetType.GSM8K, DatasetType.COMMONSENSE_QA, DatasetType.MATH_COMPETITION]
        }
        
        datasets_to_run = dataset_map.get(dataset_choice)
        if not datasets_to_run:
            raise ValueError(f"Invalid dataset choice: {dataset_choice}")
        
        # Apply system optimizations
        self._optimize_system()
        
        # Run each dataset
        all_results = {}
        for dataset_type in datasets_to_run:
            print(f"\n{'='*80}")
            print(f"RUNNING: {dataset_type.value.upper()}")
            print(f"{'='*80}\n")
            
            results = self._run_single_dataset(
                dataset_type=dataset_type,
                num_questions=num_questions,
                save_interval=save_interval
            )
            
            all_results[dataset_type.value] = results
        
        return all_results
    
    # ========================================================================
    # DATASET LOADERS
    # ========================================================================
    
    def _load_gsm8k(self) -> Tuple[List[Dict], str]:
        """Load GSM8K dataset"""
        print("Loading GSM8K (Grade School Math)...")
        dataset = load_dataset("gsm8k", "main")
        questions = dataset['test']
        
        processed = []
        for idx, item in enumerate(questions):
            processed.append({
                'idx': idx,
                'question': item['question'],
                'ground_truth': item['answer'].split('####')[-1].strip(),
                'full_answer': item['answer']
            })
        
        return processed, "GSM8K: Grade school math word problems"
    
    def _load_commonsense_qa(self) -> Tuple[List[Dict], str]:
        """Load CommonsenseQA dataset"""
        print("Loading CommonsenseQA...")
        dataset = load_dataset("tau/commonsense_qa")
        questions = dataset['validation']  # Use validation split
        
        processed = []
        for idx, item in enumerate(questions):
            # Format question with choices
            question_text = item['question']
            choices = item['choices']
            
            formatted_q = f"{question_text}\n\n"
            for i, choice in enumerate(choices['text']):
                formatted_q += f"{choices['label'][i]}: {choice}\n"
            formatted_q += "\nPlease select ONLY ONE answer (A-E)."
            
            processed.append({
                'idx': idx,
                'question': formatted_q,
                'ground_truth': item['answerKey'],
                'choices': choices
            })
        
        return processed, "CommonsenseQA: Multiple choice commonsense reasoning"
    
    def _load_math_competition(self) -> Tuple[List[Dict], str]:
        """Load MATH competition dataset"""
        print("Loading MATH Competition dataset...")
        dataset = load_dataset("nlile/hendrycks-math-benchmark", "default")
        questions = dataset['test']
        processed = []
        for idx, item in enumerate(questions):
            processed.append({
                'idx': idx,
                'question': item['problem'],
                'ground_truth': item['solution'],
                'level': item.get('level', 'unknown'),
                'category': item.get('category', item.get('type', 'unknown'))
            })
        
        return processed, "MATH: Competition-level mathematics"
    
    # ========================================================================
    # ANSWER COMPARATORS
    # ========================================================================
    
    def _compare_numerical(self, predicted: str, ground_truth: str) -> bool:
        """Compare numerical answers (GSM8K)"""
        pred_num = self._extract_number(predicted)
        truth_num = self._extract_number(ground_truth)
        
        if pred_num is None or truth_num is None:
            return predicted.strip().lower() == ground_truth.strip().lower()
        
        if truth_num == 0:
            return abs(pred_num) < 0.01
        
        return abs(pred_num - truth_num) / abs(truth_num) < 0.01
    
    def _compare_letter_choice(self, predicted: str, ground_truth: str) -> bool:
        """Compare letter choices (CommonsenseQA)"""
        # Extract letter from predicted (handles "A: text" format)
        pred_match = re.match(r'^([A-E])', predicted.strip())
        pred_letter = pred_match.group(1) if pred_match else predicted.strip()
        
        return pred_letter.upper() == ground_truth.upper()
    
    def _compare_math_expression(self, predicted: str, ground_truth: str) -> bool:
        """Compare mathematical expressions (MATH)"""
        # Try numerical comparison first
        if self._compare_numerical(predicted, ground_truth):
            return True
        
        # Try exact string match
        clean_pred = predicted.strip().lower().replace(' ', '')
        clean_truth = ground_truth.strip().lower().replace(' ', '')
        
        return clean_pred == clean_truth
    
    def _extract_number(self, text: str) -> Optional[float]:
        """Extract numerical value from text"""
        if not text:
            return None
        
        text = text.lower().strip()
        text = re.sub(r'^(the answer is|answer:|final answer:)\s*', '', text)
        
        # Try currency
        currency_match = re.search(r'\$?\s*([\d,]+\.?\d*)', text)
        if currency_match:
            try:
                return float(currency_match.group(1).replace(',', ''))
            except:
                pass
        
        # Try plain number
        num_match = re.search(r'([-+]?\d*\.?\d+)', text)
        if num_match:
            try:
                return float(num_match.group(1))
            except:
                pass
        
        return None
    
    # ========================================================================
    # CORE TEST EXECUTION
    # ========================================================================
    
    def _run_single_dataset(self, 
                           dataset_type: DatasetType,
                           num_questions: int,
                           save_interval: int) -> Dict:
        """Run test on a single dataset"""
        
        # Load dataset
        loader = self.dataset_loaders[dataset_type]
        questions, description = loader()
        
        print(f"Dataset: {description}")
        print(f"Total questions available: {len(questions)}")
        print(f"Will answer up to: {num_questions}\n")
        
        # Get comparator
        comparator = self.answer_comparators[dataset_type]
        
        # Run questions
        results = []
        start_time = time.time()
        
        for i, q_data in enumerate(questions[:num_questions]):
            q_num = i + 1
            
            print(f"\n{'='*80}")
            print(f"QUESTION #{q_num}/{num_questions}")
            print(f"{'='*80}")
            print(f"Q: {q_data['question'][:200]}...")
            print(f"Ground Truth: {q_data['ground_truth']}")
            
            try:
                # Run reasoning
                result = self.system.reason_with_synthesis(q_data['question'])
                
                answer = result.synthesized_answer.definitive_answer
                confidence = result.synthesized_answer.final_confidence
                
                # Check correctness
                is_correct = comparator(answer, q_data['ground_truth'])
                
                print(f"\nModel Answer: {answer}")
                print(f"Correct: {is_correct}")
                print(f"Confidence: {confidence:.2f}")
                print(f"Time: {result.total_time:.2f}s")
                
                # Store result
                result_entry = {
                    'question_num': q_num,
                    'dataset_idx': q_data['idx'],
                    'question': q_data['question'],
                    'ground_truth': q_data['ground_truth'],
                    'model_answer': answer,
                    'is_correct': is_correct,
                    'confidence': confidence,
                    'time': result.total_time,
                    'paths_generated': len(result.original_paths),
                    'validations': result.total_validations,
                    'regenerations': result.total_regenerations
                }
                
                # Add dataset-specific fields
                if dataset_type == DatasetType.MATH_COMPETITION:
                    result_entry['level'] = q_data.get('level')
                    result_entry['category'] = q_data.get('category')
                
                results.append(result_entry)
                
                # Periodic save
                if q_num % save_interval == 0:
                    self._save_results(
                        dataset_type=dataset_type,
                        results=results,
                        final=False
                    )
                
            except Exception as e:
                print(f"\nERROR: {e}")
                print("Skipping question...")
                continue
        
        # Final save
        total_time = time.time() - start_time
        self._save_results(
            dataset_type=dataset_type,
            results=results,
            final=True,
            total_time=total_time
        )
        
        # Calculate stats
        stats = self._calculate_statistics(results, total_time)
        
        print(f"\n{'='*80}")
        print(f"COMPLETED: {dataset_type.value.upper()}")
        print(f"Accuracy: {stats['accuracy']:.1f}% ({stats['correct']}/{stats['total']})")
        print(f"Time: {stats['total_time']:.2f}s")
        print(f"{'='*80}\n")
        
        return {
            'results': results,
            'statistics': stats,
            'dataset_type': dataset_type.value
        }
    
    # ========================================================================
    # SAVING & STATISTICS
    # ========================================================================
    
    def _save_results(self,
                     dataset_type: DatasetType,
                     results: List[Dict],
                     final: bool = True,
                     total_time: Optional[float] = None):
        """Save results to JSON file"""
        
        os.makedirs('Results', exist_ok=True)
        
        suffix = "_FINAL" if final else f"_{len(results)}q"
        filename = f"{dataset_type.value}_results{suffix}.json"
        output_path = os.path.join('Results', filename)
        
        stats = self._calculate_statistics(results, total_time)
        
        output = {
            'metadata': {
                'dataset': dataset_type.value,
                'model': self.model,
                'max_tokens': self.max_tokens,
                'questions_answered': len(results),
                'accuracy': stats['accuracy'],
                'correct_count': stats['correct'],
                'total_time_seconds': total_time,
                'avg_time_per_question': stats['avg_time'],
                'total_validations': stats['total_validations'],
                'total_regenerations': stats['total_regenerations'],
                'completed_at': time.strftime("%Y-%m-%d %H:%M:%S")
            },
            'results': results
        }
        
        with open(output_path, 'w') as f:
            json.dump(output, f, indent=2)
        
        status = "FINAL" if final else "INTERMEDIATE"
        print(f"\n{status} RESULTS SAVED: {output_path}")
        print(f"  Questions: {len(results)}")
        print(f"  Accuracy: {stats['accuracy']:.1f}%")
    
    def _calculate_statistics(self, results: List[Dict], 
                             total_time: Optional[float] = None) -> Dict:
        """Calculate statistics from results"""
        if not results:
            return {
                'total': 0,
                'correct': 0,
                'accuracy': 0.0,
                'avg_time': 0.0,
                'total_validations': 0,
                'total_regenerations': 0,
                'total_time': 0.0
            }
        
        correct = sum(1 for r in results if r['is_correct'])
        accuracy = (correct / len(results)) * 100
        
        avg_time = (sum(r['time'] for r in results) / len(results)) if results else 0
        total_validations = sum(r['validations'] for r in results)
        total_regenerations = sum(r['regenerations'] for r in results)
        
        return {
            'total': len(results),
            'correct': correct,
            'accuracy': accuracy,
            'avg_time': avg_time,
            'total_validations': total_validations,
            'total_regenerations': total_regenerations,
            'total_time': total_time or sum(r['time'] for r in results)
        }
    
    # ========================================================================
    # SYSTEM OPTIMIZATION
    # ========================================================================
    
    def _optimize_system(self):
        """Apply optimizations to the reasoning system"""
        print("\nApplying system optimizations...")
        
        # Set model
        self.system.generator.model = self.model
        self.system.synthesizer.model = self.model
        
        # Limit tokens
        original_create_gen = self.system.generator.client.messages.create
        original_create_synth = self.system.synthesizer.client.messages.create
        
        def limited_create(original_method):
            def wrapper(*args, **kwargs):
                kwargs['max_tokens'] = min(kwargs.get('max_tokens', 1500), self.max_tokens)
                return original_method(*args, **kwargs)
            return wrapper
        
        self.system.generator.client.messages.create = limited_create(original_create_gen)
        self.system.synthesizer.client.messages.create = limited_create(original_create_synth)
        
        print(f"  ✓ Model: {self.model}")
        print(f"  ✓ Token limit: {self.max_tokens}")
        print()


# ============================================================================
# USAGE INTERFACE
# ============================================================================

def run_test_harness_interactive():
    """Interactive test harness runner"""
    
    print("="*80)
    print("UNIVERSAL REASONING TEST HARNESS")
    print("="*80)
    print("\nSelect Dataset:")
    print("  1. GSM8K (Grade School Math)")
    print("  2. CommonsenseQA (Multiple Choice)")
    print("  3. MATH Competition (Advanced Math)")
    print("  4. All Datasets")
    print()
    
    # Get user input
    try:
        dataset_choice = int(input("Enter choice (1-4): "))
        num_questions = int(input("Number of questions to test: "))
        
        # Optional: Get model and tokens
        use_defaults = input("Use default settings? (y/n): ").lower() == 'y'
        
        if use_defaults:
            model = "claude-3-5-haiku-20241022"
            max_tokens = 1500
        else:
            model = input("Model (default: claude-3-5-haiku-20241022): ") or "claude-3-5-haiku-20241022"
            max_tokens = int(input("Max tokens (default: 800): ") or "800")
        
    except ValueError:
        print("Invalid input. Using defaults.")
        dataset_choice = 1
        num_questions = 10
        model = "claude-3-5-haiku-20241022"
        max_tokens = 800
    
    print(f"\n{'='*80}")
    print("CONFIGURATION")
    print(f"{'='*80}")
    print(f"  Dataset: {dataset_choice}")
    print(f"  Questions: {num_questions}")
    print(f"  Model: {model}")
    print(f"  Max Tokens: {max_tokens}")
    print(f"{'='*80}\n")
    
    # Initialize system
    system = ClaudeDynamicReasoningSystem(api_key=API_KEY)
    
    # Create and run harness
    harness = UniversalTestHarness(
        system=system,
        model=model,
        max_tokens=max_tokens
    )
    
    results = harness.run_test(
        dataset_choice=dataset_choice,
        num_questions=num_questions,
        save_interval=25
    )
    
    return results


# ============================================================================
# PROGRAMMATIC USAGE
# ============================================================================

def run_test_programmatic(dataset: int = 1, questions: int = 100):
    """Programmatic test runner (for scripts)"""
    
    system = ClaudeDynamicReasoningSystem(api_key=API_KEY)
    
    harness = UniversalTestHarness(
        system=system,
        model="claude-3-5-haiku-20241022",
        max_tokens=800
    )
    
    return harness.run_test(
        dataset_choice=dataset,
        num_questions=questions,
        save_interval=25
    )


# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    # Option 1: Interactive mode
    #results = run_test_harness_interactive()
    
    # GSM8K (Grade School Math) dataset - 10 questions each
    #results = run_test_programmatic(dataset=1, questions=10)

    # CommonsenseQA dataset - 10 questions
    results = run_test_programmatic(dataset=2, questions=10)

    # Math dataset - 10 questions
    #results = run_test_programmatic(dataset=2, questions=10)

    # All datasets - 10 questions each
    #results = run_test_programmatic(dataset=4, questions=10)
        


Initializing Claude Dynamic Reasoning System...
System ready - Claude 3.5 Sonnet with ALL 5 FIXES

Test Harness Initialized
  Model: claude-3-5-haiku-20241022
  Max Tokens: 800

Applying system optimizations...
  ✓ Model: claude-3-5-haiku-20241022
  ✓ Token limit: 800


RUNNING: COMMONSENSE_QA

Loading CommonsenseQA...
Dataset: CommonsenseQA: Multiple choice commonsense reasoning
Total questions available: 1221
Will answer up to: 10


QUESTION #1/10
Q: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?

A: bank
B: library
C: department store
D: mall
E: new york

Please select ONLY ONE answer ...
Ground Truth: A

QUERY: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?

A: bank
B: library
C: department store
D: mall
E: new york

Please select ONLY ONE answer (A-E).

[0] Classifying question...
[CLASSIFIER] Extracted question: ...
[CLASSIFIER] DEFAULT → COMMONSENSE
[CLASSI

KeyboardInterrupt: 