# üîç AML Investigation Agent (Gemma + GRPO)

An autonomous AI Agent for investigating financial transaction graphs to identify money laundering patterns using FunctionGemma function-calling format and GRPO reinforcement learning.

## Architecture Overview

| Component | Technology | Description |
|-----------|------------|-------------|
| **Base Model** | Gemma (Gemma-2-2B-IT) | XML-style function calling format |
| **Fine-Tuning** | Unsloth + LoRA/QLoRA | 4-bit quantization, 2x faster training |
| **Data Processing** | Polars | High-performance dataframes |
| **Graph Analysis** | NetworkX | Transaction network traversal |
| **Agent Framework** | LangGraph + MemorySaver | Stateful exploration with checkpointing |
| **RL Training** | TRL GRPOTrainer | Group Relative Policy Optimization |
| **Observability** | MLflow Tracing | Full agent trace logging |
| **Evaluation** | Gemini LLM-as-Judge | Strategy quality scoring |

## Evaluation Flow

This notebook implements a **three-stage evaluation** to measure training impact:

| Stage | Model State | Purpose |
|-------|-------------|---------|
| **1. Baseline** | Pre-trained Gemma-2-2B-IT | Measure zero-shot performance |
| **2. Post-SFT** | After Supervised Fine-Tuning | Measure SFT improvement |
| **3. Post-GRPO** | After RL Training | Measure GRPO improvement |

## Gemma Tool Format (with Internal Reasoning)

```xml
<thinking>
[Internal reasoning about investigation strategy]
</thinking>
<start_function_call>call:function_name{param: value}</start_function_call>
<start_function_output>call:function_name{result: value}</start_function_output>
```

## Investigation Tools

| Tool | Description |
|------|-------------|
| `get_account_summary` | Get account metadata and risk assessment |
| `get_recent_transactions` | Get top-5 recent transaction flows |
| `check_sanctions_list` | Verify against OFAC watchlist |
| `submit_sar` | Terminal action - Submit Suspicious Activity Report |

## Win Condition
`submit_sar` on an entity that is **both sanctioned AND reachable via a laundering path** from the seed account.


## 1. Setup & Configuration


In [None]:
# ============================================================================
# AML INVESTIGATION AGENT - Setup & Dependencies
# ============================================================================

import os
import sys
import json
import random
import time
import re
import uuid
from datetime import datetime
from dataclasses import dataclass, field
from typing import Dict, List, Any, Tuple, Optional, Annotated, Literal, TypedDict
import operator
from pathlib import Path

# Numerical & Data Processing
import numpy as np
import polars as pl
import pandas as pd
import networkx as nx

# ML & Deep Learning
import torch

# Environment
from dotenv import load_dotenv
load_dotenv()

# ============================================================================
# CONFIGURATION
# ============================================================================

# Model Configuration
MODEL_NAME = "google/gemma-2-2b-it"  # FunctionGemma compatible base
GEMINI_MODEL = "gemini-2.0-flash"    # LLM-as-Judge

# Agent Configuration
MAX_STEPS = 50                        # Max steps per investigation
MAX_HISTORY_TURNS = 6                 # Conversation history limit

# Training Configuration
SFT_EPOCHS = 3
SFT_LEARNING_RATE = 2e-4
GRPO_EPOCHS = 1
GRPO_LEARNING_RATE = 5e-6

# LoRA Configuration (per design doc Section 3.3)
LORA_R = 32                           # Higher rank for complex reasoning
LORA_ALPHA = 64
LORA_TARGET_MODULES = [
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj"
]

# Evaluation Configuration
EVAL_EPISODES = 10                    # Episodes per evaluation stage

# Random Seed
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Paths (relative to notebook location)
# Notebook is at: notebooks/agents/aml_investigation_agent_v2.ipynb
# Project root is 2 levels up
NOTEBOOK_DIR = Path(".").resolve()
PROJECT_ROOT = NOTEBOOK_DIR.parent.parent  # Go up from notebooks/agents to project root
DATA_DIR = PROJECT_ROOT / "data" / "raw"
MODELS_DIR = PROJECT_ROOT / "models"
OUTPUT_DIR = PROJECT_ROOT / "outputs"

# Ensure directories exist
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Dataset Selection
DATASET_SIZE = "Small"   # Options: "Small", "Medium", "Large"
DATASET_PREFIX = "LI"    # Options: "LI" (Low Illicit), "HI" (High Illicit)

# ============================================================================
# RESULTS STORAGE - For comparison across training stages
# ============================================================================

evaluation_results = {
    "baseline": None,
    "post_sft": None,
    "post_grpo": None,
}

print("=" * 60)
print("üîç AML INVESTIGATION AGENT - Configuration")
print("=" * 60)
print(f"  Model:           {MODEL_NAME}")
print(f"  Judge:           {GEMINI_MODEL}")
print(f"  Dataset:         {DATASET_PREFIX}-{DATASET_SIZE}")
print(f"  Max Steps:       {MAX_STEPS}")
print(f"  Eval Episodes:   {EVAL_EPISODES}")
print(f"  LoRA Rank:       {LORA_R}")
print(f"  LoRA Alpha:      {LORA_ALPHA}")
print(f"  Random Seed:     {RANDOM_SEED}")
print(f"  Project Root:    {PROJECT_ROOT}")
print(f"  Models Dir:      {MODELS_DIR}")
print(f"  GPU Available:   {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU Device:      {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory:      {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print("=" * 60)


## 2. Data Loading with Polars

Download the IBM AML dataset from Kaggle (if not already present) and load using Polars for high-performance data manipulation.

**Dataset**: [IBM Transactions for Anti Money Laundering (AML)](https://www.kaggle.com/datasets/ealtman2019/ibm-transactions-for-anti-money-laundering-aml/data)


In [None]:
# ============================================================================
# DOWNLOAD DATASET FROM KAGGLE - IBM AML Transactions Dataset
# Dataset: https://www.kaggle.com/datasets/ealtman2019/ibm-transactions-for-anti-money-laundering-aml
# ============================================================================

import zipfile
import shutil

# Kaggle dataset identifier
KAGGLE_DATASET = "ealtman2019/ibm-transactions-for-anti-money-laundering-aml"

# Check if required files exist
trans_file = DATA_DIR / f"{DATASET_PREFIX}-{DATASET_SIZE}_Trans.csv"
accounts_file = DATA_DIR / f"{DATASET_PREFIX}-{DATASET_SIZE}_accounts.csv"
patterns_file = DATA_DIR / f"{DATASET_PREFIX}-{DATASET_SIZE}_Patterns.txt"

files_exist = trans_file.exists() and accounts_file.exists() and patterns_file.exists()

if files_exist:
    print(f"‚úì Dataset already exists at {DATA_DIR}")
    print(f"  - Transactions: {trans_file.name}")
    print(f"  - Accounts: {accounts_file.name}")
    print(f"  - Patterns: {patterns_file.name}")
else:
    print(f"üì• Dataset not found. Downloading from Kaggle...")
    print(f"   Dataset: {KAGGLE_DATASET}")
    
    # Ensure data directory exists
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    
    try:
        # Import and authenticate Kaggle API
        from kaggle.api.kaggle_api_extended import KaggleApi
        
        api = KaggleApi()
        api.authenticate()
        
        print(f"   ‚úì Kaggle API authenticated")
        
        # Download dataset
        print(f"   Downloading dataset to {DATA_DIR}...")
        api.dataset_download_files(
            dataset=KAGGLE_DATASET,
            path=str(DATA_DIR),
            unzip=True,
            quiet=False
        )
        
        print(f"   ‚úì Download complete!")
        
        # List downloaded files
        print(f"\n   Downloaded files:")
        for f in sorted(DATA_DIR.glob("*")):
            size_mb = f.stat().st_size / (1024 * 1024)
            print(f"     - {f.name} ({size_mb:.1f} MB)")
        
    except ImportError:
        print("   ‚ùå Kaggle package not installed.")
        print("   Run: pip install kaggle")
        print("   Then set up ~/.kaggle/kaggle.json with your API credentials")
        raise ImportError("Please install kaggle package: pip install kaggle")
        
    except Exception as e:
        print(f"   ‚ùå Error downloading dataset: {e}")
        print("\n   Manual download instructions:")
        print(f"   1. Visit: https://www.kaggle.com/datasets/{KAGGLE_DATASET}")
        print(f"   2. Download and extract to: {DATA_DIR}")
        raise

# Verify files exist after download
assert trans_file.exists(), f"Transaction file not found: {trans_file}"
assert accounts_file.exists(), f"Accounts file not found: {accounts_file}"
assert patterns_file.exists(), f"Patterns file not found: {patterns_file}"

print(f"\n{'=' * 60}")
print(f"‚úì DATASET READY: {DATASET_PREFIX}-{DATASET_SIZE}")
print(f"{'=' * 60}")


In [None]:
# ============================================================================
# DATA LOADING - IBM AML Dataset with Polars
# ============================================================================

print("üìä Loading IBM AML Dataset with Polars...")

# Load transactions
raw_trans_pl = pl.read_csv(
    trans_file,
    new_columns=[
        'timestamp', 'from_bank', 'from_account', 'to_bank', 'to_account',
        'amount_received', 'receiving_currency', 'amount_paid',
        'payment_currency', 'payment_format', 'is_laundering'
    ],
    skip_rows=1
)
print(f"‚úì Loaded {len(raw_trans_pl):,} transactions")

# Load accounts
raw_accounts_pl = pl.read_csv(accounts_file)
print(f"‚úì Loaded {len(raw_accounts_pl):,} accounts")

# Process transactions - Create unique account IDs
transactions_pl = raw_trans_pl.with_columns([
    (pl.col('from_bank').cast(pl.Utf8) + '-' + pl.col('from_account').cast(pl.Utf8)).alias('from_account_id'),
    (pl.col('to_bank').cast(pl.Utf8) + '-' + pl.col('to_account').cast(pl.Utf8)).alias('to_account_id'),
    (pl.lit('TXN-') + pl.arange(0, pl.len()).cast(pl.Utf8)).alias('transaction_id'),
    pl.col('is_laundering').cast(pl.Int32),
]).select([
    'transaction_id',
    pl.col('from_account_id').alias('from_account'),
    pl.col('to_account_id').alias('to_account'),
    pl.col('amount_received').alias('amount'),
    pl.col('receiving_currency').alias('currency'),
    'timestamp', 'is_laundering', 'payment_format',
])

# Identify laundering destinations for sanctioned marking
laundering_dests = set(
    transactions_pl.filter(pl.col('is_laundering') == 1)['to_account'].unique().to_list()
)

# Process accounts - Add risk scores and sanctioned flags
accounts_pl = raw_accounts_pl.rename({
    'Bank Name': 'bank_name', 'Bank ID': 'bank_id',
    'Account Number': 'account_number', 'Entity ID': 'entity_id', 'Entity Name': 'entity_name'
}).with_columns([
    (pl.col('bank_id').cast(pl.Utf8) + '-' + pl.col('account_number').cast(pl.Utf8)).alias('account_id'),
    pl.when(pl.col('entity_name').str.contains('Corporation')).then(pl.lit('Corporate'))
        .when(pl.col('entity_name').str.contains('Partnership')).then(pl.lit('Partnership'))
        .when(pl.col('entity_name').str.contains('Sole Proprietorship')).then(pl.lit('Individual'))
        .otherwise(pl.lit('Unknown')).alias('account_type'),
])

# Add sanctioned flag (30% of laundering destinations) and risk scores
account_ids = accounts_pl['account_id'].to_list()
accounts_pl = accounts_pl.with_columns([
    pl.Series('is_sanctioned', [acc in laundering_dests and random.random() < 0.3 for acc in account_ids]),
    pl.Series('risk_score', [round(random.uniform(0.1, 0.9), 2) for _ in range(len(account_ids))]),
]).select(['account_id', 'bank_id', 'bank_name', 'entity_id', 'entity_name', 'account_type', 'risk_score', 'is_sanctioned'])

# Convert to pandas for NetworkX
transactions_df = transactions_pl.to_pandas()
accounts_df = accounts_pl.to_pandas()

# Summary
n_accounts, n_transactions = len(accounts_df), len(transactions_df)
n_laundering = int(transactions_df['is_laundering'].sum())
n_sanctioned = int(accounts_df['is_sanctioned'].sum())

print(f"\n{'=' * 60}")
print(f"üìä DATASET SUMMARY: {DATASET_PREFIX}-{DATASET_SIZE}")
print(f"{'=' * 60}")
print(f"  Accounts:             {n_accounts:>12,}")
print(f"  Transactions:         {n_transactions:>12,}")
print(f"  Laundering Txns:      {n_laundering:>12,} ({n_laundering/n_transactions*100:.2f}%)")
print(f"  Sanctioned Accounts:  {n_sanctioned:>12,}")
print(f"{'=' * 60}")


In [None]:
# ============================================================================
# PARSE LAUNDERING PATTERNS - Extract Pattern Seeds for Training
# ============================================================================

@dataclass
class LaunderingPattern:
    """Represents a single money laundering pattern from the dataset."""
    pattern_type: str
    pattern_info: str
    transactions: List[dict]
    accounts_involved: set
    
    @property
    def seed_account(self) -> str:
        return self.transactions[0].get('from_account', '') if self.transactions else ''
    
    @property
    def terminal_account(self) -> str:
        return self.transactions[-1].get('to_account', '') if self.transactions else ''
    
    @property
    def total_amount(self) -> float:
        return sum(t.get('amount', 0) for t in self.transactions)
    
    @property
    def hop_count(self) -> int:
        return len(self.transactions)


def parse_patterns_file(filepath: Path) -> List[LaunderingPattern]:
    """Parse patterns file to extract laundering patterns."""
    patterns = []
    current_pattern = None
    
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            
            if line.startswith('BEGIN LAUNDERING ATTEMPT'):
                match = re.match(r'BEGIN LAUNDERING ATTEMPT - (\w+(?:-\w+)?):?\s*(.*)', line)
                if match:
                    current_pattern = LaunderingPattern(
                        pattern_type=match.group(1),
                        pattern_info=match.group(2).strip() if match.group(2) else "",
                        transactions=[], accounts_involved=set()
                    )
            
            elif line.startswith('END LAUNDERING ATTEMPT'):
                if current_pattern and current_pattern.transactions:
                    patterns.append(current_pattern)
                current_pattern = None
            
            elif current_pattern and line and not line.startswith('BEGIN') and not line.startswith('END'):
                parts = line.split(',')
                if len(parts) >= 7:
                    try:
                        from_account = f"{parts[1].strip()}-{parts[2].strip()}"
                        to_account = f"{parts[3].strip()}-{parts[4].strip()}"
                        amount = float(parts[5].strip())
                        
                        current_pattern.transactions.append({
                            'timestamp': parts[0].strip(),
                            'from_account': from_account, 'to_account': to_account,
                            'amount': amount, 'currency': parts[6].strip(),
                        })
                        current_pattern.accounts_involved.add(from_account)
                        current_pattern.accounts_involved.add(to_account)
                    except (ValueError, IndexError):
                        pass
    
    return patterns


# Parse patterns
laundering_patterns = parse_patterns_file(patterns_file)

# Statistics
pattern_types = {}
for p in laundering_patterns:
    pattern_types[p.pattern_type] = pattern_types.get(p.pattern_type, 0) + 1

print(f"\n{'=' * 60}")
print(f"üîó LAUNDERING PATTERNS PARSED")
print(f"{'=' * 60}")
print(f"  Total Patterns: {len(laundering_patterns):,}")
for ptype, count in sorted(pattern_types.items(), key=lambda x: -x[1]):
    print(f"    - {ptype:<20} {count:>6,}")
print(f"{'=' * 60}")


## 3. Financial Environment with MLflow-Instrumented Tools

Build the transaction graph with NetworkX and create the `FinancialEnvironment` class with MLflow-traced tool functions.


In [None]:
# ============================================================================
# FINANCIAL ENVIRONMENT - NetworkX Graph with MLflow-Traced Tools
# ============================================================================

import mlflow

mlflow.set_experiment("AML_Investigation_Agent_v2")
print("‚úì MLflow experiment: AML_Investigation_Agent_v2")


@dataclass 
class FinancialEnvironment:
    """Financial investigation environment with path-validated SAR evaluation."""
    graph: nx.DiGraph = field(default_factory=nx.DiGraph)
    accounts: Dict[str, dict] = field(default_factory=dict)
    laundering_targets: List[str] = field(default_factory=list)
    all_sanctioned: set = field(default_factory=set)
    laundering_destinations: set = field(default_factory=set)
    transitive_illicit: set = field(default_factory=set)
    current_start_account: str = ""
    
    @classmethod
    def from_dataframes(cls, transactions_df: pd.DataFrame, accounts_df: pd.DataFrame) -> 'FinancialEnvironment':
        env = cls()
        
        for _, row in accounts_df.iterrows():
            env.accounts[row['account_id']] = row.to_dict()
            env.graph.add_node(row['account_id'], **row.to_dict())
        
        for _, row in transactions_df.iterrows():
            env.graph.add_edge(
                row['from_account'], row['to_account'],
                transaction_id=row['transaction_id'], amount=row['amount'],
                currency=row.get('currency', 'USD'), timestamp=row['timestamp'],
                is_laundering=row['is_laundering']
            )
        
        env.all_sanctioned = set(accounts_df[accounts_df['is_sanctioned']]['account_id'])
        laundering_txns = transactions_df[transactions_df['is_laundering'] == 1]
        env.laundering_destinations = set(laundering_txns['to_account'].unique())
        env._compute_transitive_illicit()
        env.laundering_targets = list(env.all_sanctioned & env.laundering_destinations)
        
        return env
    
    def _compute_transitive_illicit(self):
        self.transitive_illicit = set()
        laundering_sources = set()
        for u, v, data in self.graph.edges(data=True):
            if data.get('is_laundering', 0) == 1:
                laundering_sources.add(u)
                self.transitive_illicit.add(u)
                self.transitive_illicit.add(v)
        
        for source in laundering_sources:
            visited = {source}
            queue = [source]
            while queue:
                node = queue.pop(0)
                for neighbor in self.graph.successors(node):
                    edge_data = self.graph.edges[node, neighbor]
                    if edge_data.get('is_laundering', 0) == 1 and neighbor not in visited:
                        visited.add(neighbor)
                        self.transitive_illicit.add(neighbor)
                        queue.append(neighbor)
    
    def is_on_laundering_path(self, entity_id: str, max_depth: int = 10) -> bool:
        if not self.current_start_account:
            return False
        try:
            for path in nx.all_simple_paths(self.graph, self.current_start_account, entity_id, cutoff=max_depth):
                if all(self.graph.edges[path[i], path[i+1]].get('is_laundering', 0) == 1 for i in range(len(path)-1)):
                    return True
            return False
        except (nx.NetworkXNoPath, nx.NodeNotFound):
            return False
    
    @mlflow.trace(span_type="TOOL")
    def get_account_summary(self, account_id: str) -> dict:
        if account_id not in self.accounts:
            return {"error": f"Account {account_id} not found"}
        acc = self.accounts[account_id]
        return {
            "account_id": account_id, "account_type": acc.get('account_type', 'Unknown'),
            "entity_name": acc.get('entity_name', 'Unknown'), "bank_name": acc.get('bank_name', 'Unknown'),
            "risk_score": round(acc.get('risk_score', 0), 2), "is_sanctioned": acc.get('is_sanctioned', False),
            "transitive_illicit": account_id in self.transitive_illicit,
        }
    
    @mlflow.trace(span_type="TOOL")
    def get_recent_transactions(self, account_id: str, direction: str = "outgoing", limit: int = 5) -> List[dict]:
        if account_id not in self.graph:
            return []
        edges = list(self.graph.out_edges(account_id, data=True) if direction == "outgoing" 
                     else self.graph.in_edges(account_id, data=True))
        edges = sorted(edges, key=lambda e: e[2].get('amount', 0), reverse=True)[:limit]
        
        results = []
        for edge in edges:
            target = edge[1] if direction == "outgoing" else edge[0]
            results.append({
                "counterparty": target, "amount": round(edge[2].get('amount', 0), 2),
                "currency": edge[2].get('currency', 'USD'), "is_laundering": edge[2].get('is_laundering', 0),
                "high_risk_indicator": target in self.transitive_illicit,
            })
        return results
    
    @mlflow.trace(span_type="TOOL")
    def check_sanctions_list(self, entity_id: str) -> dict:
        is_sanctioned = entity_id in self.all_sanctioned
        return {"entity_id": entity_id, "on_sanctions_list": is_sanctioned, "list_type": "OFAC SDN" if is_sanctioned else None}
    
    @mlflow.trace(span_type="TOOL")
    def submit_sar(self, entity_id: str, reason: str) -> dict:
        is_sanctioned = entity_id in self.all_sanctioned
        is_primary = entity_id in self.laundering_targets
        on_path = self.is_on_laundering_path(entity_id)
        correct = is_primary or (is_sanctioned and on_path)
        
        if is_primary:
            eval_reason = "PRIMARY_TARGET: sanctioned + receives laundering directly"
        elif is_sanctioned and on_path:
            eval_reason = f"VALID: sanctioned + on laundering path from {self.current_start_account}"
        elif is_sanctioned:
            eval_reason = "INVALID: sanctioned but NOT on laundering path from start"
        else:
            eval_reason = "INVALID: entity is not sanctioned"
        
        return {
            "entity_id": entity_id, "reason": reason, "report_id": f"SAR-{uuid.uuid4().hex[:8].upper()}",
            "correct_identification": correct, "is_sanctioned": is_sanctioned,
            "is_primary_target": is_primary, "on_laundering_path": on_path, "evaluation_reason": eval_reason,
        }
    
    def reset_investigation(self, start_account: str):
        self.current_start_account = start_account


# Build environment
env = FinancialEnvironment.from_dataframes(transactions_df, accounts_df)

print(f"\n{'=' * 60}")
print(f"üè¶ FINANCIAL ENVIRONMENT BUILT")
print(f"{'=' * 60}")
print(f"  Graph Nodes:          {env.graph.number_of_nodes():>12,}")
print(f"  Graph Edges:          {env.graph.number_of_edges():>12,}")
print(f"  Sanctioned Accounts:  {len(env.all_sanctioned):>12,}")
print(f"  Primary Targets:      {len(env.laundering_targets):>12,}")
print(f"  Transitive Illicit:   {len(env.transitive_illicit):>12,}")
print(f"{'=' * 60}")


## 4. FunctionGemma Tool Format and Parsing

Implement the FunctionGemma XML-style function calling format with internal reasoning (`<thinking>` tags) and robust parsing.


In [None]:
# ============================================================================
# FUNCTIONGEMMA TOOL FORMAT & PARSING (with Internal Reasoning)
# Ref: https://ai.google.dev/gemma/docs/functiongemma
# Ref: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/FunctionGemma_(270M).ipynb
# ============================================================================

TOOL_DECLARATIONS = """<start_function_declaration>
{"name": "get_account_summary", "description": "Get account metadata and risk assessment for an account ID",
 "parameters": {"type": "object", "properties": {"account_id": {"type": "string", "description": "The unique account identifier"}}, "required": ["account_id"]}}
</start_function_declaration>

<start_function_declaration>
{"name": "get_recent_transactions", "description": "Get the top-5 recent transactions by amount for an account",
 "parameters": {"type": "object", "properties": {"account_id": {"type": "string", "description": "The account to get transactions for"}, "direction": {"type": "string", "enum": ["outgoing", "incoming"]}}, "required": ["account_id"]}}
</start_function_declaration>

<start_function_declaration>
{"name": "check_sanctions_list", "description": "Check if an entity is on the OFAC sanctions list",
 "parameters": {"type": "object", "properties": {"entity_id": {"type": "string", "description": "The account/entity ID to check"}}, "required": ["entity_id"]}}
</start_function_declaration>

<start_function_declaration>
{"name": "submit_sar", "description": "Submit a Suspicious Activity Report - TERMINAL ACTION",
 "parameters": {"type": "object", "properties": {"entity_id": {"type": "string", "description": "The sanctioned entity to report"}, "reason": {"type": "string", "description": "Justification for the SAR"}}, "required": ["entity_id", "reason"]}}
</start_function_declaration>"""

# Investigation prompt with internal reasoning (FunctionGemma doesn't support thinking by default)
INVESTIGATION_PROMPT = """You are an expert AML investigator. Investigate financial transaction networks to identify money laundering.

STRATEGY:
1. START: Get account summary of seed account
2. EXPLORE: Get recent transactions to find money flows
3. FOLLOW: Investigate high-amount counterparties
4. VERIFY: Check sanctions list for suspicious entities
5. REPORT: Submit SAR only after confirming sanctioned status

{tool_declarations}

RESPONSE FORMAT:
You MUST first provide your internal reasoning in <thinking> tags, then make a function call.

<thinking>
[Your reasoning about what to do next and why]
</thinking>
<start_function_call>call:function_name{{param: value}}</start_function_call>
"""

# Tool mapping to environment functions
TOOL_MAPPING = {
    "get_account_summary": lambda args: env.get_account_summary(args.get("account_id", "")),
    "get_recent_transactions": lambda args: env.get_recent_transactions(args.get("account_id", ""), args.get("direction", "outgoing")),
    "check_sanctions_list": lambda args: env.check_sanctions_list(args.get("entity_id", "")),
    "submit_sar": lambda args: env.submit_sar(args.get("entity_id", ""), args.get("reason", "Suspicious activity")),
}

VALID_TOOLS = set(TOOL_MAPPING.keys())

# Tool name corrections for hallucination handling
TOOL_NAME_CORRECTIONS = {
    "get_account": "get_account_summary",
    "get_transactions": "get_recent_transactions",
    "check_sanctions": "check_sanctions_list",
    "submit_report": "submit_sar",
}

def harden_tool_call(call: dict) -> dict:
    """Correct common tool name hallucinations."""
    if not call:
        return call
    name = call.get("name", "").lower().strip()
    return {"name": TOOL_NAME_CORRECTIONS.get(name, name), "arguments": call.get("arguments", {})}

def extract_thinking(text: str) -> Optional[str]:
    """Extract internal reasoning from <thinking> tags."""
    match = re.search(r"<thinking>(.*?)</thinking>", text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else None

def extract_function_call(text: str) -> Optional[dict]:
    """Extract FunctionGemma-style function call from model output."""
    _ = extract_thinking(text)  # Extract thinking for logging
    match = re.search(r"<start_function_call>call:(\w+)\{(.*?)\}</start_function_call>", text, re.DOTALL)
    if match:
        name = match.group(1)
        args = {}
        for arg_match in re.finditer(r"(\w+):\s*([^,}]+)", match.group(2)):
            args[arg_match.group(1)] = arg_match.group(2).strip().strip('"\'')
        return harden_tool_call({"name": name, "arguments": args})
    return None

def format_function_output(tool_name: str, result: Any) -> str:
    """Format tool result in FunctionGemma output format."""
    return f"<start_function_output>call:{tool_name}{{{json.dumps(result, default=str)}}}</start_function_output>"

def execute_tool_call(call: dict) -> Tuple[str, Any]:
    """Execute a tool call and return (tool_name, result)."""
    if not call:
        return ("error", {"error": "No valid tool call"})
    tool_name = call.get("name", "")
    if tool_name not in TOOL_MAPPING:
        return ("error", {"error": f"Unknown tool: {tool_name}"})
    try:
        return (tool_name, TOOL_MAPPING[tool_name](call.get("arguments", {})))
    except Exception as e:
        return ("error", {"error": str(e)})

print("‚úì FunctionGemma tool format configured (with <thinking> support)")
print(f"  Available tools: {list(VALID_TOOLS)}")


## 5. Load Base Model with Unsloth

Load FunctionGemma (Gemma-2-2B-IT) with Unsloth for optimized 4-bit inference.


In [None]:
# ============================================================================
# LOAD MODEL WITH UNSLOTH - Optimized 4-bit Loading
# ============================================================================

from unsloth import FastLanguageModel

print("üì• Loading model with Unsloth...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=8192,
    dtype=None,
    load_in_4bit=True,
    device_map="cuda:0",  # Explicit GPU mapping to prevent CPU offload
)

FastLanguageModel.for_inference(model)

print(f"\n{'=' * 60}")
print(f"ü§ñ MODEL LOADED")
print(f"{'=' * 60}")
print(f"  Model:            {MODEL_NAME}")
print(f"  Parameters:       {model.num_parameters():,}")
print(f"  Max Seq Length:   8,192")
print(f"  Quantization:     4-bit")
print(f"  Device:           {next(model.parameters()).device}")
print(f"{'=' * 60}")


## 6. Agent State and Execution Logic

Define the `InvestigationState`, `InvestigationEpisode`, and agent execution functions with reward calculation.


In [None]:
# ============================================================================
# AGENT STATE AND EXECUTION LOGIC
# ============================================================================

class InvestigationState(TypedDict):
    """Complete state for the investigation agent."""
    start_account: str
    accounts_analyzed: Dict[str, dict]
    entities_checked: Dict[str, bool]
    risk_indicators: List[str]
    investigation_path: List[str]
    total_amount_traced: float
    current_strategy: str
    messages: List[dict]
    steps: List[dict]
    step_count: int
    terminated: bool
    success: bool
    final_result: dict


def create_initial_state(start_account: str) -> InvestigationState:
    return InvestigationState(
        start_account=start_account, accounts_analyzed={}, entities_checked={},
        risk_indicators=[], investigation_path=[], total_amount_traced=0.0,
        current_strategy="explore", messages=[], steps=[], step_count=0,
        terminated=False, success=False, final_result={},
    )


def build_prompt(state: InvestigationState) -> str:
    system_prompt = INVESTIGATION_PROMPT.format(tool_declarations=TOOL_DECLARATIONS)
    prompt = f"<start_of_turn>user\n{system_prompt}\n\nINVESTIGATION TARGET: {state['start_account']}\n\n"
    prompt += f"CURRENT STATUS:\n- Accounts analyzed: {len(state['accounts_analyzed'])}\n"
    prompt += f"- Path: {' ‚Üí '.join(state['investigation_path'][-5:]) if state['investigation_path'] else '(none)'}\n"
    prompt += f"- Amount traced: ${state['total_amount_traced']:,.2f}\n"
    if state['risk_indicators']:
        prompt += f"- Risk indicators: {'; '.join(state['risk_indicators'][-3:])}\n"
    prompt += "\nBegin investigation.<end_of_turn>\n"
    
    for msg in state['messages'][-MAX_HISTORY_TURNS*2:]:
        if msg['role'] == 'assistant':
            prompt += f"<start_of_turn>model\n{msg['content']}<end_of_turn>\n"
        elif msg['role'] == 'tool':
            prompt += f"<start_of_turn>user\n{msg['content']}<end_of_turn>\n"
    prompt += "<start_of_turn>model\n"
    return prompt


def update_state_from_result(state: InvestigationState, tool_name: str, args: dict, result: Any) -> InvestigationState:
    new_state = dict(state)
    new_state['accounts_analyzed'] = dict(state['accounts_analyzed'])
    new_state['entities_checked'] = dict(state['entities_checked'])
    new_state['risk_indicators'] = list(state['risk_indicators'])
    new_state['investigation_path'] = list(state['investigation_path'])
    
    if tool_name == "get_account_summary":
        acc_id = args.get("account_id", "")
        if isinstance(result, dict) and "error" not in result:
            new_state['accounts_analyzed'][acc_id] = result
            if acc_id not in new_state['investigation_path']:
                new_state['investigation_path'].append(acc_id)
            if result.get("risk_score", 0) > 0.7:
                new_state['risk_indicators'].append(f"HIGH_RISK:{acc_id}")
            if result.get("is_sanctioned"):
                new_state['risk_indicators'].append(f"SANCTIONED:{acc_id}")
            if result.get("transitive_illicit"):
                new_state['risk_indicators'].append(f"ILLICIT_PATH:{acc_id}")
    
    elif tool_name == "get_recent_transactions":
        if isinstance(result, list):
            for txn in result:
                new_state['total_amount_traced'] += txn.get("amount", 0)
                counterparty = txn.get("counterparty")
                if counterparty and counterparty not in new_state['investigation_path']:
                    new_state['investigation_path'].append(counterparty)
                if txn.get("high_risk_indicator"):
                    new_state['risk_indicators'].append(f"ILLICIT_TXN:{counterparty}")
    
    elif tool_name == "check_sanctions_list":
        entity_id = args.get("entity_id", "")
        is_sanctioned = result.get("on_sanctions_list", False) if isinstance(result, dict) else False
        new_state['entities_checked'][entity_id] = is_sanctioned
        if is_sanctioned:
            new_state['risk_indicators'].append(f"SANCTIONED:{entity_id}")
    
    return new_state


@dataclass
class InvestigationEpisode:
    """Records an investigation episode for evaluation and training."""
    start_account: str
    steps: List[dict] = field(default_factory=list)
    terminated: bool = False
    success: bool = False
    final_result: dict = field(default_factory=dict)
    total_reward: float = 0.0
    
    def add_step(self, tool_name: str, args: dict, result: Any, reward: float = 0.0):
        self.steps.append({"step": len(self.steps) + 1, "tool_name": tool_name, "arguments": args, "result": result, "reward": reward})
        self.total_reward += reward


def calculate_reward(tool_name: str, args: dict, result: Any, state: InvestigationState) -> float:
    """GRPO reward function based on design doc Section 4.2."""
    reward = -0.1  # Base step penalty
    
    if tool_name == "get_account_summary":
        if isinstance(result, dict) and result.get("transitive_illicit"):
            reward += 0.5  # R_Discovery
    elif tool_name == "get_recent_transactions":
        if any("HIGH_RISK" in r or "ILLICIT" in r for r in state['risk_indicators']):
            reward += 0.3  # R_Logic
        if isinstance(result, list):
            for txn in result:
                if txn.get("high_risk_indicator"):
                    reward += 0.2
    elif tool_name == "check_sanctions_list":
        if isinstance(result, dict) and result.get("on_sanctions_list"):
            reward += 0.5
    elif tool_name == "submit_sar":
        if isinstance(result, dict):
            if result.get("correct_identification"):
                reward += 2.0  # R_Outcome
            else:
                reward -= 1.0
    return reward


print("‚úì Agent state and execution logic defined")


In [None]:
# ============================================================================
# AGENT EXECUTION - Run Investigation Episode with Detailed Logging
# ============================================================================

def run_investigation(
    start_account: str, 
    model, 
    tokenizer, 
    max_steps: int = MAX_STEPS, 
    verbose: bool = False,
    show_thinking: bool = True,
    show_memory: bool = True,
    show_raw_response: bool = False
) -> InvestigationEpisode:
    """
    Run a complete investigation episode using the FunctionGemma agent.
    
    Args:
        verbose: Show step-by-step execution details
        show_thinking: Display agent's <thinking> reasoning
        show_memory: Display memory context (accounts analyzed, risk indicators)
        show_raw_response: Display raw model output (for debugging)
    """
    env.reset_investigation(start_account)
    state = create_initial_state(start_account)
    episode = InvestigationEpisode(start_account=start_account)
    
    if verbose:
        print(f"\n{'‚ïê' * 70}")
        print(f"üîç INVESTIGATION START")
        print(f"{'‚ïê' * 70}")
        print(f"  Target Account: {start_account}")
        print(f"  Max Steps: {max_steps}")
        print(f"{'‚îÄ' * 70}")
    
    for step_num in range(max_steps):
        # Build prompt from current state
        prompt = build_prompt(state)
        prompt_tokens = len(tokenizer.encode(prompt))
        
        if verbose and show_memory:
            print(f"\n‚îå‚îÄ STEP {step_num + 1} {'‚îÄ' * 55}‚îê")
            print(f"‚îÇ üìä MEMORY CONTEXT:")
            print(f"‚îÇ   Accounts Analyzed: {len(state['accounts_analyzed'])}")
            print(f"‚îÇ   Path: {' ‚Üí '.join(state['investigation_path'][-4:]) if state['investigation_path'] else '(empty)'}")
            print(f"‚îÇ   Amount Traced: ${state['total_amount_traced']:,.2f}")
            print(f"‚îÇ   Risk Indicators: {len(state['risk_indicators'])} ({state['risk_indicators'][-3:] if state['risk_indicators'] else 'none'})")
            print(f"‚îÇ   Strategy: {state['current_strategy'].upper()}")
            print(f"‚îÇ   Prompt Tokens: {prompt_tokens:,}")
        
        # Generate model response
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=7000).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,  # Increased for thinking
                temperature=0.7,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
            )
        
        response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=False)
        if "<end_of_turn>" in response:
            response = response.split("<end_of_turn>")[0]
        
        # Extract thinking and function call
        thinking = extract_thinking(response)
        tool_call = extract_function_call(response)
        used_fallback = False
        
        if verbose and show_raw_response:
            print(f"‚îÇ üî§ RAW RESPONSE:")
            for line in response[:500].split('\n'):
                print(f"‚îÇ   {line[:65]}")
            if len(response) > 500:
                print(f"‚îÇ   ... (truncated, {len(response)} chars total)")
        
        if verbose and show_thinking and thinking:
            print(f"‚îÇ üí≠ AGENT THINKING:")
            for line in thinking.split('\n')[:5]:
                print(f"‚îÇ   {line[:65]}")
        
        if not tool_call:
            used_fallback = True
            # Smart fallback strategy that makes real progress
            
            # Check if we found a sanctioned entity - submit SAR immediately
            sanctioned_found = [entity for entity, is_sanct in state['entities_checked'].items() if is_sanct]
            if sanctioned_found:
                tool_call = {"name": "submit_sar", "arguments": {
                    "entity_id": sanctioned_found[0], 
                    "reason": f"Sanctioned entity found on laundering path from {start_account}"
                }}
            # Step 0: Start with account summary
            elif state['step_count'] == 0 or not state['accounts_analyzed']:
                tool_call = {"name": "get_account_summary", "arguments": {"account_id": start_account}}
            # Step 1: Get transactions to find connected accounts
            elif len(state['accounts_analyzed']) == 1 and not any("TXN" in r for r in state['risk_indicators']):
                tool_call = {"name": "get_recent_transactions", "arguments": {"account_id": start_account}}
            else:
                # Priority: Check sanctions on accounts in path (high-risk first)
                unchecked = [a for a in state['investigation_path'] if a not in state['entities_checked']]
                
                # Prioritize high-risk accounts
                high_risk_unchecked = [a for a in unchecked if any(f"HIGH_RISK:{a}" in r or f"ILLICIT" in r for r in state['risk_indicators'])]
                
                if high_risk_unchecked:
                    tool_call = {"name": "check_sanctions_list", "arguments": {"entity_id": high_risk_unchecked[0]}}
                elif unchecked:
                    tool_call = {"name": "check_sanctions_list", "arguments": {"entity_id": unchecked[0]}}
                else:
                    # Explore next account in path
                    unexplored = [a for a in state['investigation_path'] if a not in state['accounts_analyzed']]
                    if unexplored:
                        tool_call = {"name": "get_account_summary", "arguments": {"account_id": unexplored[0]}}
                    elif state['investigation_path']:
                        # Get transactions from last explored account
                        for acc in reversed(state['investigation_path'][-5:]):
                            if acc in state['accounts_analyzed']:
                                tool_call = {"name": "get_recent_transactions", "arguments": {"account_id": acc}}
                                break
                    
                    if not tool_call:
                        if verbose:
                            print(f"‚îÇ ‚ö†Ô∏è No valid action - terminating")
                        break
        
        # Execute tool call
        tool_name, result = execute_tool_call(tool_call)
        args = tool_call.get("arguments", {})
        
        if verbose:
            fallback_indicator = " (FALLBACK)" if used_fallback else ""
            print(f"‚îÇ üîß TOOL CALL{fallback_indicator}:")
            print(f"‚îÇ   Function: {tool_name}")
            print(f"‚îÇ   Arguments: {json.dumps(args, default=str)[:60]}")
            
            # Show result summary based on tool type
            if tool_name == "get_account_summary" and isinstance(result, dict):
                illicit = "üî¥ ILLICIT" if result.get("transitive_illicit") else ""
                sanct = "‚ö†Ô∏è SANCTIONED" if result.get("is_sanctioned") else ""
                print(f"‚îÇ   Result: {result.get('account_type', 'Unknown')} | Risk: {result.get('risk_score', 0):.2f} {illicit} {sanct}")
            elif tool_name == "get_recent_transactions" and isinstance(result, list):
                print(f"‚îÇ   Result: {len(result)} transactions found")
                for txn in result[:3]:
                    risk = "üî¥" if txn.get("high_risk_indicator") else ""
                    print(f"‚îÇ     ‚Üí {txn.get('counterparty', '?')[:25]}: ${txn.get('amount', 0):,.2f} {risk}")
            elif tool_name == "check_sanctions_list" and isinstance(result, dict):
                status = "‚ö†Ô∏è ON SANCTIONS LIST" if result.get("on_sanctions_list") else "‚úì Clear"
                print(f"‚îÇ   Result: {status}")
            elif tool_name == "submit_sar" and isinstance(result, dict):
                status = "‚úÖ CORRECT" if result.get("correct_identification") else "‚ùå INCORRECT"
                print(f"‚îÇ   Result: {status}")
                print(f"‚îÇ   Reason: {result.get('evaluation_reason', '')[:50]}")
            elif isinstance(result, dict) and "error" in result:
                print(f"‚îÇ   Result: ‚ö†Ô∏è ERROR - {result.get('error', '')[:50]}")
        
        # Calculate reward
        reward = calculate_reward(tool_name, args, result, state)
        
        if verbose:
            reward_color = "üü¢" if reward > 0 else ("üî¥" if reward < 0 else "‚ö™")
            print(f"‚îÇ üí∞ REWARD: {reward_color} {reward:+.2f}")
            print(f"‚îî{'‚îÄ' * 68}‚îò")
        
        # Record step
        episode.add_step(tool_name, args, result, reward)
        
        # Update state
        state = update_state_from_result(state, tool_name, args, result)
        state['step_count'] = step_num + 1
        state['messages'] = list(state['messages']) + [
            {"role": "assistant", "content": response},
            {"role": "tool", "content": format_function_output(tool_name, result)},
        ]
        
        # Check for terminal action (SAR submission)
        if tool_name == "submit_sar":
            episode.terminated = True
            episode.success = result.get("correct_identification", False)
            episode.final_result = result
            break
    
    if not episode.terminated:
        episode.terminated = True
        episode.success = False
    
    if verbose:
        print(f"\n{'‚ïê' * 70}")
        print(f"üìã INVESTIGATION COMPLETE")
        print(f"{'‚ïê' * 70}")
        print(f"  Result: {'‚úÖ SUCCESS' if episode.success else '‚ùå FAILED'}")
        print(f"  Steps: {len(episode.steps)}")
        print(f"  Total Reward: {episode.total_reward:+.2f}")
        if episode.final_result:
            print(f"  SAR Reason: {episode.final_result.get('evaluation_reason', 'N/A')}")
        print(f"{'‚ïê' * 70}")
    
    return episode


print("‚úì Agent execution function defined (verbose mode available)")


## 7. LLM-as-Judge Evaluation with Gemini

Evaluate agent performance using Gemini as an LLM judge with integrated MLflow tracing.

**Rubric:**
1. **Strategy Quality** (0-10): Did the agent prioritize high-value/high-risk transfers?
2. **Decision Persistence** (0-10): Did it pivot correctly after hitting dead ends?
3. **Outcome** (0-10): Did the agent correctly identify a sanctioned entity?


In [None]:
# ============================================================================
# LLM-AS-JUDGE EVALUATION - Gemini-based Scoring
# Using google.genai library (google.generativeai is deprecated)
# ============================================================================

from google import genai

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
if GEMINI_API_KEY:
    gemini_client = genai.Client(api_key=GEMINI_API_KEY)
    print("‚úì Gemini API configured (google.genai)")
else:
    gemini_client = None
    print("‚ö†Ô∏è GEMINI_API_KEY not set - LLM-as-Judge will be disabled")


JUDGE_PROMPT = """You are an expert evaluator of AML investigations.

## Investigation Trace
{trace}

## Evaluation Rubric
1. **Strategy Quality** (0-10): Did the agent prioritize high-value and high-risk transfers?
2. **Decision Persistence** (0-10): Did the agent correctly pivot after hitting dead ends?
3. **Outcome Quality** (0-10): Did the agent correctly identify a sanctioned entity?

## Response Format
Provide your evaluation as JSON:
{{"strategy_score": <0-10>, "persistence_score": <0-10>, "outcome_score": <0-10>, "overall_score": <0-10>, "reasoning": "<brief explanation>"}}
"""


@mlflow.trace(span_type="LLM_JUDGE")
def evaluate_episode_with_llm(episode: InvestigationEpisode) -> dict:
    """Evaluate an investigation episode using Gemini as LLM judge."""
    default_scores = {
        "strategy_score": 0, 
        "persistence_score": 0, 
        "outcome_score": 10 if episode.success else 0,
        "overall_score": 5 if episode.success else 0, 
        "reasoning": "Default score"
    }
    
    if not gemini_client:
        default_scores["reasoning"] = "LLM evaluation disabled (no API key)"
        return default_scores
    
    # Build investigation trace
    trace_lines = [f"Start Account: {episode.start_account}", ""]
    for step in episode.steps:
        tool_name = step.get('tool_name', 'unknown')
        trace_lines.append(f"Step {step['step']}: {tool_name}")
        trace_lines.append(f"  Args: {json.dumps(step.get('arguments', {}), default=str)[:80]}")
        
        # More detailed result summary based on tool type
        result = step.get('result', {})
        if tool_name == "get_account_summary" and isinstance(result, dict):
            trace_lines.append(f"  Result: Type={result.get('account_type', 'Unknown')}, Risk={result.get('risk_score', 0):.2f}, Illicit={result.get('transitive_illicit', False)}")
        elif tool_name == "get_recent_transactions" and isinstance(result, list):
            trace_lines.append(f"  Result: {len(result)} transactions (high_risk: {sum(1 for t in result if t.get('high_risk_indicator'))})")
        elif tool_name == "check_sanctions_list" and isinstance(result, dict):
            trace_lines.append(f"  Result: Sanctioned={result.get('on_sanctions_list', False)}")
        elif tool_name == "submit_sar" and isinstance(result, dict):
            trace_lines.append(f"  Result: Correct={result.get('correct_identification', False)}, Reason={result.get('evaluation_reason', 'N/A')[:50]}")
        else:
            trace_lines.append(f"  Result: {json.dumps(result, default=str)[:100]}")
        trace_lines.append(f"  Reward: {step.get('reward', 0):+.2f}")
    
    trace_lines.append(f"\nFinal: {'SUCCESS' if episode.success else 'FAILED'} | Steps: {len(episode.steps)} | Total Reward: {episode.total_reward:+.2f}")
    
    trace_text = "\n".join(trace_lines)
    
    try:
        response = gemini_client.models.generate_content(
            model=GEMINI_MODEL, 
            contents=JUDGE_PROMPT.format(trace=trace_text)
        )
        
        # Extract JSON from response
        json_match = re.search(r'\{[^{}]*\}', response.text, re.DOTALL)
        if json_match:
            scores = json.loads(json_match.group())
            # Ensure all required fields are present
            for key in ["strategy_score", "persistence_score", "outcome_score", "overall_score"]:
                if key not in scores:
                    scores[key] = 0
            if "reasoning" not in scores:
                scores["reasoning"] = "No reasoning provided"
            return scores
        else:
            default_scores["reasoning"] = f"Failed to parse JSON from: {response.text[:100]}"
            return default_scores
            
    except Exception as e:
        default_scores["reasoning"] = f"API error: {str(e)[:50]}"
        return default_scores


In [None]:
# ============================================================================
# EVALUATION FUNCTION - Run Multiple Episodes with Detailed Logging
# ============================================================================

def run_evaluation(
    model, 
    tokenizer, 
    stage_name: str, 
    n_episodes: int = EVAL_EPISODES, 
    use_llm_judge: bool = True, 
    verbose: bool = True,
    show_episode_details: bool = True,  # Show step-by-step for each episode
    show_thinking: bool = True,
    show_memory: bool = True,
    show_raw_response: bool = False     # Show raw model output for debugging
) -> pd.DataFrame:
    """
    Run evaluation episodes and collect metrics. Returns DataFrame with results.
    
    Args:
        verbose: Show progress and summaries
        show_episode_details: Show step-by-step details for each episode
        show_thinking: Show agent's <thinking> reasoning
        show_memory: Show memory context at each step
    """
    # Use pattern seed accounts for guaranteed laundering paths
    eval_seeds = [p.seed_account for p in random.sample(laundering_patterns, min(n_episodes, len(laundering_patterns))) if p.seed_account]
    
    results = []
    
    print(f"\n{'‚ïê' * 70}")
    print(f"üß™ EVALUATION: {stage_name.upper()}")
    print(f"{'‚ïê' * 70}")
    print(f"  Episodes: {n_episodes}")
    print(f"  Max Steps: {MAX_STEPS}")
    print(f"  LLM Judge: {'Enabled' if use_llm_judge and gemini_client else 'Disabled'}")
    print(f"{'‚ïê' * 70}")
    
    with mlflow.start_run(run_name=f"eval_{stage_name}"):
        for i, seed in enumerate(eval_seeds[:n_episodes]):
            print(f"\n{'‚ñì' * 70}")
            print(f"  EPISODE {i+1}/{n_episodes}")
            print(f"  Seed: {seed}")
            print(f"{'‚ñì' * 70}")
            
            # Run investigation with verbose output if requested
            episode = run_investigation(
                seed, model, tokenizer, 
                verbose=show_episode_details,
                show_thinking=show_thinking,
                show_memory=show_memory,
                show_raw_response=show_raw_response
            )
            
            # LLM Judge evaluation
            if use_llm_judge and gemini_client:
                print(f"\n  ü§ñ LLM-AS-JUDGE EVALUATION...")
                llm_scores = evaluate_episode_with_llm(episode)
                print(f"     Strategy:    {llm_scores.get('strategy_score', 0)}/10")
                print(f"     Persistence: {llm_scores.get('persistence_score', 0)}/10")
                print(f"     Outcome:     {llm_scores.get('outcome_score', 0)}/10")
                print(f"     Overall:     {llm_scores.get('overall_score', 0)}/10")
                if llm_scores.get('reasoning'):
                    print(f"     Reasoning:   {llm_scores.get('reasoning', '')[:60]}...")
            else:
                llm_scores = {
                    "strategy_score": 0, 
                    "persistence_score": 0, 
                    "outcome_score": 10 if episode.success else 0, 
                    "overall_score": 5 if episode.success else 0,
                    "reasoning": "LLM judge disabled"
                }
            
            result = {
                "seed_account": seed, 
                "success": episode.success, 
                "steps": len(episode.steps),
                "total_reward": episode.total_reward, 
                **llm_scores
            }
            results.append(result)
            
            # Episode summary
            print(f"\n  ‚îå‚îÄ EPISODE {i+1} RESULT {'‚îÄ' * 44}‚îê")
            print(f"  ‚îÇ Outcome:      {'‚úÖ SUCCESS' if episode.success else '‚ùå FAILED':<20} ‚îÇ")
            print(f"  ‚îÇ Steps:        {len(episode.steps):<20} ‚îÇ")
            print(f"  ‚îÇ Total Reward: {episode.total_reward:+.2f}{' ' * 17} ‚îÇ")
            print(f"  ‚îÇ LLM Score:    {llm_scores.get('overall_score', 0)}/10{' ' * 16} ‚îÇ")
            print(f"  ‚îî{'‚îÄ' * 53}‚îò")
        
        df = pd.DataFrame(results)
        
        # Log to MLflow
        mlflow.log_metrics({
            f"{stage_name}_success_rate": df['success'].mean(),
            f"{stage_name}_avg_steps": df['steps'].mean(),
            f"{stage_name}_avg_reward": df['total_reward'].mean(),
            f"{stage_name}_avg_score": df['overall_score'].mean(),
        })
    
    # Final Summary
    print(f"\n{'‚ïê' * 70}")
    print(f"üìä {stage_name.upper()} - FINAL SUMMARY")
    print(f"{'‚ïê' * 70}")
    print(f"  Episodes Run:     {len(results)}")
    print(f"  Success Rate:     {df['success'].mean()*100:.1f}% ({df['success'].sum()}/{len(results)})")
    print(f"  Avg Steps:        {df['steps'].mean():.1f}")
    print(f"  Avg Reward:       {df['total_reward'].mean():+.2f}")
    print(f"  Avg LLM Score:    {df['overall_score'].mean():.1f}/10")
    print(f"{'‚îÄ' * 70}")
    print(f"  Score Breakdown:")
    print(f"    Strategy:       {df['strategy_score'].mean():.1f}/10")
    print(f"    Persistence:    {df['persistence_score'].mean():.1f}/10")
    print(f"    Outcome:        {df['outcome_score'].mean():.1f}/10")
    print(f"{'‚ïê' * 70}")
    
    return df


print("‚úì Evaluation framework configured (detailed logging enabled)")


---

## 8. STAGE 1: Baseline Evaluation (Pre-Training)

Run evaluation on the **base Gemma-2-2B-IT model** before any fine-tuning to establish baseline performance.


In [None]:
# ============================================================================
# CLEANUP - Remove Previous Training Outputs (for notebook re-runs)
# ============================================================================

import shutil

def cleanup_training_outputs():
    """Remove previous training outputs to ensure clean re-runs."""
    cleanup_dirs = [
        MODELS_DIR / "sft_output",
        MODELS_DIR / "sft_adapter", 
        MODELS_DIR / "grpo_output",
        MODELS_DIR / "grpo_adapter",
        MODELS_DIR / "aml_agent_final",
    ]
    
    print("üßπ Cleaning up previous training outputs...")
    
    for dir_path in cleanup_dirs:
        if dir_path.exists():
            try:
                shutil.rmtree(dir_path)
                print(f"   ‚úì Removed: {dir_path.name}")
            except Exception as e:
                print(f"   ‚ö† Could not remove {dir_path.name}: {e}")
        else:
            print(f"   - Not found: {dir_path.name} (skipping)")
    
    print("‚úì Cleanup complete\n")

# Run cleanup
cleanup_training_outputs()


In [None]:
# ============================================================================
# STAGE 1: BASELINE EVALUATION - Pre-Training Performance
# ============================================================================

print("üîç STAGE 1: BASELINE EVALUATION")
print("   Testing base Gemma-2-2B-IT model (no fine-tuning)")

# Initialize results dictionary for all stages
all_results = {}

# Ensure model is in inference mode
FastLanguageModel.for_inference(model)

# Configuration for evaluation verbosity
# Set show_raw_response=True to see what the model actually generates
SHOW_EPISODE_DETAILS = True  # Show step-by-step execution
SHOW_THINKING = True         # Show agent's <thinking> reasoning
SHOW_MEMORY = True           # Show memory context
SHOW_RAW_RESPONSE = False    # Show raw model output (useful for debugging)

# Run baseline evaluation with detailed logging
baseline_results = run_evaluation(
    model=model,
    tokenizer=tokenizer,
    stage_name="Baseline",
    n_episodes=EVAL_EPISODES,
    use_llm_judge=bool(gemini_client),
    verbose=True,
    show_episode_details=SHOW_EPISODE_DETAILS,
    show_thinking=SHOW_THINKING,
    show_memory=SHOW_MEMORY
)

# Store results for comparison
all_results["Baseline"] = baseline_results

print("\nüìä Baseline Results Summary:")
print(baseline_results[['success', 'steps', 'total_reward', 'overall_score']].describe())


---

## 9. SFT Training Data Generation

Generate Supervised Fine-Tuning samples from laundering patterns with internal reasoning (`<thinking>` tags).


In [None]:
# ============================================================================
# SFT TRAINING DATA GENERATION - Multi-turn Conversations with Thinking
# ============================================================================

def generate_sft_sample(pattern: LaunderingPattern, include_backtrack: bool = False) -> List[dict]:
    """Generate a multi-turn SFT training sample with <thinking> tags."""
    messages = []
    seed_account = pattern.seed_account
    terminal_account = pattern.terminal_account
    
    if not seed_account or not terminal_account:
        return []
    
    system_prompt = INVESTIGATION_PROMPT.format(tool_declarations=TOOL_DECLARATIONS)
    messages.append({"role": "user", "content": f"{system_prompt}\n\nINVESTIGATION TARGET: {seed_account}\n\nBegin investigation."})
    
    # Step 1: Get account summary with thinking
    messages.append({"role": "assistant", "content": f"""<thinking>
I need to start by getting the account summary for {seed_account} to understand its risk profile.
</thinking>
<start_function_call>call:get_account_summary{{account_id: {seed_account}}}</start_function_call>"""})
    messages.append({"role": "user", "content": f"TOOL RESULT: {json.dumps({'account_id': seed_account, 'account_type': 'Corporate', 'risk_score': 0.75, 'transitive_illicit': True})}"})
    
    # Step 2: Get transactions with thinking
    messages.append({"role": "assistant", "content": f"""<thinking>
The account shows transitive_illicit=True. I should trace outgoing transactions to follow the money flow.
</thinking>
<start_function_call>call:get_recent_transactions{{account_id: {seed_account}, direction: outgoing}}</start_function_call>"""})
    txn_results = [{"counterparty": t['to_account'], "amount": t['amount'], "high_risk_indicator": True} for t in pattern.transactions[:3]]
    messages.append({"role": "user", "content": f"TOOL RESULT: {json.dumps(txn_results)}"})
    
    # Step 3: Check sanctions with thinking
    messages.append({"role": "assistant", "content": f"""<thinking>
I've found high-risk counterparties. Before submitting a SAR, I MUST verify the target is sanctioned.
</thinking>
<start_function_call>call:check_sanctions_list{{entity_id: {terminal_account}}}</start_function_call>"""})
    messages.append({"role": "user", "content": f'TOOL RESULT: {{"entity_id": "{terminal_account}", "on_sanctions_list": true, "list_type": "OFAC SDN"}}'})
    
    # Step 4: Submit SAR with thinking
    reason = f"Identified {pattern.pattern_type} pattern with ${pattern.total_amount:,.2f}"
    messages.append({"role": "assistant", "content": f"""<thinking>
CONFIRMED: {terminal_account} is ON the OFAC sanctions list. I can now submit the SAR with confidence.
</thinking>
<start_function_call>call:submit_sar{{entity_id: {terminal_account}, reason: {reason}}}</start_function_call>"""})
    
    return messages


def generate_sft_dataset(patterns: List[LaunderingPattern], max_samples: int = 100) -> List[dict]:
    dataset = []
    for i, pattern in enumerate(random.sample(patterns, min(max_samples, len(patterns)))):
        messages = generate_sft_sample(pattern)
        if messages:
            dataset.append({"id": f"sft_{i}", "pattern_type": pattern.pattern_type, "messages": messages})
    return dataset


def format_for_unsloth_sft(dataset: List[dict]) -> List[dict]:
    formatted = []
    for sample in dataset:
        text = ""
        for msg in sample['messages']:
            if msg['role'] == 'user':
                text += f"<start_of_turn>user\n{msg['content']}<end_of_turn>\n"
            elif msg['role'] == 'assistant':
                text += f"<start_of_turn>model\n{msg['content']}<end_of_turn>\n"
        formatted.append({"text": text})
    return formatted


# Generate SFT dataset
print("üìä Generating SFT training data...")
sft_dataset = generate_sft_dataset(laundering_patterns, max_samples=100)
sft_formatted = format_for_unsloth_sft(sft_dataset)

print(f"‚úì Generated {len(sft_formatted)} SFT training samples")


## 10. SFT Training with Unsloth

Fine-tune with LoRA adapters (r=32, alpha=64) targeting attention and MLP layers.


In [None]:
# ============================================================================
# SFT TRAINING WITH UNSLOTH - LoRA Fine-tuning
# ============================================================================

from datasets import Dataset
from trl import SFTTrainer
from transformers import TrainingArguments

print("üì• Configuring LoRA adapters for SFT...")

# Get PEFT model with LoRA
model_for_training = FastLanguageModel.get_peft_model(
    model, r=LORA_R, target_modules=LORA_TARGET_MODULES, lora_alpha=LORA_ALPHA,
    lora_dropout=0.05, bias="none", use_gradient_checkpointing="unsloth", random_state=RANDOM_SEED,
)

# Create dataset
sft_hf_dataset = Dataset.from_list([{"text": s["text"]} for s in sft_formatted])

# Training arguments
sft_output_dir = MODELS_DIR / "sft_output"
sft_args = TrainingArguments(
    output_dir=str(sft_output_dir), per_device_train_batch_size=2, gradient_accumulation_steps=4,
    num_train_epochs=SFT_EPOCHS, learning_rate=SFT_LEARNING_RATE, warmup_ratio=0.1,
    logging_steps=10, save_strategy="epoch", fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(), optim="adamw_8bit", seed=RANDOM_SEED, report_to="none",
)

# Create trainer
sft_trainer = SFTTrainer(
    model=model_for_training, tokenizer=tokenizer, train_dataset=sft_hf_dataset,
    args=sft_args, max_seq_length=4096,
)

print(f"\n{'=' * 60}")
print(f"üéì STARTING SFT TRAINING")
print(f"{'=' * 60}")
print(f"  Epochs: {SFT_EPOCHS} | LR: {SFT_LEARNING_RATE} | LoRA r={LORA_R}")
print(f"{'=' * 60}")

# Train
sft_trainer.train()

# Save adapter
sft_adapter_path = MODELS_DIR / "sft_adapter"
model_for_training.save_pretrained(str(sft_adapter_path))
tokenizer.save_pretrained(str(sft_adapter_path))

print(f"\n‚úì SFT training complete! Adapter saved to: {sft_adapter_path}")


## 11. STAGE 2: Post-SFT Evaluation

Evaluate the **SFT-tuned model** to measure improvement from supervised fine-tuning.


In [None]:
# ============================================================================
# STAGE 2: POST-SFT EVALUATION
# ============================================================================

print("üîç STAGE 2: POST-SFT EVALUATION")
print("   Testing model after Supervised Fine-Tuning")

# Switch to inference mode
FastLanguageModel.for_inference(model_for_training)

# Run post-SFT evaluation with detailed logging
post_sft_results = run_evaluation(
    model=model_for_training,
    tokenizer=tokenizer,
    stage_name="Post-SFT",
    n_episodes=EVAL_EPISODES,
    use_llm_judge=bool(gemini_client),
    verbose=True,
    show_episode_details=SHOW_EPISODE_DETAILS,
    show_thinking=SHOW_THINKING,
    show_memory=SHOW_MEMORY
)

# Store results for comparison
all_results["Post-SFT"] = post_sft_results

# Show improvement over baseline
baseline_sr = all_results["Baseline"]['success'].mean()
post_sft_sr = post_sft_results['success'].mean()
print(f"\nüìà SFT Improvement: {baseline_sr*100:.1f}% ‚Üí {post_sft_sr*100:.1f}% ({(post_sft_sr-baseline_sr)*100:+.1f}%)")

print("\nüìä Post-SFT Results Summary:")
print(post_sft_results[['success', 'steps', 'total_reward', 'overall_score']].describe())


---

## 12. GRPO Training - Reinforcement Learning

Train with **Group Relative Policy Optimization (GRPO)** using TRL:
- **R_Discovery (+0.5)**: Discovering transitive_illicit nodes
- **R_Logic (+0.3)**: Correct tool sequencing (transactions after high-risk)
- **R_Outcome (+2.0)**: Correct SAR submission
- **R_Efficiency (-0.1)**: Step penalty to prevent loops


In [None]:
# ============================================================================
# GRPO TRAINING - Group Relative Policy Optimization
# ============================================================================

from trl import GRPOTrainer, GRPOConfig

def generate_grpo_prompts(patterns: List[LaunderingPattern], n_prompts: int = 50) -> List[str]:
    """Generate prompts for GRPO training from laundering patterns."""
    prompts = []
    selected = random.sample(patterns, min(n_prompts, len(patterns)))
    
    for pattern in selected:
        seed = pattern.seed_account
        if seed:
            system = INVESTIGATION_PROMPT.format(tool_declarations=TOOL_DECLARATIONS)
            prompt = f"<start_of_turn>user\n{system}\n\nINVESTIGATION TARGET: {seed}\n\nBegin investigation.<end_of_turn>\n<start_of_turn>model\n"
            prompts.append(prompt)
    
    return prompts


def grpo_reward_function(completions: List[str], prompts: List[str]) -> List[float]:
    """
    GRPO reward function that evaluates model completions.
    Runs each completion through the environment and calculates reward.
    """
    rewards = []
    
    for completion, prompt in zip(completions, prompts):
        # Extract seed account from prompt
        match = re.search(r'INVESTIGATION TARGET:\s*([^\n]+)', prompt)
        if not match:
            rewards.append(-1.0)
            continue
        
        seed_account = match.group(1).strip()
        total_reward = 0.0
        
        # Reset environment
        env.reset_investigation(seed_account)
        state = create_initial_state(seed_account)
        
        # Extract and execute function call from completion
        tool_call = extract_function_call(completion)
        
        if tool_call:
            tool_name, result = execute_tool_call(tool_call)
            args = tool_call.get("arguments", {})
            total_reward = calculate_reward(tool_name, args, result, state)
        else:
            total_reward = -0.5  # Penalty for invalid output
        
        rewards.append(total_reward)
    
    return rewards


# Generate GRPO training prompts
print("üìä Generating GRPO training prompts...")
grpo_prompts = generate_grpo_prompts(laundering_patterns, n_prompts=50)
grpo_dataset = Dataset.from_dict({"prompt": grpo_prompts})
print(f"‚úì Generated {len(grpo_prompts)} GRPO training prompts")

# GRPO Training Configuration
grpo_output_dir = MODELS_DIR / "grpo_output"
grpo_config = GRPOConfig(
    output_dir=str(grpo_output_dir),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=GRPO_EPOCHS,
    learning_rate=GRPO_LEARNING_RATE,
    logging_steps=5,
    save_strategy="epoch",
    report_to="none",
)

print(f"\n{'=' * 60}")
print(f"üéØ STARTING GRPO TRAINING")
print(f"{'=' * 60}")
print(f"  Epochs:        {GRPO_EPOCHS}")
print(f"  Learning Rate: {GRPO_LEARNING_RATE}")
print(f"  Prompts:       {len(grpo_prompts)}")
print(f"{'=' * 60}")

# Create GRPO Trainer
grpo_trainer = GRPOTrainer(
    model=model_for_training,
    config=grpo_config,
    tokenizer=tokenizer,
    train_dataset=grpo_dataset,
    reward_funcs=grpo_reward_function,
)

# Train with GRPO
grpo_trainer.train()

# Save GRPO adapter
grpo_adapter_path = MODELS_DIR / "grpo_adapter"
model_for_training.save_pretrained(str(grpo_adapter_path))
tokenizer.save_pretrained(str(grpo_adapter_path))

print(f"\n‚úì GRPO training complete!")
print(f"‚úì Adapter saved to: {grpo_adapter_path}")


## 13. STAGE 3: Post-GRPO Evaluation

Evaluate the **GRPO-trained model** to measure improvement from reinforcement learning.


In [None]:
# ============================================================================
# STAGE 3: POST-GRPO EVALUATION
# ============================================================================

print("üîç STAGE 3: POST-GRPO EVALUATION")
print("   Testing model after GRPO Reinforcement Learning")

# Switch to inference mode
FastLanguageModel.for_inference(model_for_training)

# Run post-GRPO evaluation with detailed logging
post_grpo_results = run_evaluation(
    model=model_for_training,
    tokenizer=tokenizer,
    stage_name="Post-GRPO",
    n_episodes=EVAL_EPISODES,
    use_llm_judge=bool(gemini_client),
    verbose=True,
    show_episode_details=SHOW_EPISODE_DETAILS,
    show_thinking=SHOW_THINKING,
    show_memory=SHOW_MEMORY
)

# Store results for comparison
all_results["Post-GRPO"] = post_grpo_results

print("\nüìä Post-GRPO Results Summary:")
print(post_grpo_results[['success', 'steps', 'total_reward', 'overall_score']].describe())


---

## 14. Final Comparison: Baseline vs SFT vs GRPO

Compare metrics across all three training stages:
- **Success Rate**: Percentage of correct SAR submissions
- **Average Steps**: Efficiency of investigation
- **Average Reward**: GRPO reward function score
- **LLM-as-Judge Scores**: Strategy, Persistence, Outcome, Overall


In [None]:
# ============================================================================
# FINAL COMPARISON - Baseline vs SFT vs GRPO
# ============================================================================

print("\n" + "=" * 80)
print("üìä FINAL COMPARISON: BASELINE vs SFT vs GRPO")
print("=" * 80)

# Build comparison DataFrame
comparison_data = []
for stage_name, results_df in all_results.items():
    comparison_data.append({
        "Stage": stage_name,
        "Success Rate (%)": results_df['success'].mean() * 100,
        "Avg Steps": results_df['steps'].mean(),
        "Avg Reward": results_df['total_reward'].mean(),
        "Avg Strategy Score": results_df['strategy_score'].mean() if 'strategy_score' in results_df else 0,
        "Avg Persistence Score": results_df['persistence_score'].mean() if 'persistence_score' in results_df else 0,
        "Avg Outcome Score": results_df['outcome_score'].mean() if 'outcome_score' in results_df else 0,
        "Avg Overall Score": results_df['overall_score'].mean() if 'overall_score' in results_df else 0,
    })

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.set_index("Stage")

# Display comparison table
print("\nüìà METRICS COMPARISON:")
print("-" * 80)
print(comparison_df.round(2).to_string())
print("-" * 80)

# Calculate improvements
if len(comparison_data) >= 2:
    baseline = comparison_data[0]
    
    print("\nüìä IMPROVEMENTS OVER BASELINE:")
    print("-" * 80)
    
    for i, stage in enumerate(comparison_data[1:], 1):
        stage_name = stage["Stage"]
        success_improvement = stage["Success Rate (%)"] - baseline["Success Rate (%)"]
        reward_improvement = stage["Avg Reward"] - baseline["Avg Reward"]
        score_improvement = stage["Avg Overall Score"] - baseline["Avg Overall Score"]
        
        print(f"\n  {stage_name}:")
        print(f"    Success Rate: {success_improvement:+.1f}%")
        print(f"    Avg Reward:   {reward_improvement:+.2f}")
        print(f"    Overall Score: {score_improvement:+.1f}/10")

print("\n" + "=" * 80)

# Log to MLflow
with mlflow.start_run(run_name="final_comparison"):
    for stage_name, results_df in all_results.items():
        mlflow.log_metrics({
            f"{stage_name}_success_rate": results_df['success'].mean(),
            f"{stage_name}_avg_steps": results_df['steps'].mean(),
            f"{stage_name}_avg_reward": results_df['total_reward'].mean(),
            f"{stage_name}_avg_overall_score": results_df.get('overall_score', pd.Series([0])).mean(),
        })

print("‚úì Comparison logged to MLflow")


In [None]:
# ============================================================================
# VISUALIZATION - Performance Comparison Charts
# ============================================================================

import matplotlib.pyplot as plt

# Create comparison visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('AML Investigation Agent: Training Stage Comparison', fontsize=14, fontweight='bold')

stages = list(all_results.keys())
colors = ['#e74c3c', '#3498db', '#2ecc71'][:len(stages)]

# 1. Success Rate
ax1 = axes[0, 0]
success_rates = [all_results[s]['success'].mean() * 100 for s in stages]
bars1 = ax1.bar(stages, success_rates, color=colors, edgecolor='black', linewidth=1.2)
ax1.set_ylabel('Success Rate (%)', fontweight='bold')
ax1.set_title('SAR Submission Accuracy', fontweight='bold')
ax1.set_ylim(0, 100)
for bar, val in zip(bars1, success_rates):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, f'{val:.1f}%', 
             ha='center', va='bottom', fontweight='bold')

# 2. Average Reward
ax2 = axes[0, 1]
avg_rewards = [all_results[s]['total_reward'].mean() for s in stages]
bars2 = ax2.bar(stages, avg_rewards, color=colors, edgecolor='black', linewidth=1.2)
ax2.set_ylabel('Average Reward', fontweight='bold')
ax2.set_title('GRPO Reward Score', fontweight='bold')
ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
for bar, val in zip(bars2, avg_rewards):
    ypos = bar.get_height() + 0.1 if val >= 0 else bar.get_height() - 0.3
    ax2.text(bar.get_x() + bar.get_width()/2, ypos, f'{val:.2f}', 
             ha='center', va='bottom' if val >= 0 else 'top', fontweight='bold')

# 3. Average Steps
ax3 = axes[1, 0]
avg_steps = [all_results[s]['steps'].mean() for s in stages]
bars3 = ax3.bar(stages, avg_steps, color=colors, edgecolor='black', linewidth=1.2)
ax3.set_ylabel('Average Steps', fontweight='bold')
ax3.set_title('Investigation Efficiency', fontweight='bold')
for bar, val in zip(bars3, avg_steps):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3, f'{val:.1f}', 
             ha='center', va='bottom', fontweight='bold')

# 4. LLM-as-Judge Overall Score
ax4 = axes[1, 1]
if 'overall_score' in all_results[stages[0]].columns:
    overall_scores = [all_results[s]['overall_score'].mean() for s in stages]
    bars4 = ax4.bar(stages, overall_scores, color=colors, edgecolor='black', linewidth=1.2)
    ax4.set_ylabel('Overall Score (0-10)', fontweight='bold')
    ax4.set_title('LLM-as-Judge Score', fontweight='bold')
    ax4.set_ylim(0, 10)
    for bar, val in zip(bars4, overall_scores):
        ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2, f'{val:.1f}', 
                 ha='center', va='bottom', fontweight='bold')
else:
    ax4.text(0.5, 0.5, 'LLM-as-Judge\nNot Available', ha='center', va='center', fontsize=12)
    ax4.set_title('LLM-as-Judge Score', fontweight='bold')

plt.tight_layout()
comparison_chart_path = OUTPUT_DIR / "training_comparison.png"
plt.savefig(str(comparison_chart_path), dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Comparison chart saved to {comparison_chart_path}")


---

## 15. Save Final Model

Save the trained model adapters for deployment.


In [None]:
# ============================================================================
# SAVE FINAL MODEL
# ============================================================================

# Save final adapter to models directory
final_model_path = MODELS_DIR / "aml_agent_final"
model_for_training.save_pretrained(str(final_model_path))
tokenizer.save_pretrained(str(final_model_path))

# Save comparison results
comparison_csv_path = OUTPUT_DIR / "training_comparison.csv"
comparison_df.to_csv(str(comparison_csv_path))

print(f"\n{'=' * 60}")
print(f"‚úÖ TRAINING COMPLETE")
print(f"{'=' * 60}")
print(f"  Final Model:      {final_model_path}")
print(f"  SFT Adapter:      {sft_adapter_path}")
print(f"  GRPO Adapter:     {grpo_adapter_path}")
print(f"  Comparison Chart: {OUTPUT_DIR / 'training_comparison.png'}")
print(f"  Comparison CSV:   {comparison_csv_path}")
print(f"{'=' * 60}")

print(f"\nüéâ AML Investigation Agent Training Complete!")
print(f"\nüìä Final Performance Summary:")
print(comparison_df.round(2).to_string())

print(f"\nNext Steps:")
print(f"  1. Load adapter with Unsloth for inference")
print(f"  2. Deploy agent with LangGraph orchestration")
print(f"  3. Monitor with MLflow tracing")
