In [77]:
import os
from datetime import datetime
import json
from dotenv import load_dotenv
from openai import OpenAI
from IPython.display import Markdown, display
from agents import Agent, Runner, trace, ModelSettings
from agents.extensions.models.litellm_model import LitellmModel
from pydantic import BaseModel, Field
from typing import Literal
from IPython.display import Markdown, display
from pathlib import Path
import asyncio
from typing import Optional, List
import chromadb

In [78]:
load_dotenv(override=True)

True

In [79]:
openai_api_key = os.getenv('OPENAI_API_KEY')
openrouter_api_key = os.getenv('OPENROUTER_API_KEY')

if openai_api_key:
    print(f"OpenAI API Key exists and begins {openai_api_key[:8]}")
else:
    print("OpenAI API Key not set")
    
if openrouter_api_key:
    print(f"OpenRouter API Key exists and begins {openrouter_api_key[:8]}")
else:
    print("OpenRouter API Key not set")

OpenAI API Key exists and begins sk-proj-
OpenRouter API Key exists and begins sk-or-v1


In [80]:
grok_code_fast_1=LitellmModel(model="openrouter/x-ai/grok-code-fast-1", api_key=openrouter_api_key)
grok_4_1_fast=LitellmModel(model="openrouter/x-ai/grok-4.1-fast", api_key=openrouter_api_key)

## 4 Agents + Aggregator

In [81]:
class BugFinding(BaseModel):
    title: str = Field(description="Brief name for the bug")
    description: str = Field(description="Detailed explanation")
    severity: int = Field(description="Severity 1-10")
    file: str = Field(description="File path")
    relevant_lines: list[int] = Field(description="Line numbers (max 20 lines per finding)", max_length=20)
    suggested_fix: str = Field(description="Recommended solution")

class VulnerabilityFinding(BaseModel):
    title: str = Field(description="Brief name for the vulnerability")
    description: str = Field(description="Detailed explanation")
    severity: int = Field(description="Severity 1-10")
    file: str = Field(description="File path")
    relevant_lines: list[int] = Field(description="Line numbers (max 20 lines per finding)", max_length=20)
    suggested_fix: str = Field(description="Recommended solution")
    cve_reference: str | None = Field(default=None, description="CVE ID if applicable")

class BestPracticeFinding(BaseModel):
    title: str = Field(description="Brief name for the best practice violation")
    description: str = Field(description="Detailed explanation")
    severity: int = Field(description="Severity 1-10")
    file: str = Field(description="File path")
    relevant_lines: list[int] = Field(description="Line numbers (max 20 lines per finding)", max_length=20)
    suggested_fix: str = Field(description="Recommended solution")
    
class TestGap(BaseModel):
    function_name: str = Field(description="Name of the function/method lacking tests")
    file: str = Field(description="File containing the untested code")
    lines: list[int] = Field(description="Line numbers of the untested code (max 20 lines)", max_length=20)
    missing_scenarios: list[str] = Field(description="Specific test cases that should be added, e.g., ['edge case: empty input', 'error handling: invalid type']")
    priority: int = Field(description="Priority 1-10, based on code criticality")
    suggested_test_approach: str = Field(description="How to test this (unit test, integration test, etc.)")
    
class CodeAnalyzerOutput(BaseModel):
    findings: list[BugFinding] = Field(description="Bugs and anti-patterns found")

class SecurityOutput(BaseModel):
    findings: list[VulnerabilityFinding] = Field(description="Security vulnerabilities found")

class BestPracticesOutput(BaseModel):
    findings: list[BestPracticeFinding] = Field(description="Style and best practice violations")

class TestCoverageOutput(BaseModel):
    findings: list[TestGap] = Field(description="Testing gaps found")

In [82]:
# IMPROVED: Added deletion analysis, chain-of-thought, and better semantic understanding

code_analyzer_instructions = """You are a Code Analyzer agent reviewing a pull request diff.

ANALYSIS APPROACH:
1. First, describe what changed: What code was added? What was removed? What was modified?
2. Then, identify potential issues in the changes
3. Consider the inverse: What functionality might be LOST from deletions?

CRITICAL: Only create findings for actual bugs, logic errors, or antipatterns. If the code is clean and correct, return an empty findings list.

DELETION ANALYSIS (CRITICAL):
- When you see removed code (lines starting with -), pay special attention to:
  * Entire functions/classes being deleted - flag if they're called elsewhere
  * Helper functions removed - check if remaining code still works without them
  * Error handling removed - flag if this makes code less safe
  * Imports removed - verify they're truly unused
- If 10+ consecutive lines are deleted, describe what functionality is being removed

BUG PATTERNS TO IDENTIFY:
- Logic errors, unhandled edge cases, null/undefined access, type mismatches
- Off-by-one errors, resource leaks (unclosed files/cursors/connections)
- Infinite loops, missing error handling (no try-except blocks)
- Code duplication, overly complex functions
- Removed functionality that breaks remaining code

IMPORTANT: For each issue, specify ONLY the specific lines where the issue occurs (max 20 lines per finding). 
Do NOT list entire files or large ranges. Be precise and focused."""

security_instructions = """You are a Security agent reviewing a pull request diff.

ANALYSIS APPROACH:
1. First, describe what changed from a security perspective
2. Identify what security controls or validations were added or removed
3. Consider: Does this change introduce new attack surface?

CRITICAL: Only create findings for actual security vulnerabilities or risks. If the code is secure and follows security best practices, return an empty findings list.

SECURITY PATTERNS:
- SQL injection, command injection, XSS vulnerabilities
- Hardcoded secrets/credentials, insecure authentication
- Path traversal, insecure deserialization
- Improper input validation
- Missing error handling that could expose sensitive information
- Removed security checks or validation code

DELETION AWARENESS:
- If security-related code is removed (validation, sanitization, auth checks), flag it as HIGH severity
- Consider what protections are LOST, not just what bugs are added

IMPORTANT: For each vulnerability, specify ONLY the specific lines where the vulnerability exists (max 20 lines per finding).
Do NOT list entire files or large ranges. Focus on the exact vulnerable code location."""

best_practices_instructions = """You are a Best Practices agent reviewing a pull request diff.

ANALYSIS APPROACH:
1. Describe what changed in terms of code quality
2. Identify violations of best practices in the new/modified code
3. Consider: Does this change make the code harder to maintain?

CRITICAL: Only create findings for actual violations of coding standards and best practices. If the code follows PEP 8, has proper docstrings, and is well-structured, return an empty findings list.

CODE QUALITY ISSUES:
- Unclear variable names, functions exceeding 50 lines
- Nested complexity over 3 levels, missing docstrings
- Inconsistent formatting, magic numbers without explanation
- Violations of DRY principle
- Unclosed resources (files, database cursors, connections)
- Missing try-except blocks for error-prone operations

DELETION AWARENESS:
- If helpful comments, docstrings, or error handling are removed, flag it
- If code is simplified but loses clarity, mention it

IMPORTANT: For each issue, specify ONLY the specific lines with the violation (max 20 lines per finding).
Do NOT list entire files or large ranges. Be specific and targeted."""

test_coverage_instructions = """You are a Test Coverage agent reviewing a pull request diff.

ANALYSIS APPROACH:
1. Identify what functions/methods are new or modified
2. For each, assess criticality and risk
3. Only flag missing tests for high-risk code

CRITICAL: Only create test gap findings for functions that are genuinely risky if untested. Use priority 7-8 for critical code, priority 4-5 for nice-to-have tests.

PRIORITY GUIDELINES:
- Priority 8-10: Functions handling user input, authentication, authorization, financial transactions, data persistence, security controls, or external API calls
- Priority 7: Functions with complex logic, multiple conditional branches, error-prone operations (file I/O, parsing, calculations)
- Priority 4-6: Simple utility functions, formatters, getters/setters, straightforward data transformations
- Priority 1-3: Trivial helpers (one-liners, simple wrappers, obvious logic)

DO NOT FLAG: Trivial helper functions, simple string formatters, obvious getters/setters, or functions with self-evident correctness.

For each flagged function, suggest test cases covering:
- Normal input cases
- Edge cases (empty, null, boundary values)
- Error conditions (exceptions, failures, timeouts)
- Integration scenarios

IMPORTANT: For each gap, specify ONLY the specific lines of the function needing tests (max 20 lines per gap).
Do NOT list entire files. Focus on the specific untested function location."""

code_analyzer = Agent(
    name="Code Analyzer",
    instructions=code_analyzer_instructions,
    model="gpt-4.1-mini",
    model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
    output_type=CodeAnalyzerOutput
)

security_agent = Agent(
    name="Security Agent",
    instructions=security_instructions,
    model="gpt-4.1-mini",
    model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
    output_type=SecurityOutput
)

best_practices_agent = Agent(
    name="Best Practices Agent",
    instructions=best_practices_instructions,
    model="gpt-4.1-mini",
    model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
    output_type=BestPracticesOutput
)

test_coverage_agent = Agent(
    name="Test Coverage Agent",
    instructions=test_coverage_instructions,
    model="gpt-4.1-mini",
    model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
    output_type=TestCoverageOutput
)

In [83]:
def get_relevant_security_patterns(code_diff: str, n_results: int = 5) -> str:
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    security_collection = chroma_client.get_collection(name="security_patterns")
    results = security_collection.query(query_texts=[code_diff], n_results=n_results)
    return "\n\n".join(results['documents'][0]) if results['documents'][0] else ""

def get_relevant_best_practices_patterns(code_diff: str, n_results: int = 5) -> str:
    """Retrieve relevant best practices patterns from ChromaDB"""
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    best_practices_collection = chroma_client.get_collection(name="best_practices_patterns")
    results = best_practices_collection.query(query_texts=[code_diff], n_results=n_results)
    return "\n\n".join(results['documents'][0]) if results['documents'][0] else ""

def get_relevant_python_gotchas(code_diff: str, n_results: int = 3) -> str:
    """Retrieve relevant Python gotchas patterns from ChromaDB"""
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    python_gotchas_collection = chroma_client.get_collection(name="python_gotchas_patterns")
    results = python_gotchas_collection.query(query_texts=[code_diff], n_results=n_results)
    return "\n\n".join(results['documents'][0]) if results['documents'][0] else ""

def get_relevant_code_review_patterns(code_diff: str, n_results: int = 3) -> str:
    """Retrieve relevant code review patterns from ChromaDB"""
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    code_review_collection = chroma_client.get_collection(name="code_review_patterns")
    results = code_review_collection.query(query_texts=[code_diff], n_results=n_results)
    return "\n\n".join(results['documents'][0]) if results['documents'][0] else ""

def get_relevant_refactoring_patterns(code_diff: str, n_results: int = 5) -> str:
    """Retrieve relevant refactoring patterns from ChromaDB (multi-file changes, shotgun surgery, etc.)"""
    chroma_client = chromadb.PersistentClient(path="./chroma_db")
    refactoring_collection = chroma_client.get_collection(name="refactoring_patterns")
    results = refactoring_collection.query(query_texts=[code_diff], n_results=n_results)
    return "\n\n".join(results['documents'][0]) if results['documents'][0] else ""

In [84]:
# async def run_all_agents(diff):
#     results = await asyncio.gather(
#         Runner.run(code_analyzer, diff),
#         Runner.run(security_agent, diff),
#         Runner.run(best_practices_agent, diff),
#         Runner.run(test_coverage_agent, diff)
#     )
#     return results

async def run_all_agents(diff):
    # Get RAG context for all agents
    # INCREASED n_results from 5 to 15 for security patterns to capture more injection patterns
    security_patterns = get_relevant_security_patterns(diff, n_results=15)
    best_practices_patterns = get_relevant_best_practices_patterns(diff, n_results=5)
    python_gotchas = get_relevant_python_gotchas(diff, n_results=3)
    code_review_patterns = get_relevant_code_review_patterns(diff, n_results=3)
    refactoring_patterns = get_relevant_refactoring_patterns(diff, n_results=5)  # NEW: Multi-file refactoring patterns
    
    # Create RAG-enhanced Code Analyzer agent (UPDATED: with all patterns including refactoring)
    enhanced_code_analyzer_instructions = f"""{code_analyzer_instructions}

RELEVANT PYTHON GOTCHAS TO CHECK:
{python_gotchas}

RELEVANT CODE REVIEW PATTERNS TO CHECK:
{code_review_patterns}

RELEVANT REFACTORING PATTERNS TO CHECK (Multi-File Changes):
{refactoring_patterns}"""
    
    # Create RAG-enhanced security agent
    enhanced_security_instructions = f"""{security_instructions}

RELEVANT SECURITY PATTERNS TO CHECK:
{security_patterns}"""
    
    # Create RAG-enhanced best practices agent
    enhanced_best_practices_instructions = f"""{best_practices_instructions}

RELEVANT BEST PRACTICES PATTERNS TO CHECK:
{best_practices_patterns}"""
    
    code_analyzer_rag = Agent(
        name="Code Analyzer",
        instructions=enhanced_code_analyzer_instructions,
        model="gpt-4.1-mini",
        model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
        output_type=CodeAnalyzerOutput
    )
    
    security_agent_rag = Agent(
        name="Security Agent",
        instructions=enhanced_security_instructions,
        model="gpt-4.1-mini",
        model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
        output_type=SecurityOutput
    )
    
    best_practices_agent_rag = Agent(
        name="Best Practices Agent",
        instructions=enhanced_best_practices_instructions,
        model="gpt-4.1-mini",
        model_settings=ModelSettings(
            temperature=0.6,
            max_tokens=4000,
        ),
        output_type=BestPracticesOutput
    )
    
    # Run all agents in parallel
    results = await asyncio.gather(
        Runner.run(code_analyzer_rag, diff),  # Now uses RAG with refactoring patterns!
        Runner.run(security_agent_rag, diff),  # Uses RAG
        Runner.run(best_practices_agent_rag, diff),  # Uses RAG
        Runner.run(test_coverage_agent, diff)  # No RAG needed for test coverage
    )
    return results

In [85]:
def organize_findings(
    code_result,
    security_result, 
    best_practices_result,
    test_coverage_result
):
    """
    Organizes all findings by file.
    
    Returns:
        dict: {
            "file.py": [Finding, Finding, TestGap, ...]
        }
    """
    organized = {}
    for result in [code_result, security_result,  best_practices_result, test_coverage_result]:
        for finding in result.final_output.findings:
            file = finding.file
            if file not in organized:
                organized[file] = []
            organized[file].append(finding)
        
    return organized

In [86]:
# IMPROVED: Added multi-file awareness and cross-file dependency detection

aggregator_instructions = """You are a Code Review Aggregator tasked with creating a deduplicated summary report. Your goal is to merge duplicate findings from multiple agents into a clear, actionable report.

CRITICAL: Output your report as plain text/markdown. Do NOT wrap your response in JSON or code fences.

You will be provided with findings from multiple agents:
<findings>
{organized}
</findings>

AGGREGATION GUIDELINES:

1. IDENTIFY DUPLICATES: Group findings that describe the same root issue
   - Look for overlapping line numbers and similar descriptions
   - When multiple agents flag the same problem, merge into one issue
   - Use the HIGHEST severity when merging

2. MULTI-FILE AWARENESS (CRITICAL):
   - If findings span multiple files, check for cross-file dependencies
   - Flag if changes in one file might break APIs/contracts in another file
   - Look for patterns like: "File A removes function X, but does File B call it?"
   - Consider the bigger picture: Do these changes work together?

3. PRESERVE INFORMATION: 
   - Keep agent names: Code Analyzer, Security, Best Practices, Test Coverage
   - Include file paths and line numbers
   - Maintain the most comprehensive description from merged findings

4. CATEGORIZE each issue as:
   - Bug: Logic errors, crashes, incorrect behavior  
   - Security: Vulnerabilities, unsafe code
   - Performance: Inefficient algorithms, resource issues
   - Style: Naming, formatting, documentation
   - Test Gap: Missing test coverage

5. CREATE SUMMARY TABLE with these columns:
   | Issue | File | Lines | Severity | Category | Fix | Found By |

6. SEPARATE CONCERNS: Test coverage gaps are distinct from code issues

Present your report in this format:

# Code Review Report

## Executive Summary
[2-3 sentences highlighting the most critical findings. If multi-file change, mention cross-file implications]

## Summary of Actions
| Issue | File | Lines | Severity | Category | Fix | Found By |
|-------|------|-------|----------|----------|-----|----------|
[One row per unique issue]

**Total Distinct Issues: [count]**

CRITICAL REQUIREMENT: 
- EVERY finding from EVERY agent must appear in the summary table
- This includes ALL test coverage gaps reported by the Test Coverage agent
- Test gaps should be listed as separate rows (one per function needing tests)
- Do NOT omit any findings, especially test coverage gaps
- The Total Distinct Issues count must match the number of rows in the table."""

aggregator = Agent(
    name="Aggregator",
    instructions=aggregator_instructions,
    model=grok_4_1_fast,
    model_settings=ModelSettings(
            temperature=0.6,
            extra_args={"reasoning": {"enabled": True}}
        ),
)

In [87]:
async def aggregator_agent(organized):
    result = await Runner.run(aggregator, f"Aggregate these findings into a structured report:\n\n{organized}")
    return result.final_output

In [88]:
async def review_code(diff: str, save_output: bool = True, min_severity: int = 1) -> str:
    """
    Complete code review pipeline.
    
    Args:
        diff: The code diff to review
        min_severity: Minimum severity threshold (1-10). Findings below this are filtered out. (default: 1)
        
    Returns:
        Markdown-formatted code review report
    """
    with trace("Multi-Agent Code Review"):
        results = await run_all_agents(diff)
        code_result, security_result, best_practices_result, test_coverage_result = results
        
        # Filter findings by severity threshold
        def filter_by_severity(result):
            filtered_findings = [
                finding for finding in result.final_output.findings
                if getattr(finding, 'severity', getattr(finding, 'priority', 0)) >= min_severity
            ]
            result.final_output.findings = filtered_findings
            return result
        
        code_result = filter_by_severity(code_result)
        security_result = filter_by_severity(security_result)
        best_practices_result = filter_by_severity(best_practices_result)
        test_coverage_result = filter_by_severity(test_coverage_result)
        
        organized = organize_findings(code_result, security_result, best_practices_result, test_coverage_result)
        
        # If all findings were filtered out, return early with a clean report
        if not any(organized.values()):
            clean_report = "# Code Review Report\n\nNo issues found meeting severity threshold.\n"
            print(clean_report)
            return clean_report
        
        print("\n" + "="*60)
        print("CALLING AGGREGATOR...")
        print("="*60)
        
        report = await aggregator_agent(organized)
        
        print("\n" + "="*60)
        print("AGGREGATOR OUTPUT:")
        print("="*60)
        print(report)
        print("="*60 + "\n")
        
        if save_output:
            os.makedirs("user-data", exist_ok=True)
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filepath = f"user-data/code_review_{timestamp}.md"
            with open(filepath, "w") as f:
                f.write(report)
            print(f"Report saved to {filepath}")
        
        return report

In [74]:
sample_diff = '''
diff --git a/utils.py b/utils.py
index abc123..def456 100644
--- a/utils.py
+++ b/utils.py
@@ -1,3 +1,8 @@
+def greet(name):
+    """Return a greeting message."""
+    return f"Hello, {name}!"
+
+
 def add(a, b):
     """Add two numbers."""
     return a + b
'''

report = await review_code(sample_diff, save_output=False, min_severity=6)

# Code Review Report

No issues found meeting severity threshold.



In [75]:
serious_diff = '''
diff --git a/user_auth.py b/user_auth.py
index abc123..def456 100644
--- a/user_auth.py
+++ b/user_auth.py
@@ -5,6 +5,12 @@ class UserAuth:
     def __init__(self):
         self.db = sqlite3.connect('users.db')
     
+    def authenticate(self, username, password):
+        query = "SELECT * FROM users WHERE username='" + username + "' AND password='" + password + "'"
+        cursor = self.db.cursor()
+        result = cursor.execute(query)
+        return result.fetchone() is not None
+
'''

report = await review_code(serious_diff, save_output=False, min_severity=5)


CALLING AGGREGATOR...

AGGREGATOR OUTPUT:
# Code Review Report

## Executive Summary
A critical SQL injection vulnerability exists in the `authenticate` method of `user_auth.py` due to direct concatenation of user inputs into SQL queries, flagged by multiple agents as a high-severity security risk that could allow unauthorized access or data leakage. Comprehensive test coverage is missing for this function, including normal, edge, error, and security scenarios. No cross-file dependencies observed in this single-file review.

## Summary of Actions

| Issue | File | Lines | Severity | Category | Fix | Found By |
|-------|------|-------|----------|----------|-----|----------|
| SQL Injection Vulnerability in authenticate Method | user_auth.py | 7-12 | 9 | Security | Use parameterized queries with placeholders to safely pass user inputs to the SQL query, e.g., `cursor.execute("SELECT * FROM users WHERE username=? AND password=?", (username, password))`. Example: `query = "SELECT * FROM us

# Evals

## Synthetic Evals

In [None]:
judge_instructions = """You are an evaluation judge for code review systems comparing expected findings (ground truth) against actual findings.

CRITICAL MATCHING RULES:
1. Each actual finding can match AT MOST ONE expected finding
2. Each expected finding can match AT MOST ONE actual finding
3. Once an actual finding is matched, it CANNOT be used again
4. Only match within same category (bugs ‚â† test gaps)

PROCESS:
1. Count total_actual from "Total Distinct Issues: X" in report
2. For EACH expected finding:
   - Find the BEST matching actual finding that hasn't been used yet
   - If good match exists: mark as matched=True, record which actual finding
   - If no match: mark as matched=False
   - NEVER reuse an actual finding for multiple expected findings

A match means the same type of issue was identified, even if worded differently.
"""

class MatchedFinding(BaseModel):
    expected: str = Field(description="the expected finding text")
    matched: bool = Field(description="true if the expected finding is present, else false")
    actual_finding: Optional[str] = Field(default=None, description="the matching text from report (if matched)")

class EvaluationResult(BaseModel):
    matched_findings: list[MatchedFinding]
    total_expected: int = Field(description="Total number of expected findings from ground truth")
    total_actual: int = Field(description="Count of distinct issues in the report's summary section")
    # matches: int = Field(description="Number of expected findings successfully matched")
    
    def model_post_init(self, __context):
        # Calculate matches from the list
        matches = sum(1 for mf in self.matched_findings if mf.matched)
        
        # Check for duplicate actual findings
        actual_findings_used = [
            mf.actual_finding for mf in self.matched_findings 
            if mf.matched and mf.actual_finding
        ]
        unique_actuals = len(set(actual_findings_used))
        
        if matches > unique_actuals:
            print(f"ERROR: {matches} matches but only {unique_actuals} unique actual findings used!")
            print("The judge matched the same actual finding multiple times.")
        
        if matches > self.total_actual:
            print(f"WARNING: Matches ({matches}) > Total Actual ({self.total_actual})")



async def evaluate_report(report: str, ground_truth_content: str) -> dict:
    """
    Fixed evaluation function with proper counting.
    """
    
    judge_agent = Agent(
        name="Evaluation Judge",
        instructions=judge_instructions,
        model="gpt-5.1",
        output_type=EvaluationResult
    )
    
    prompt = f"""
GROUND TRUTH (expected findings):
{ground_truth_content}

ACTUAL REPORT (what the system found):
{report}

For each expected finding, determine if it matches any actual finding.
Output matched_findings list, total_expected, and total_actual.
"""
    
    result = await Runner.run(judge_agent, prompt)
    eval_result = result.final_output
    
    # Calculate matches from the actual data - don't trust LLM counting
    matches = sum(1 for mf in eval_result.matched_findings if mf.matched)
    
    # Calculate metrics
    recall = matches / eval_result.total_expected if eval_result.total_expected > 0 else 0
    precision = matches / eval_result.total_actual if eval_result.total_actual > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        "recall": recall,
        "precision": precision,
        "f1": f1,
        "matches": matches,
        "total_expected": eval_result.total_expected,
        "total_actual": eval_result.total_actual,
        "details": eval_result.matched_findings
    }

In [None]:
# Run test 2
test_dir = Path("test-cases")
diff_file = test_dir / "01_sql_injection.diff"

# Load files
diff_content = diff_file.read_text()
expected_file = diff_file.with_name("01_sql_injection_expected.json")
ground_truth_content = expected_file.read_text()

# Run review WITH saving
report = await review_code(diff_content, save_output=False)

# Evaluate
eval_result = await evaluate_report(report, ground_truth_content)

print("\n" + "="*60)
print("JUDGE OUTPUT:")
print("="*60)
print(f"total_expected: {eval_result['total_expected']}")
print(f"total_actual: {eval_result['total_actual']}")
print(f"matches: {eval_result['matches']}")
print(f"\nmatched_findings:")
for mf in eval_result['details']:
    print(f"\n  Expected: {mf.expected}")
    print(f"  Matched: {mf.matched}")
    if mf.actual_finding:
        print(f"  Actual: {mf.actual_finding[:100]}...")  # truncate if long

print("\n" + "="*60)
print("CALCULATED METRICS:")
print("="*60)
print(f"Recall: {eval_result['recall']:.2f}")
print(f"Precision: {eval_result['precision']:.2f}")
print(f"F1 Score: {eval_result['f1']:.2f}")

In [None]:
# test all test cases

test_cases = [
    "01_sql_injection",
    "02_logic_bug",
    "03_code_quality",
    "04_multi_file_security",
    "05_multi_file_mixed"
]

async def run_all_tests():
    test_dir = Path("test-cases")
    results = []
    
    for test_name in test_cases:
        print(f"\n{'='*60}")
        print(f"TESTING: {test_name}")
        print('='*60)
        
        # Load files
        diff_file = test_dir / f"{test_name}.diff"
        diff_content = diff_file.read_text()
        expected_file = test_dir / f"{test_name}_expected.json"
        ground_truth_content = expected_file.read_text()
        
        # Run review
        report = await review_code(diff_content, save_output=False)
        
        # Evaluate
        eval_result = await evaluate_report(report, ground_truth_content)
        
        # Print detailed judge output
        print("\n" + "="*60)
        print("JUDGE OUTPUT:")
        print("="*60)
        print(f"total_expected: {eval_result['total_expected']}")
        print(f"total_actual: {eval_result['total_actual']}")
        print(f"matches: {eval_result['matches']}")
        print(f"\nmatched_findings:")
        for mf in eval_result['details']:
            print(f"\n  Expected: {mf.expected}")
            print(f"  Matched: {mf.matched}")
            if mf.actual_finding:
                print(f"  Actual: {mf.actual_finding[:100]}...")
        
        # Store results
        results.append({
            'test_name': test_name,
            'recall': eval_result['recall'],
            'precision': eval_result['precision'],
            'f1': eval_result['f1'],
            'passed': eval_result['recall'] >= 0.80 and 
                     eval_result['precision'] >= 0.85 and 
                     eval_result['f1'] >= 0.82
        })
        
        # Print calculated metrics
        print("\n" + "="*60)
        print("CALCULATED METRICS:")
        print("="*60)
        print(f"Recall: {eval_result['recall']:.2f}")
        print(f"Precision: {eval_result['precision']:.2f}")
        print(f"F1 Score: {eval_result['f1']:.2f}")
        print(f"Status: {'‚úì PASSED' if results[-1]['passed'] else '‚úó FAILED'}")
    
    # Print overall summary
    print(f"\n\n{'='*60}")
    print("OVERALL SUMMARY")
    print('='*60)
    for result in results:
        status = '‚úì' if result['passed'] else '‚úó'
        print(f"{status} {result['test_name']}: R={result['recall']:.2f} P={result['precision']:.2f} F1={result['f1']:.2f}")
    
    passed = sum(1 for r in results if r['passed'])
    print(f"\nPassed: {passed}/{len(results)}")
    
    return results

# Run all tests
results = await run_all_tests()

## BugsInPy Evals

In [None]:
# Hybrid Evaluation (Option 4): Models and Utilities
import re

def reverse_diff(bug_patch: str) -> str:
    """Reverses a bug patch to show bug introduction instead of fix."""
    lines = bug_patch.split('\n')
    reversed_lines = []
    for line in lines:
        if line.startswith('---') or line.startswith('+++'):
            reversed_lines.append(line)
        elif line.startswith('-') and not line.startswith('---'):
            reversed_lines.append('+' + line[1:])
        elif line.startswith('+') and not line.startswith('+++'):
            reversed_lines.append('-' + line[1:])
        else:
            reversed_lines.append(line)
    return '\n'.join(reversed_lines)

def parse_changed_locations(bug_patch: str) -> dict:
    """Extract files and lines changed in the patch."""
    changed_files = set()
    changed_lines = {}
    
    current_file = None
    for line in bug_patch.split('\n'):
        # Extract filename from +++ line
        if line.startswith('+++'):
            match = re.search(r'\+\+\+ b/(.+)', line)
            if match:
                current_file = match.group(1)
                changed_files.add(current_file)
                changed_lines[current_file] = set()
        
        # Extract line numbers from @@ hunk headers
        elif line.startswith('@@') and current_file:
            match = re.search(r'@@ -\d+,?\d* \+(\d+),?(\d*)', line)
            if match:
                start = int(match.group(1))
                count = int(match.group(2)) if match.group(2) else 1
                changed_lines[current_file].update(range(start, start + count))
    
    return {'files': changed_files, 'lines': changed_lines}

def parse_flagged_locations(report: str) -> dict:
    """Extract files and lines flagged in the report."""
    flagged_files = set()
    flagged_lines = {}
    
    # Parse markdown table from report
    in_table = False
    for line in report.split('\n'):
        if '| Issue | File | Lines |' in line:
            in_table = True
            continue
        if in_table and line.strip().startswith('|') and not line.strip().startswith('|---'):
            parts = [p.strip() for p in line.split('|')]
            if len(parts) > 3:
                file_path = parts[2]
                lines_str = parts[3]
                
                if file_path and file_path != 'File':
                    flagged_files.add(file_path)
                    if file_path not in flagged_lines:
                        flagged_lines[file_path] = set()
                    
                    # Strip brackets like [82-85] -> 82-85
                    lines_str = lines_str.strip('[]')
                    
                    # Parse line numbers (e.g., "7-10", "24-25", "9")
                    for line_range in lines_str.split(','):
                        line_range = line_range.strip()
                        if '-' in line_range:
                            start, end = map(int, line_range.split('-'))
                            flagged_lines[file_path].update(range(start, end + 1))
                        elif line_range.isdigit():
                            flagged_lines[file_path].add(int(line_range))
    
    return {'files': flagged_files, 'lines': flagged_lines}

def calculate_location_metrics(actual: dict, flagged: dict) -> dict:
    """
    Calculate location-based overlap metrics.
    
    Recall: Of all actual changed lines, how many did we flag (within 5 line tolerance)?
    Precision: Of all flagged lines, how many correspond to actual changes (within 5 line tolerance)?
    """
    # File-level recall
    file_recall = len(flagged['files'] & actual['files']) / len(actual['files']) if actual['files'] else 0.0
    
    # Line-level metrics
    total_actual_lines = 0
    total_flagged_lines = 0
    actual_lines_matched = 0  # For recall: how many actual lines have a nearby flagged line
    flagged_lines_matched = 0  # For precision: how many flagged lines have a nearby actual line
    
    for file in actual['files']:
        actual_lines = actual['lines'].get(file, set())
        flagged_lines_in_file = flagged['lines'].get(file, set())
        
        total_actual_lines += len(actual_lines)
        total_flagged_lines += len(flagged_lines_in_file)
        
        # Count actual lines that have at least one flagged line within 5 lines (for recall)
        for actual_line in actual_lines:
            if any(abs(actual_line - flagged_line) <= 5 for flagged_line in flagged_lines_in_file):
                actual_lines_matched += 1
        
        # Count flagged lines that have at least one actual line within 5 lines (for precision)
        for flagged_line in flagged_lines_in_file:
            if any(abs(flagged_line - actual_line) <= 5 for actual_line in actual_lines):
                flagged_lines_matched += 1
    
    line_recall = actual_lines_matched / total_actual_lines if total_actual_lines > 0 else 0.0
    line_precision = flagged_lines_matched / total_flagged_lines if total_flagged_lines > 0 else 0.0
    
    return {
        'file_recall': file_recall,
        'line_recall': line_recall,
        'line_precision': line_precision
    }

In [None]:
# Hybrid Evaluation Function

class LLMRelevance(BaseModel):
    """LLM's assessment of how relevant the review findings are to the actual fix."""
    relevance_score: float = Field(description="0.0-1.0: How well the review findings align with the actual fix")
    explanation: str = Field(description="Brief explanation of the score")

async def evaluate_hybrid(report: str, bug_patch: str) -> dict:
    """
    Hybrid evaluation: Location metrics (automated) + LLM relevance (semantic).
    
    Stage 1: Calculate automated location overlap
    Stage 2: If file_recall > 0, use LLM to judge semantic relevance
    
    Returns:
        dict with file_recall, line_precision, line_recall, llm_relevance, composite_score
    """
    
    # Stage 1: Automated location metrics
    actual_locations = parse_changed_locations(bug_patch)
    flagged_locations = parse_flagged_locations(report)
    location_metrics = calculate_location_metrics(actual_locations, flagged_locations)
    
    # Stage 2: LLM relevance (only if there's file overlap)
    llm_relevance = 0.0
    if location_metrics['file_recall'] > 0:
        llm_judge_instructions = """You are evaluating the semantic relevance of code review findings to an actual bug fix.

CRITICAL: Output ONLY valid JSON matching the specified schema. Do NOT wrap your response in markdown code fences or backticks.

Given:
1. ACTUAL FIX PATCH: The changes that were made to fix bugs
2. CODE REVIEW REPORT: What the review system found

Rate the relevance (0.0 to 1.0) of the review findings:
- 1.0: Findings directly identify the bugs that were fixed
- 0.7-0.9: Findings flag related issues that would lead to discovering the bugs
- 0.4-0.6: Findings flag the general area but miss specific bugs
- 0.1-0.3: Findings are tangentially related
- 0.0: No relevant findings

Be objective and strict in your assessment."""

        llm_judge = Agent(
            name="Relevance Judge",
            instructions=llm_judge_instructions,
            model=grok_4_1_fast,
            model_settings=ModelSettings(
                temperature=0.6,
                extra_args={"reasoning": {"enabled": True}}
            ),
            output_type=LLMRelevance
        )
        
        prompt = f"""
ACTUAL FIX PATCH:
{bug_patch}

CODE REVIEW REPORT:
{report}

Rate the semantic relevance of the review findings to the actual fix.
"""
        with trace("LLM Judge"):
            result = await Runner.run(llm_judge, prompt)
            llm_relevance = result.final_output.relevance_score
    
    # Composite score: average of line recall and LLM relevance
    composite_score = (location_metrics['line_recall'] + llm_relevance) / 2
    
    return {
        'file_recall': location_metrics['file_recall'],
        'line_precision': location_metrics['line_precision'],
        'line_recall': location_metrics['line_recall'],
        'llm_relevance': llm_relevance,
        'composite_score': composite_score
    }

In [None]:
# Enhanced test with 10 bugs - showing what agents missed

async def test_bugsinpy_with_miss_analysis(bugs_to_test: list[tuple[str, int]]) -> list[dict]:
    """
    Test multiple BugsInPy bugs with detailed miss analysis.
    Shows what the agents caught vs. what they missed.
    """
    results = []
    
    for project, bug_id in bugs_to_test:
        print(f"\n{'='*60}")
        print(f"TESTING: {project} bug {bug_id}")
        print('='*60)
        
        try:
            # Load bug patch
            bug_patch_path = Path(f"BugsInPy/projects/{project}/bugs/{bug_id}/bug_patch.txt")
            bug_patch = bug_patch_path.read_text()
            
            print("\nACTUAL FIX (first 500 chars):")
            print(bug_patch[:500])
            print("..." if len(bug_patch) > 500 else "")
            
            # Reverse diff
            reversed_diff = reverse_diff(bug_patch)
            
            # Run review
            report = await review_code(reversed_diff, save_output=False)
            
            # Hybrid evaluation
            eval_result = await evaluate_hybrid(report, bug_patch)
            
            # Parse locations to show what was missed
            actual_locations = parse_changed_locations(bug_patch)
            flagged_locations = parse_flagged_locations(report)
            
            # Find missed files
            missed_files = actual_locations['files'] - flagged_locations['files']
            
            # Find missed line ranges
            missed_lines = {}
            for file in actual_locations['files']:
                actual_lines = actual_locations['lines'].get(file, set())
                flagged_lines_in_file = flagged_locations['lines'].get(file, set())
                
                # Lines that weren't caught (no flagged line within 5 lines)
                uncaught = []
                for actual_line in actual_lines:
                    if not any(abs(actual_line - flagged_line) <= 5 for flagged_line in flagged_lines_in_file):
                        uncaught.append(actual_line)
                
                if uncaught:
                    missed_lines[file] = sorted(uncaught)
            
            # Store result
            result = {
                'project': project,
                'bug_id': bug_id,
                'file_recall': eval_result['file_recall'],
                'line_precision': eval_result['line_precision'],
                'line_recall': eval_result['line_recall'],
                'llm_relevance': eval_result['llm_relevance'],
                'composite_score': eval_result['composite_score'],
                'passed': eval_result['composite_score'] >= 0.60,
                'missed_files': list(missed_files),
                'missed_lines': missed_lines
            }
            results.append(result)
            
            # Print metrics
            print(f"\nüìç LOCATION METRICS:")
            print(f"  File Recall: {eval_result['file_recall']:.0%}")
            print(f"  Line Precision: {eval_result['line_precision']:.0%}")
            print(f"  Line Recall: {eval_result['line_recall']:.0%}")
            print(f"\nü§ñ LLM RELEVANCE: {eval_result['llm_relevance']:.0%}")
            print(f"üéØ COMPOSITE: {eval_result['composite_score']:.0%}")
            
            # Show what was missed
            if missed_files:
                print(f"\n‚ùå MISSED FILES: {', '.join(missed_files)}")
            
            if missed_lines:
                print(f"\n‚ùå MISSED LINES:")
                for file, lines in missed_lines.items():
                    line_ranges = []
                    start = lines[0]
                    end = start
                    for i in range(1, len(lines)):
                        if lines[i] == end + 1:
                            end = lines[i]
                        else:
                            line_ranges.append(f"{start}-{end}" if start != end else str(start))
                            start = lines[i]
                            end = start
                    line_ranges.append(f"{start}-{end}" if start != end else str(start))
                    print(f"  {file}: lines {', '.join(line_ranges)}")
            
            print(f"\n{'‚úì PASSED' if result['passed'] else '‚úó FAILED'}")
            
        except Exception as e:
            print(f"ERROR: {e}")
            import traceback
            traceback.print_exc()
            results.append({
                'project': project,
                'bug_id': bug_id,
                'error': str(e),
                'passed': False
            })
    
    # Print overall summary
    print(f"\n\n{'='*60}")
    print("OVERALL SUMMARY")
    print('='*60)
    for result in results:
        if 'error' in result:
            print(f"‚úó {result['project']}/{result['bug_id']}: ERROR")
        else:
            status = '‚úì' if result['passed'] else '‚úó'
            missed_info = ""
            if result['missed_files']:
                missed_info += f" | Missed files: {len(result['missed_files'])}"
            if result['missed_lines']:
                total_missed = sum(len(lines) for lines in result['missed_lines'].values())
                missed_info += f" | Missed lines: {total_missed}"
            
            print(f"{status} {result['project']}/{result['bug_id']}: "
                  f"Composite={result['composite_score']:.0%} "
                  f"(LineRec={result['line_recall']:.0%}, LLM={result['llm_relevance']:.0%})"
                  f"{missed_info}")
    
    passed = sum(1 for r in results if r.get('passed', False))
    print(f"\nPassed: {passed}/{len(results)} ({passed/len(results):.0%})")
    
    return results

In [None]:
# Test 20 diverse bugs - different projects not yet tested
bugs_to_test_20_diverse = [
    ("scrapy", 2),        # Web scraping framework
    ("ansible", 2),       # Automation tool
    ("pytest", 2),        # Testing framework
    ("sanic", 2),         # Async web framework
    ("spacy", 2),         # NLP library
    ("youtube-dl", 2),    # Video downloader
    ("thefuck", 2),       # Command corrector
    ("luigi", 4),         # Pipeline framework
    ("black", 2),         # Code formatter
    ("pandas", 3),        # Data analysis
    ("keras", 3),         # ML framework
    ("matplotlib", 2),    # Plotting library
    ("tornado", 2),       # Async networking
    ("tqdm", 2),          # Progress bar
    ("httpie", 2),        # HTTP client
    ("cookiecutter", 2),  # Project templating
    ("fastapi", 2),       # API framework
    ("scrapy", 3),        # More scrapy
    ("ansible", 3),       # More ansible
    ("pytest", 3),        # More pytest
]

results = await test_bugsinpy_with_miss_analysis(bugs_to_test_20_diverse)

## CVE Evals

In [None]:
# Evaluation Utilities
import re

def reverse_diff(bug_patch: str) -> str:
    """Reverses a bug patch to show bug introduction instead of fix."""
    lines = bug_patch.split('\n')
    reversed_lines = []
    for line in lines:
        if line.startswith('---') or line.startswith('+++'):
            reversed_lines.append(line)
        elif line.startswith('-') and not line.startswith('---'):
            reversed_lines.append('+' + line[1:])
        elif line.startswith('+') and not line.startswith('+++'):
            reversed_lines.append('-' + line[1:])
        else:
            reversed_lines.append(line)
    return '\n'.join(reversed_lines)

def parse_changed_locations(bug_patch: str) -> dict:
    """Extract files and lines changed in the patch."""
    changed_files = set()
    changed_lines = {}
    
    current_file = None
    for line in bug_patch.split('\n'):
        # Extract filename from +++ line
        if line.startswith('+++'):
            match = re.search(r'\+\+\+ b/(.+)', line)
            if match:
                current_file = match.group(1)
                changed_files.add(current_file)
                changed_lines[current_file] = set()
        
        # Extract line numbers from @@ hunk headers
        elif line.startswith('@@') and current_file:
            match = re.search(r'@@ -\d+,?\d* \+(\d+),?(\d*)', line)
            if match:
                start = int(match.group(1))
                count = int(match.group(2)) if match.group(2) else 1
                changed_lines[current_file].update(range(start, start + count))
    
    return {'files': changed_files, 'lines': changed_lines}

def parse_flagged_locations(report: str) -> dict:
    """Extract files and lines flagged in the report."""
    flagged_files = set()
    flagged_lines = {}
    
    # Parse markdown table from report
    in_table = False
    for line in report.split('\n'):
        if '| Issue | File | Lines |' in line:
            in_table = True
            continue
        if in_table and line.strip().startswith('|') and not line.strip().startswith('|---'):
            parts = [p.strip() for p in line.split('|')]
            if len(parts) > 3:
                file_path = parts[2]
                lines_str = parts[3]
                
                if file_path and file_path != 'File':
                    flagged_files.add(file_path)
                    if file_path not in flagged_lines:
                        flagged_lines[file_path] = set()
                    
                    # Strip brackets like [82-85] -> 82-85
                    lines_str = lines_str.strip('[]')
                    
                    # Parse line numbers (e.g., "7-10", "24-25", "9")
                    for line_range in lines_str.split(','):
                        line_range = line_range.strip()
                        if '-' in line_range:
                            start, end = map(int, line_range.split('-'))
                            flagged_lines[file_path].update(range(start, end + 1))
                        elif line_range.isdigit():
                            flagged_lines[file_path].add(int(line_range))
    
    return {'files': flagged_files, 'lines': flagged_lines}

def calculate_location_metrics(actual: dict, flagged: dict) -> dict:
    """
    Calculate location-based overlap metrics.
    
    Recall: Of all actual changed lines, how many did we flag (within 5 line tolerance)?
    Precision: Of all flagged lines, how many correspond to actual changes (within 5 line tolerance)?
    """
    # File-level recall
    file_recall = len(flagged['files'] & actual['files']) / len(actual['files']) if actual['files'] else 0.0
    
    # Line-level metrics
    total_actual_lines = 0
    total_flagged_lines = 0
    actual_lines_matched = 0  # For recall: how many actual lines have a nearby flagged line
    flagged_lines_matched = 0  # For precision: how many flagged lines have a nearby actual line
    
    for file in actual['files']:
        actual_lines = actual['lines'].get(file, set())
        flagged_lines_in_file = flagged['lines'].get(file, set())
        
        total_actual_lines += len(actual_lines)
        total_flagged_lines += len(flagged_lines_in_file)
        
        # Count actual lines that have at least one flagged line within 5 lines (for recall)
        for actual_line in actual_lines:
            if any(abs(actual_line - flagged_line) <= 5 for flagged_line in flagged_lines_in_file):
                actual_lines_matched += 1
        
        # Count flagged lines that have at least one actual line within 5 lines (for precision)
        for flagged_line in flagged_lines_in_file:
            if any(abs(flagged_line - actual_line) <= 5 for actual_line in actual_lines):
                flagged_lines_matched += 1
    
    line_recall = actual_lines_matched / total_actual_lines if total_actual_lines > 0 else 0.0
    line_precision = flagged_lines_matched / total_flagged_lines if total_flagged_lines > 0 else 0.0
    
    return {
        'file_recall': file_recall,
        'line_recall': line_recall,
        'line_precision': line_precision
    }

In [None]:
# CVE Dataset Loading Functions

def load_cve_dataset(json_path: str = "cve_dataset.json") -> list[dict]:
    """Load curated CVE dataset"""
    with open(json_path) as f:
        return json.load(f)

def load_cve_patch(cve_id: str, patches_dir: str = "cve_patches") -> str:
    """Load patch file for specific CVE"""
    patch_path = Path(patches_dir) / f"{cve_id}.patch"
    return patch_path.read_text()

# Test loading
cve_dataset = load_cve_dataset()
print(f"‚úì Loaded {len(cve_dataset)} CVEs")
print(f"\nCWE Coverage:")
cwe_counts = {}
for cve in cve_dataset:
    cwe = cve['cwe_name']
    cwe_counts[cwe] = cwe_counts.get(cwe, 0) + 1

for cwe, count in sorted(cwe_counts.items()):
    print(f"  {cwe}: {count}")

In [None]:
# Phase 5: Helper Functions

def check_security_agent_flagged(report: str) -> bool:
    """Check if Security Agent found anything"""
    return "Security" in report and "Found By" in report

def extract_max_severity(report: str) -> int:
    """Extract highest severity from report (1-10 scale)"""
    import re
    # Parse markdown table for severity column (finds Security findings)
    severities = re.findall(r'\|\s*(\d+)\s*\|.*\|\s*Security\s*\|', report, re.IGNORECASE)
    return max(map(int, severities)) if severities else 0

print("‚úì Helper functions loaded")

In [None]:
# CVE-Specific Evaluation

class LLMRelevance(BaseModel):
    """LLM's assessment of how relevant the review findings are to the actual fix."""
    relevance_score: float = Field(description="0.0-1.0: How well the review findings align with the actual fix")
    explanation: str = Field(description="Brief explanation of the score")

async def evaluate_hybrid_cve(report: str, patch: str, 
                               cve_id: str, cwe_id: str, cwe_name: str,
                               cvss_score: float, severity: str) -> dict:
    """
    Hybrid evaluation for CVEs: Location metrics + LLM relevance + Security detection.
    """
    
    # Stage 1: Automated location metrics
    actual_locations = parse_changed_locations(patch)
    flagged_locations = parse_flagged_locations(report)
    location_metrics = calculate_location_metrics(actual_locations, flagged_locations)
    
    # Stage 2: LLM relevance with CVE context
    llm_relevance = 0.0
    if location_metrics['file_recall'] > 0:
        llm_judge_cve_instructions = f"""You are evaluating code review findings against a real CVE.

CRITICAL: Output ONLY valid JSON matching the specified schema. Do NOT wrap your response in markdown code fences or backticks.

Given:
1. CVE ID: {cve_id}
2. CWE Type: {cwe_name} ({cwe_id})
3. CVSS Score: {cvss_score} ({severity})
4. ACTUAL FIX PATCH: The changes that fixed the vulnerability
5. CODE REVIEW REPORT: What our system found

Rate the relevance (0.0 to 1.0) of the review findings:
- 1.0: Findings directly identify the CVE vulnerability type
- 0.7-0.9: Findings flag related security issues that would lead to discovery
- 0.4-0.6: Findings flag the general area but miss specific vulnerability
- 0.1-0.3: Findings are tangentially related
- 0.0: No relevant findings

Special attention:
- Did the Security Agent flag this as a security issue?
- Is the severity appropriate for the CVE?"""

        llm_judge = Agent(
            name="CVE Relevance Judge",
            instructions=llm_judge_cve_instructions,
            model=grok_4_1_fast,
            model_settings=ModelSettings(
                temperature=0.6,
                extra_args={"reasoning": {"enabled": True}}
            ),
            output_type=LLMRelevance
        )
        
        prompt = f"""
ACTUAL FIX PATCH:
{patch}

CODE REVIEW REPORT:
{report}

Rate the semantic relevance of the review findings to this CVE.
"""
        with trace("CVE LLM Judge"):
            result = await Runner.run(llm_judge, prompt)
            llm_relevance = result.final_output.relevance_score
    
    # Stage 3: Security detection check
    security_flagged = check_security_agent_flagged(report)
    severity_from_report = extract_max_severity(report)
    severity_appropriate = abs(severity_from_report - cvss_score) <= 3 if severity_from_report > 0 else False
    
    # Composite score: average of line recall and LLM relevance
    composite_score = (location_metrics['line_recall'] + llm_relevance) / 2
    
    return {
        'file_recall': location_metrics['file_recall'],
        'line_precision': location_metrics['line_precision'],
        'line_recall': location_metrics['line_recall'],
        'llm_relevance': llm_relevance,
        'composite_score': composite_score,
        'severity_appropriate': severity_appropriate,
        'security_finding_present': security_flagged
    }

In [None]:
# Phase 4: CVE Testing Framework

async def test_cve_benchmark(cve_dataset: list[dict]) -> list[dict]:
    """
    Test code review system on CVE dataset.
    Reuses existing hybrid evaluation with CVE enhancements.
    """
    results = []
    
    for cve in cve_dataset:
        print(f"\n{'='*60}")
        print(f"TESTING: {cve['cve_id']} - {cve['cwe_name']}")
        print(f"CVSS: {cve['cvss_score']} | Project: {cve['project']}")
        print('='*60)
        
        try:
            # Load patch
            patch = load_cve_patch(cve['cve_id'])
            
            # Reverse diff (show vulnerability introduction)
            reversed_diff = reverse_diff(patch)
            
            # Run code review
            report = await review_code(reversed_diff, save_output=False)
            
            # Hybrid evaluation with CVE context
            eval_result = await evaluate_hybrid_cve(
                report, patch,
                cve['cve_id'], cve['cwe_id'], cve['cwe_name'],
                cve['cvss_score'], cve['severity']
            )
            
            result = {
                'cve_id': cve['cve_id'],
                'cwe_id': cve['cwe_id'],
                'cwe_name': cve['cwe_name'],
                'cvss_score': cve['cvss_score'],
                'file_recall': eval_result['file_recall'],
                'line_recall': eval_result['line_recall'],
                'llm_relevance': eval_result['llm_relevance'],
                'composite_score': eval_result['composite_score'],
                'security_flagged': eval_result['security_finding_present'],
                'severity_appropriate': eval_result['severity_appropriate'],
                'passed': eval_result['composite_score'] >= 0.60
            }
            results.append(result)
            
            # Print metrics
            print(f"\nüìç Location: FileRec={result['file_recall']:.0%}, LineRec={result['line_recall']:.0%}")
            print(f"ü§ñ LLM Relevance: {result['llm_relevance']:.0%}")
            print(f"üõ°Ô∏è  Security Agent: {'‚úì FLAGGED' if result['security_flagged'] else '‚úó MISSED'}")
            print(f"üìä Composite: {result['composite_score']:.0%} - {'‚úì PASSED' if result['passed'] else '‚úó FAILED'}")
            
        except Exception as e:
            print(f"\n‚ùå ERROR: {e}")
            import traceback
            traceback.print_exc()
            results.append({
                'cve_id': cve['cve_id'],
                'error': str(e),
                'passed': False
            })
    
    # Summary
    print(f"\n\n{'='*60}")
    print("CVE BENCHMARK SUMMARY")
    print('='*60)
    
    valid_results = [r for r in results if 'error' not in r]
    passed = sum(r['passed'] for r in valid_results)
    security_detected = sum(r['security_flagged'] for r in valid_results)
    
    print(f"Overall Pass Rate: {passed}/{len(valid_results)} ({passed/len(valid_results):.0%})")
    print(f"Security Agent Detection: {security_detected}/{len(valid_results)} ({security_detected/len(valid_results):.0%})")
    
    # By CWE type
    print(f"\nüìã Results by CWE Type:")
    cwe_results = {}
    for r in valid_results:
        cwe = r['cwe_name']
        if cwe not in cwe_results:
            cwe_results[cwe] = {'total': 0, 'passed': 0}
        cwe_results[cwe]['total'] += 1
        cwe_results[cwe]['passed'] += r['passed']
    
    for cwe, stats in sorted(cwe_results.items()):
        print(f"  {cwe}: {stats['passed']}/{stats['total']} passed")
    
    # Save results
    os.makedirs("user-data", exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_path = f"user-data/cve_benchmark_{timestamp}.json"
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nüíæ Results saved to {results_path}")
    
    return results

print("‚úì CVE testing framework loaded")

In [None]:
# # Phase 6: Run CVE Benchmark

# # Load CVE dataset
# cve_dataset = load_cve_dataset()
# print(f"Loaded {len(cve_dataset)} CVEs\n")

# # Run benchmark on all CVEs
# results = await test_cve_benchmark(cve_dataset)

In [None]:
# Test the two problematic CVEs that got stuck

problematic_cves = ["CVE-2024-53908", "CVE-2024-23346"]  # SQL Injection and Code Injection
problematic_dataset = [cve for cve in cve_dataset if cve['cve_id'] in problematic_cves]

print(f"Testing {len(problematic_dataset)} problematic CVEs with fixed schema (max 20 lines per finding)")
print("="*60 + "\n")

# Run quick test on just these two
results_test = await test_cve_benchmark(problematic_dataset)

In [None]:
# CVE Benchmark Comparison Summary

print("\n" + "="*60)
print("CVE BENCHMARK COMPARISON")
print("="*60)
print("\nConfiguration changes:")
print("  - Security KB: 13 patterns ‚Üí 43 patterns (OWASP Top 10 2021)")
print("  - RAG retrieval: n_results=5 ‚Üí n_results=15")
print("\n" + "-"*60)
print("RESULTS COMPARISON:")
print("-"*60)
print("\nüìä No RAG (Baseline):")
print("  Pass Rate: 16/17 (94%)")
print("  Security Detection: 16/17 (94%)")
print("\nüîç RAG with 13 patterns (n_results=5):")
print("  Pass Rate: 15/17 (88%)")
print("  Security Detection: 16/17 (94%)")
print("\nüéØ RAG with 43 patterns (n_results=15):")
print("  Run cell above to see results...")
print("\n" + "="*60)