# Workflow for MIRA equations extractions - notes

This notebook outlines ideas for prompting for extracting mathematical equations from PDFs in the MIRA framework.

---

## One-shot Prompting:
The basic workflow of MIRA is a one-shot prompting architecture (let's call this **verison = 001**).


Process of the extraction: *'mira/notebooks/llm_extraction.ipynb'*

Pipeline: *'mira/sources/sympy_ode/llm_util.py'

Prompts: *'mira/sources/sympy_ode/constants.py'*

Detailed results can be found in the notebook: *mira_llm_extraction_evaluation.ipynb*

---

## Iterative promting workflow:
**version = 002**

To improve the precision of the extraction, an iterative workflow is being introduced, with the following steps:
### Agent 1:
First, an extraction agent uses the original MIRA process to convert equation images into SymPy code and ground biological concepts. 

### Agent 2:
Then, a validation agent checks the extraction for execution errors (missing imports, undefined variables), parameter consistency issues, and incorrect concept grounding. 
If errors are found, the validation agent corrects them and the process repeats for up to 3 iterations until all checks pass. 

This multi-agent approach improves extraction accuracy by catching and fixing common errors that the single-shot method might miss, while maintaining backward compatibility with the existing MIRA codebase.

### RESULTS OF IMPLEMENTATION:
*Forked version on GitHub: *'fruzsedua/mira/tree/extraction-development'*

Examples for each result found in this folder: *'mira/notebooks/equation extraction development/extraction error check/string mismatch check/comparison_results_version002'*

Process of the extraction: *'mira/notebooks/llm_extraction.ipynb'* -> **More detalied process**

Pipeline: *'mira/sources/sympy_ode/llm_util.py' -> **New functions added**

Prompts: *'mira/sources/sympy_ode/constants.py'* -> **Error handling prompt added**
**
**Image extraction:**
- Additional rules added: symmetry, transmission structure, patterns, mathematical structure, parameter consistency, completeness check
- Epidemology based rules are just ideas (from Claude) -> *revision needed!*

**Error checking and correcting:**
- Execution errors are mostly fixed during iteration 1:
- Syntax rules for detecting and handling functions/symbols
- Handling of imports, utilizing their names precisely
- Missing parameters are included

- Data cannot be parsed if the output format of the prompt is not aligned with the next function -> exact clarification is added to the prompt
- Comparing number of factors to the original (count * operators and variables)
- Preserving content between iterations of the error handling prompt
- Missing /N fixed

**Comparison of the extracted odes added:**
- Sympy format matching
- Sorting of equations (based on the variable on the LHS) for comparison
- Template Model → Mtx odes confuses a lot of information due to multiple formatting steps -> *fix needed!*


Error handling multi-agent architecture is part of the tm creation 
pipeline:

**Image → LLM Extraction → Multi-Agent Validation → JSON (corrected ODEs + concepts) → Template Model → Mtx odes**

**REMAINING ERRORS:**
- Parameter consistency: mostly symbolic differences (e.g. rho_1 vs. rho1), sometimes more serious: e.g. rho vs. q (similar) -> LLM has no info, which one is used, doesn’t know it needs fix
- Multiplication vs. addition still gets mixed up sometimes
- Semantic compartment mismatches I(t) vs. T(t)-> extra validation needed e.g. linear and 
- Strengthening of the arithmetic validation is much needed!
- Precision of coefficient extraction 
- Still remains: CodeExecutionError: Error while executing the code: 'Symbol' object is not callable (examples: BIOMD000000972, BIOMD000000976)
- The error handling function  mixes up the order of operations in some cases (example: BIOMD0000000991)
- Extraction of the compartments differ from the original completely, maybe derived from the RHS (example: 2024_dec_epi_1_model_A)


---

> NEXT STEP:
## Multi-Agent Pipeline:
**version= 003**

There are clearly separable problem areas, which will be better managed by detailing and resolving the prompting. An agenda based approach will systematically address extraction challenges by organizing the process into distinct agenda items, each targeting specific aspects of the process:

### Agent 1: Initial Extraction
- Extract equations from image/PDF using existing MIRA logic
- Convert mathematical notation to SymPy code representation
- Pass raw code string to next agent

### Agent 2: Execution Error Handler
- Attempt to execute the extracted SymPy code
- Catch and diagnose execution errors (missing imports, undefined variables, syntax errors)
- Automatically fix common issues and retry execution
- Pass executable code and any remaining warnings forward

### Agent 3: Symbol & Parameter Analysis
#### Time dependency classification:
- Identify all variables that appear with d/dt (time-dependent)
- Classify remaining symbols as parameters or independent variables
- Flag any inconsistencies in variable usage

#### Parameter consistency checking:
- Detect parameters that appear in equations but aren't defined
- Identify duplicate parameter definitions
- Find defined but unused parameters
- Check notation consistency (subscripts, superscripts, Greek letters)
- Pass comprehensive symbol mapping to next agent

### Agent 4: Diagnostic & Scoring
- Calculate extraction quality score based on:
 - Successful execution (from Agent 2)
 - Symbol consistency (from Agent 3)
 - Common extraction error patterns
- Generate final report with:
 - Overall confidence score
 - Specific warnings about potential extraction errors
 - Recommendations for manual review if score is low
- Optional: Include lightweight mathematical validation (missing negative signs, suspicious parameter usage)

This pipeline transforms the single-shot extraction into a robust, multi-step process where each agent specializes in one aspect of validation and correction. Since each agent requires a distinct approach and prompt configuration, the LLM can achieve better focus (rather than receiving a summarized, less detailed message).

Other possible agenda items:

6. Symbol Validation – Are all variables and parameters defined? this focuses more on JSON

7. Biological Context Tagging – Are compartments semantically labeled (e.g., S = susceptible)?

8. JSON Structure Integrity – Is the output JSON consistent and complete?


### Quantitative measures for evaluating the extraction process
> TBD

In [None]:
import sympy
import re
from datetime import datetime
import os

class ODEsEvaluator:
    
    def execution_success_rate(self, odes_list):
        """
        Check if ODEs can be executed without errors.
        """
        total = len(odes_list)
        successful = 0
        failed = []
        
        for i, eq in enumerate(odes_list):
            try:
                # Check if it's a valid SymPy equation
                if hasattr(eq, 'lhs') and hasattr(eq, 'rhs'):
                    # Try to evaluate/simplify
                    sympy.simplify(eq.lhs - eq.rhs)
                    successful += 1
                else:
                    failed.append(i)
            except Exception as e:
                failed.append(i)
        
        return {
            'success_rate': successful / total if total > 0 else 0,
            'successful': successful,
            'failed': len(failed),
            'total': total,
            'failed_indices': failed
        }
    
    def symbol_accuracy(self, extracted_eq, correct_eq):
        """
        Compare symbols between extracted and correct equations.
        """
        # Get all symbols from equations
        extracted_symbols = set(str(s) for s in extracted_eq.free_symbols)
        correct_symbols = set(str(s) for s in correct_eq.free_symbols)
        
        # Calculate accuracy
        if len(correct_symbols) == 0:
            return 1.0 if len(extracted_symbols) == 0 else 0.0
        
        correct_matches = extracted_symbols & correct_symbols
        missing = correct_symbols - extracted_symbols
        extra = extracted_symbols - correct_symbols
        
        accuracy = len(correct_matches) / len(correct_symbols)
        
        return {
            'accuracy': accuracy,
            'missing': list(missing),
            'extra': list(extra),
            'correct_count': len(correct_matches),
            'total_expected': len(correct_symbols)
        }
    
    def check_mathematical_equivalence(self, extracted_eq, correct_eq):
        """
        Check if two equations are mathematically equivalent.
        """
        try:
            # Calculate difference
            diff = sympy.simplify((extracted_eq.lhs - extracted_eq.rhs) - 
                                (correct_eq.lhs - correct_eq.rhs))
            
            # Check if difference is zero (equations are equivalent)
            is_equivalent = diff == 0
            
            return {
                'equivalent': is_equivalent,
                'difference': str(diff),
                'severity': 0 if is_equivalent else self.calculate_severity(diff)
            }
        except:
            return {
                'equivalent': False,
                'difference': 'Could not compute',
                'severity': 1.0
            }
    
    def calculate_severity(self, diff):
        """
        Calculate error severity based on the difference.
        """
        diff_str = str(diff)
        
        # If difference is 0, no error
        if diff == 0:
            return 0
        
        # Count terms in difference
        terms = diff_str.count('+') + diff_str.count('-') + 1
        
        # More terms = more severe
        if terms == 1:
            return 0.3  # Single term difference
        elif terms == 2:
            return 0.5  # Two terms difference
        else:
            return 0.8  # Many terms difference
    
    def evaluate_odes_set(self, extracted_odes, correct_odes, set_name="Extracted"):
        """
        Evaluate a complete set of ODEs against correct ODEs.
        """
        # Execution success
        exec_results = self.execution_success_rate(extracted_odes)
        
        # Compare each equation
        equation_results = []
        total_accuracy = 0
        total_equivalent = 0
        total_severity = 0
        
        for i, (extracted, correct) in enumerate(zip(extracted_odes, correct_odes)):
            # Symbol accuracy
            symbol_result = self.symbol_accuracy(extracted, correct)
            
            # Mathematical equivalence
            equiv_result = self.check_mathematical_equivalence(extracted, correct)
            
            equation_results.append({
                'index': i,
                'symbol_accuracy': symbol_result['accuracy'],
                'equivalent': equiv_result['equivalent'],
                'severity': equiv_result['severity']
            })
            
            total_accuracy += symbol_result['accuracy']
            total_equivalent += 1 if equiv_result['equivalent'] else 0
            total_severity += equiv_result['severity']
        
        n = len(correct_odes)
        
        return {
            'set_name': set_name,
            'execution_success_rate': exec_results['success_rate'],
            'avg_symbol_accuracy': total_accuracy / n if n > 0 else 0,
            'exact_match_rate': total_equivalent / n if n > 0 else 0,
            'avg_severity': total_severity / n if n > 0 else 0,
            'quality_score': 1 - (total_severity / n) if n > 0 else 0,
            'equation_details': equation_results,
            'summary': self.get_summary(exec_results['success_rate'], 
                                       total_equivalent / n if n > 0 else 0)
        }
    
    def get_summary(self, exec_rate, match_rate):
        """
        Get qualitative summary.
        """
        score = (exec_rate + match_rate) / 2
        if score >= 0.9:
            return "Excellent"
        elif score >= 0.7:
            return "Good"
        elif score >= 0.5:
            return "Fair"
        else:
            return "Needs Improvement"
    
    def compare_all_methods(self, correct_odes_sorted, extracted_odes_sorted, 
                           corrected_odes_sorted, mtx_odes_sorted):
        """
        Compare all extraction methods against correct ODEs.
        """
        print("="*60)
        print("MIRA EQUATION EXTRACTION EVALUATION")
        print("="*60)
        
        results = {}
        
        # 1. Evaluate extracted_odes
        print("\n1. EXTRACTED ODEs EVALUATION")
        print("-"*40)
        extracted_results = self.evaluate_odes_set(
            extracted_odes_sorted, correct_odes_sorted, "Extracted"
        )
        results['extracted'] = extracted_results
        self.print_results(extracted_results)
        
        # 2. Evaluate corrected_odes
        print("\n2. CORRECTED ODEs EVALUATION")
        print("-"*40)
        corrected_results = self.evaluate_odes_set(
            corrected_odes_sorted, correct_odes_sorted, "Corrected"
        )
        results['corrected'] = corrected_results
        self.print_results(corrected_results)
        
        # 3. Evaluate mtx_odes
        print("\n3. MATRIX ODEs EVALUATION")
        print("-"*40)
        # Handle potential length mismatch
        min_len = min(len(mtx_odes_sorted), len(correct_odes_sorted))
        mtx_results = self.evaluate_odes_set(
            mtx_odes_sorted[:min_len], correct_odes_sorted[:min_len], "Matrix"
        )
        results['matrix'] = mtx_results
        self.print_results(mtx_results)
        
        # Overall comparison
        print("\n" + "="*60)
        print("COMPARISON SUMMARY")
        print("="*60)
        print(f"{'Method':<15} {'Exec Rate':<12} {'Match Rate':<12} {'Quality':<12} {'Assessment':<12}")
        print("-"*60)
        
        for method in ['extracted', 'corrected', 'matrix']:
            r = results[method]
            print(f"{method.capitalize():<15} "
                  f"{r['execution_success_rate']:<12.1%} "
                  f"{r['exact_match_rate']:<12.1%} "
                  f"{r['quality_score']:<12.1%} "
                  f"{r['summary']:<12}")
        
        # Find best method
        best_method = max(results.items(), 
                         key=lambda x: x[1]['exact_match_rate'])
        print(f"\nBest performing method: {best_method[0].upper()}")
        
        return results
    
    def print_results(self, results):
        """
        Print evaluation results.
        """
        print(f"Execution Success Rate: {results['execution_success_rate']:.1%}")
        print(f"Symbol Accuracy: {results['avg_symbol_accuracy']:.1%}")
        print(f"Exact Match Rate: {results['exact_match_rate']:.1%}")
        print(f"Quality Score: {results['quality_score']:.1%}")
        print(f"Assessment: {results['summary']}")
        
        # Show problematic equations
        problematic = [eq for eq in results['equation_details'] 
                      if not eq['equivalent']]
        if problematic:
            print(f"Problematic equations: {[eq['index'] for eq in problematic[:5]]}")


# Usage example
def run_evaluation(correct_odes_sorted, extracted_odes_sorted, 
                  corrected_odes_sorted, mtx_odes_sorted, biomodel_name):
    """
    Run the evaluation for your ODEs.
    """
    evaluator = ODEsEvaluator()
    
    # Run comparison
    results = evaluator.compare_all_methods(
        correct_odes_sorted,
        extracted_odes_sorted,
        corrected_odes_sorted,
        mtx_odes_sorted
    )
    
    # Save results to file
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    output_file = f"evaluation_{biomodel_name}_{timestamp}.txt"
    
    with open(output_file, 'w') as f:
        f.write(f"Evaluation Results for {biomodel_name}\n")
        f.write(f"Timestamp: {timestamp}\n\n")
        
        for method, res in results.items():
            f.write(f"\n{method.upper()} METHOD:\n")
            f.write(f"  Execution Success: {res['execution_success_rate']:.1%}\n")
            f.write(f"  Exact Match Rate: {res['exact_match_rate']:.1%}\n")
            f.write(f"  Quality Score: {res['quality_score']:.1%}\n")
    
    return results


# Example usage with your variables
if __name__ == "__main__":
    # Assuming your variables are already defined:
    # correct_odes_sorted, extracted_odes_sorted, corrected_odes_sorted, mtx_odes_sorted
    
    evaluator = ODEsEvaluator()
    results = evaluator.compare_all_methods(
        correct_odes_sorted,
        extracted_odes_sorted,
        corrected_odes_sorted,
        mtx_odes_sorted
    )

In [None]:
def run_multi_agent_pipeline(
    image_path: str,
    client: OpenAIClient,
    verbose: bool = True
) -> tuple[str, Optional[dict], dict]:
    """
    Multi-agent pipeline for ODE extraction and validation
    
    Phase 1: Extract ODEs from image
    Phase 2: Fix execution errors
    Phase 3: Validate and correct (parallel checks)
    Phase 4: Evaluate quality
    
    Returns:
        Validated ODE string, concepts, and quality score
    """
    
    if verbose:
        print("="*60)
        print("MULTI-AGENT ODE EXTRACTION & VALIDATION PIPELINE")
        print("="*60)
    
    # Phase 1: Extraction (part of pipeline for context tracking)
    ode_str, concepts = phase1_extract_odes(image_path, client, verbose)
    
    # Phase 2: Execution error correction
    ode_str = phase2_fix_execution_errors(ode_str, client, verbose)
    
    # Phase 3: Validation and mathematical correction
    ode_str, concepts = phase3_validate_and_correct(ode_str, concepts, client, verbose)
    
    # Phase 4: Quality evaluation
    quality_score = phase4_evaluate_quality(ode_str, concepts, {}, client, verbose)
    
    if verbose:
        print("="*60)
        print(f"PIPELINE COMPLETE - Quality Score: {quality_score['total_score']:.2%}")
        print("="*60)
    
    return ode_str, concepts, quality_score

# Individual phase functions
def phase1_extract_odes(
    image_path: str, 
    client: OpenAIClient,
    verbose: bool = True
) -> tuple[str, Optional[dict]]:
    """Phase 1: Extract ODEs and concepts from image"""
    if verbose:
        print("\nPHASE 1: ODE Extraction")
    
    # Use existing extraction functions
    ode_str = image_file_to_odes_str(image_path, client)
    
    try:
        concepts = get_concepts_from_odes(ode_str, client)
    except Exception as e:
        if verbose:
            print(f"  Warning: Concept extraction failed: {e}")
        concepts = None
    
    return ode_str, concepts

def phase2_fix_execution_errors(
    ode_str: str, 
    client: OpenAIClient,
    verbose: bool = True
) -> str:
    """Phase 2: Check and fix execution errors"""
    if verbose:
        print("\nPHASE 2: Execution Error Check & Correction")
    
    from agents import ExecutionErrorCorrector
    corrector = ExecutionErrorCorrector(client)
    result = corrector.process({'ode_str': ode_str})
    
    if verbose and result.get('execution_report', {}).get('errors_fixed'):
        print(f"  Fixed {len(result['execution_report']['errors_fixed'])} errors")
    
    return result['ode_str']

def phase3_validate_and_correct(
    ode_str: str, 
    concepts: Optional[dict],
    client: OpenAIClient,
    verbose: bool = True
) -> tuple[str, Optional[dict]]:
    """Phase 3: Validation and mathematical checks with corrections"""
    if verbose:
        print("\nPHASE 3: Validation & Mathematical Checks")
    
    from agents import (
        ValidationAggregator,
        MathematicalAggregator,
        UnifiedErrorCorrector
    )
    
    pipeline_state = {'ode_str': ode_str, 'concepts': concepts}
    
    # Run parallel validation checks
    val_aggregator = ValidationAggregator(client)
    val_results = val_aggregator.process(pipeline_state)
    
    math_aggregator = MathematicalAggregator(client)
    math_results = math_aggregator.process(pipeline_state)
    
    # Apply unified corrections
    pipeline_state.update(val_results)
    pipeline_state.update(math_results)
    
    corrector = UnifiedErrorCorrector(client)
    correction_result = corrector.process(pipeline_state)
    
    return correction_result['ode_str'], correction_result.get('concepts', concepts)

def phase4_evaluate_quality(
    ode_str: str,
    concepts: Optional[dict],
    reports: dict,
    client: OpenAIClient,
    verbose: bool = True
) -> dict:
    """Phase 4: Quantitative evaluation of extraction quality"""
    if verbose:
        print("\nPHASE 4: Quantitative Evaluation")
    
    from agents import QuantitativeEvaluator
    evaluator = QuantitativeEvaluator(client)
    
    result = evaluator.process({
        'ode_str': ode_str,
        'concepts': concepts,
        **reports
    })
    
    return result['quality_score']

In [None]:
def execute_template_model_from_sympy_odes(
    ode_str,
    attempt_grounding: bool,
    client: OpenAIClient,
) -> TemplateModel:
    """Create a TemplateModel from the sympy ODEs defined in the code snippet string

    Parameters
    ----------
    ode_str :
        The code snippet defining the ODEs
    attempt_grounding :
        Whether to attempt grounding the concepts in the ODEs. This will prompt the
        OpenAI chat completion to create concepts data to provide grounding for the
        concepts in the ODEs. The concepts data is then used to create the TemplateModel.
    client :
        The OpenAI client

    Returns
    -------
    :
        The TemplateModel created from the sympy ODEs.
    """
    # FixMe, for now use `exec` on the code, but need to find a safer way to execute
    #  the code
    # Import sympy just in case the code snippet does not import it
    import sympy
    odes: List[sympy.Eq] = None
    # Execute the code and expose the `odes` variable to the local scope
    local_dict = locals()
    try:
        exec(ode_str, globals(), local_dict)
    except Exception as e:
        # Raise a CodeExecutionError to handle the error in the UI
        raise CodeExecutionError(f"Error while executing the code: {e}")
    # `odes` should now be defined in the local scope
    odes = local_dict.get("odes")
    assert odes is not None, "The code should define a variable called `odes`"
    if attempt_grounding:
        concept_data = get_concepts_from_odes(ode_str, client)
    else:
        concept_data = None
    return template_model_from_sympy_odes(odes, concept_data=concept_data)