# 05 - Evaluation: Testing Model Performance

**Goal**: Evaluate the fine-tuned model's performance on the test set.

In this notebook, you'll learn:
- How to load fine-tuned models with adapters
- How to run batch inference on test data
- How to calculate accuracy metrics (amounts, addresses, protocols)
- How to create confusion matrices
- How to analyze model predictions
- How to visualize performance breakdowns

**Prerequisites**: Completed `04-fine-tuning.ipynb`, have trained model

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

import json
import sys
from pathlib import Path
from collections import Counter, defaultdict

import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datasets import load_dataset

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Import project modules
from eth_finetuning.evaluation.evaluator import (
    load_model_for_evaluation,
    run_inference,
    parse_json_output,
)
from eth_finetuning.evaluation.metrics import (
    calculate_accuracy_metrics,
    calculate_per_protocol_metrics,
    calculate_readability_score,
)
from eth_finetuning.evaluation.report import generate_evaluation_report

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print("✓ Imports successful")
print(f"✓ Project root: {project_root}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")

## Loading Test Dataset

Let's load the test set that was held out during training.

In [None]:
print("LOADING TEST DATASET")
print("=" * 80)

# Load test dataset
test_file = project_root / "data" / "datasets" / "test.jsonl"

if not test_file.exists():
    print(f"\n⚠️  Test file not found: {test_file}")
    print("   Please run notebook 03-dataset-preparation.ipynb first")
else:
    test_dataset = load_dataset('json', data_files=str(test_file))['train']
    
    print(f"\n✓ Loaded test dataset")
    print(f"  Total examples: {len(test_dataset)}")
    
    # Analyze test set composition
    protocols = []
    for example in test_dataset:
        try:
            intent = json.loads(example['output'])
            protocols.append(intent.get('protocol', 'unknown'))
        except:
            protocols.append('unknown')
    
    protocol_counts = Counter(protocols)
    print(f"\nTest set composition:")
    for protocol, count in protocol_counts.most_common():
        pct = count / len(test_dataset) * 100
        print(f"  {protocol:15s}: {count:3d} ({pct:5.1f}%)")

## Loading Fine-Tuned Model

Now let's load the model we trained in the previous notebook.

**Note**: This loads the model with the adapter merged, ready for inference.

In [None]:
print("LOADING FINE-TUNED MODEL")
print("=" * 80)

# Define model path
model_dir = project_root / "models" / "fine-tuned" / "eth-intent-notebook" / "final"

# Check if model exists
if not model_dir.exists():
    print(f"\n⚠️  Model not found at: {model_dir}")
    print("   Please train a model first using notebook 04-fine-tuning.ipynb")
    print("\n   Alternative: Use a different model path")
    # Try alternative paths
    alt_paths = [
        project_root / "models" / "fine-tuned" / "eth-intent-extractor-v1",
        project_root / "models" / "fine-tuned" / "checkpoint-latest",
    ]
    for alt_path in alt_paths:
        if alt_path.exists():
            print(f"   Found alternative: {alt_path}")
            model_dir = alt_path
            break

if model_dir.exists():
    print(f"\nLoading model from: {model_dir}")
    print("This may take a few minutes...\n")
    
    # Load model and tokenizer
    model, tokenizer = load_model_for_evaluation(str(model_dir))
    
    print("\n✓ Model loaded successfully")
    print(f"  Model type: {type(model).__name__}")
    print(f"  Tokenizer:  {type(tokenizer).__name__}")
    
    # Check VRAM usage
    if torch.cuda.is_available():
        vram = torch.cuda.memory_allocated(0) / 1024**3
        print(f"  VRAM usage: {vram:.2f} GB")
    
    # Set to evaluation mode
    model.eval()
    print("\n✓ Model set to evaluation mode")

## Running Batch Inference

Let's run inference on all test examples and collect predictions.

In [None]:
print("RUNNING BATCH INFERENCE")
print("=" * 80)
print(f"\nProcessing {len(test_dataset)} test examples...")
print("This may take several minutes\n")

predictions = []
ground_truths = []
raw_outputs = []

# Process each test example
for i, example in enumerate(test_dataset):
    if i % 10 == 0:
        print(f"  Progress: {i}/{len(test_dataset)} ({i/len(test_dataset)*100:.0f}%)")
    
    # Format prompt
    prompt = f"{example['instruction']}\n\nInput: {example['input']}\n\nOutput:"
    
    # Run inference
    output = run_inference(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        max_new_tokens=256,
        temperature=0.1,  # Low temperature for deterministic output
    )
    
    raw_outputs.append(output)
    
    # Parse prediction
    pred_intent = parse_json_output(output)
    predictions.append(pred_intent)
    
    # Parse ground truth
    try:
        truth_intent = json.loads(example['output'])
        ground_truths.append(truth_intent)
    except json.JSONDecodeError:
        ground_truths.append(None)

print(f"\n✓ Inference complete")
print(f"  Total predictions: {len(predictions)}")
print(f"  Valid predictions: {sum(1 for p in predictions if p is not None)}")
print(f"  Failed to parse:   {sum(1 for p in predictions if p is None)}")

## Analyzing Predictions

Let's examine some predictions to see how the model is performing.

In [None]:
print("SAMPLE PREDICTIONS")
print("=" * 80)

# Show first 3 predictions
for i in range(min(3, len(predictions))):
    print(f"\n{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    
    print("\nGround Truth:")
    if ground_truths[i]:
        print(json.dumps(ground_truths[i], indent=2))
    else:
        print("  (Failed to parse)")
    
    print("\nPrediction:")
    if predictions[i]:
        print(json.dumps(predictions[i], indent=2))
    else:
        print("  (Failed to parse)")
        print(f"\n  Raw output: {raw_outputs[i][:200]}...")
    
    # Compare key fields
    if predictions[i] and ground_truths[i]:
        print("\nComparison:")
        for key in ['action', 'protocol', 'outcome']:
            pred_val = predictions[i].get(key, 'N/A')
            truth_val = ground_truths[i].get(key, 'N/A')
            match = '✓' if pred_val == truth_val else '✗'
            print(f"  {match} {key:10s}: {pred_val:15s} vs {truth_val:15s}")

## Calculating Accuracy Metrics

Now let's calculate comprehensive accuracy metrics:
- **Overall accuracy**: Percentage of perfectly matched predictions
- **Amount accuracy**: Accuracy of numerical amounts (±1% tolerance)
- **Address accuracy**: Accuracy of Ethereum addresses
- **Protocol accuracy**: Protocol classification accuracy

In [None]:
print("CALCULATING ACCURACY METRICS")
print("=" * 80)

# Calculate metrics
metrics = calculate_accuracy_metrics(
    predictions=predictions,
    ground_truths=ground_truths,
)

print("\nOVERALL METRICS")
print(f"  Overall Accuracy:  {metrics['overall_accuracy']*100:6.2f}%")
print(f"  Amount Accuracy:   {metrics['amount_accuracy']*100:6.2f}%")
print(f"  Address Accuracy:  {metrics['address_accuracy']*100:6.2f}%")
print(f"  Protocol Accuracy: {metrics['protocol_accuracy']*100:6.2f}%")

# Check against targets
target_accuracy = 0.90  # 90% target from SPEC
print(f"\nTarget: {target_accuracy*100:.0f}% accuracy")

all_above_target = all([
    metrics['amount_accuracy'] >= target_accuracy,
    metrics['address_accuracy'] >= target_accuracy,
    metrics['protocol_accuracy'] >= target_accuracy,
])

if all_above_target:
    print("\n✓ All metrics meet 90% accuracy target!")
else:
    print("\n⚠️  Some metrics below 90% target:")
    if metrics['amount_accuracy'] < target_accuracy:
        print(f"   - Amount accuracy: {metrics['amount_accuracy']*100:.2f}%")
    if metrics['address_accuracy'] < target_accuracy:
        print(f"   - Address accuracy: {metrics['address_accuracy']*100:.2f}%")
    if metrics['protocol_accuracy'] < target_accuracy:
        print(f"   - Protocol accuracy: {metrics['protocol_accuracy']*100:.2f}%")

## Per-Protocol Performance

Let's break down performance by protocol to identify strengths and weaknesses.

In [None]:
print("PER-PROTOCOL PERFORMANCE")
print("=" * 80)

# Calculate per-protocol metrics
protocol_metrics = calculate_per_protocol_metrics(
    predictions=predictions,
    ground_truths=ground_truths,
)

# Display as table
protocol_data = []
for protocol, stats in protocol_metrics.items():
    protocol_data.append({
        'Protocol': protocol,
        'Count': stats['count'],
        'Accuracy': f"{stats['accuracy']*100:.1f}%",
        'Amount Acc': f"{stats.get('amount_accuracy', 0)*100:.1f}%",
        'Address Acc': f"{stats.get('address_accuracy', 0)*100:.1f}%",
    })

df_protocols = pd.DataFrame(protocol_data)
print("\n" + df_protocols.to_string(index=False))

# Identify best and worst performing protocols
accuracies = [(p, s['accuracy']) for p, s in protocol_metrics.items()]
accuracies.sort(key=lambda x: x[1], reverse=True)

print(f"\nBest performing:  {accuracies[0][0]} ({accuracies[0][1]*100:.1f}%)")
print(f"Worst performing: {accuracies[-1][0]} ({accuracies[-1][1]*100:.1f}%)")

## Confusion Matrix

Let's visualize protocol classification with a confusion matrix.

In [None]:
from sklearn.metrics import confusion_matrix

print("CONFUSION MATRIX")
print("=" * 80)

# Extract protocols for confusion matrix
pred_protocols = [p.get('protocol', 'unknown') if p else 'unknown' for p in predictions]
truth_protocols = [t.get('protocol', 'unknown') if t else 'unknown' for t in ground_truths]

# Get unique protocols
all_protocols = sorted(set(pred_protocols + truth_protocols))

# Compute confusion matrix
cm = confusion_matrix(truth_protocols, pred_protocols, labels=all_protocols)

# Visualize
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=all_protocols,
    yticklabels=all_protocols,
    cbar_kws={'label': 'Count'}
)
plt.title('Protocol Classification Confusion Matrix', fontsize=14, fontweight='bold')
plt.ylabel('True Protocol', fontsize=12)
plt.xlabel('Predicted Protocol', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

print("\nDiagonal elements = correct predictions")
print("Off-diagonal elements = misclassifications")

## Visualizing Metric Performance

In [None]:
# Create comprehensive visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Overall Metrics Bar Chart
metric_names = ['Overall', 'Amount', 'Address', 'Protocol']
metric_values = [
    metrics['overall_accuracy'],
    metrics['amount_accuracy'],
    metrics['address_accuracy'],
    metrics['protocol_accuracy'],
]
colors = ['steelblue', 'coral', 'lightgreen', 'mediumpurple']

bars = axes[0, 0].bar(metric_names, metric_values, color=colors)
axes[0, 0].axhline(y=0.90, color='red', linestyle='--', label='90% Target')
axes[0, 0].set_title('Accuracy Metrics Overview', fontsize=14, fontweight='bold')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].set_ylim(0, 1.0)
axes[0, 0].legend()

# Add value labels
for bar, value in zip(bars, metric_values):
    height = bar.get_height()
    axes[0, 0].text(bar.get_x() + bar.get_width()/2., height + 0.02,
                    f'{value*100:.1f}%', ha='center', va='bottom', fontweight='bold')

# Plot 2: Per-Protocol Accuracy
protocols_sorted = sorted(protocol_metrics.items(), key=lambda x: x[1]['accuracy'], reverse=True)
proto_names = [p[0] for p in protocols_sorted]
proto_accs = [p[1]['accuracy'] for p in protocols_sorted]

axes[0, 1].barh(proto_names, proto_accs, color='steelblue')
axes[0, 1].axvline(x=0.90, color='red', linestyle='--', label='90% Target')
axes[0, 1].set_title('Per-Protocol Accuracy', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Accuracy')
axes[0, 1].set_xlim(0, 1.0)
axes[0, 1].legend()

# Plot 3: Sample Count by Protocol
proto_counts = [p[1]['count'] for p in protocols_sorted]
axes[1, 0].barh(proto_names, proto_counts, color='coral')
axes[1, 0].set_title('Test Set Distribution', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Number of Examples')

# Plot 4: Metric Comparison
metric_comparison = pd.DataFrame({
    'Metric': ['Amount', 'Address', 'Protocol'],
    'Accuracy': [
        metrics['amount_accuracy'],
        metrics['address_accuracy'],
        metrics['protocol_accuracy'],
    ],
    'Target': [0.90, 0.90, 0.90]
})

x = np.arange(len(metric_comparison))
width = 0.35

axes[1, 1].bar(x - width/2, metric_comparison['Accuracy'], width, label='Actual', color='steelblue')
axes[1, 1].bar(x + width/2, metric_comparison['Target'], width, label='Target', color='red', alpha=0.5)
axes[1, 1].set_title('Target vs Actual Performance', fontsize=14, fontweight='bold')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(metric_comparison['Metric'])
axes[1, 1].set_ylim(0, 1.0)
axes[1, 1].legend()

plt.tight_layout()
plt.show()

print("\n📊 Visualizations complete")

## Error Analysis

Let's analyze prediction errors to understand failure modes.

In [None]:
print("ERROR ANALYSIS")
print("=" * 80)

# Categorize errors
error_types = defaultdict(list)

for i, (pred, truth) in enumerate(zip(predictions, ground_truths)):
    if pred is None:
        error_types['parsing_failed'].append(i)
        continue
    
    if truth is None:
        continue
    
    # Check for specific errors
    if pred.get('protocol') != truth.get('protocol'):
        error_types['protocol_mismatch'].append(i)
    
    if pred.get('action') != truth.get('action'):
        error_types['action_mismatch'].append(i)
    
    if pred.get('outcome') != truth.get('outcome'):
        error_types['outcome_mismatch'].append(i)
    
    # Check amount accuracy
    pred_amounts = pred.get('amounts', [])
    truth_amounts = truth.get('amounts', [])
    if pred_amounts != truth_amounts:
        error_types['amount_mismatch'].append(i)

# Display error summary
print("\nError Type Summary:")
for error_type, indices in sorted(error_types.items(), key=lambda x: len(x[1]), reverse=True):
    count = len(indices)
    pct = count / len(predictions) * 100
    print(f"  {error_type:20s}: {count:3d} ({pct:5.1f}%)")

# Show example errors
if error_types['protocol_mismatch']:
    print("\nExample Protocol Mismatch:")
    idx = error_types['protocol_mismatch'][0]
    print(f"  Truth:      {ground_truths[idx].get('protocol')}")
    print(f"  Prediction: {predictions[idx].get('protocol')}")

## Generating Evaluation Report

Let's generate a comprehensive markdown report.

In [None]:
print("GENERATING EVALUATION REPORT")
print("=" * 80)

# Create output directory
output_dir = project_root / "outputs" / "reports"
output_dir.mkdir(parents=True, exist_ok=True)

report_path = output_dir / "evaluation_report.md"

# Generate report
report = generate_evaluation_report(
    metrics=metrics,
    protocol_metrics=protocol_metrics,
    test_size=len(test_dataset),
)

# Save report
with open(report_path, 'w') as f:
    f.write(report)

print(f"\n✓ Report saved to: {report_path}")

# Display report preview
print("\nReport Preview:")
print("=" * 80)
print(report[:1000])
if len(report) > 1000:
    print("\n... (truncated)")
    print(f"\nFull report: {report_path}")

## Saving Predictions

Save predictions for future analysis or debugging.

In [None]:
print("SAVING PREDICTIONS")
print("=" * 80)

# Create predictions output
predictions_output = []
for i, (pred, truth, raw) in enumerate(zip(predictions, ground_truths, raw_outputs)):
    predictions_output.append({
        'example_id': i,
        'ground_truth': truth,
        'prediction': pred,
        'raw_output': raw,
        'correct': pred == truth if pred and truth else False,
    })

# Save to JSON
predictions_dir = project_root / "outputs" / "predictions"
predictions_dir.mkdir(parents=True, exist_ok=True)
predictions_file = predictions_dir / "test_predictions.json"

with open(predictions_file, 'w') as f:
    json.dump(predictions_output, f, indent=2)

print(f"\n✓ Predictions saved to: {predictions_file}")
print(f"  Total predictions: {len(predictions_output)}")

# Also save metrics
metrics_file = predictions_dir.parent / "metrics" / "test_metrics.json"
metrics_file.parent.mkdir(parents=True, exist_ok=True)

metrics_output = {
    'overall_metrics': metrics,
    'protocol_metrics': protocol_metrics,
    'test_size': len(test_dataset),
}

with open(metrics_file, 'w') as f:
    json.dump(metrics_output, f, indent=2)

print(f"✓ Metrics saved to: {metrics_file}")

## Final Summary

In [None]:
print("\n" + "=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)

print(f"\n📊 Test Set: {len(test_dataset)} examples")
print(f"\n🎯 Overall Performance:")
print(f"   Overall Accuracy:  {metrics['overall_accuracy']*100:6.2f}%")
print(f"   Amount Accuracy:   {metrics['amount_accuracy']*100:6.2f}%")
print(f"   Address Accuracy:  {metrics['address_accuracy']*100:6.2f}%")
print(f"   Protocol Accuracy: {metrics['protocol_accuracy']*100:6.2f}%")

print(f"\n🏆 Best Protocol: {accuracies[0][0]} ({accuracies[0][1]*100:.1f}%)")
print(f"📉 Worst Protocol: {accuracies[-1][0]} ({accuracies[-1][1]*100:.1f}%)")

# Target achievement
target_met = sum([
    metrics['amount_accuracy'] >= 0.90,
    metrics['address_accuracy'] >= 0.90,
    metrics['protocol_accuracy'] >= 0.90,
])

print(f"\n✓ Target Achievement: {target_met}/3 metrics above 90%")

if target_met == 3:
    print("\n🎉 SUCCESS: All accuracy targets met!")
elif target_met >= 2:
    print("\n👍 GOOD: Most accuracy targets met")
else:
    print("\n⚠️  NEEDS IMPROVEMENT: Consider more training or data")

print(f"\n📁 Output Files:")
print(f"   Report:      {report_path}")
print(f"   Predictions: {predictions_file}")
print(f"   Metrics:     {metrics_file}")

print("\n" + "=" * 80)
print("✓ Evaluation complete!")
print("=" * 80)

## Key Takeaways

✓ **Model Loading**: Fine-tuned adapters load quickly (~100MB vs 14GB full model)

✓ **Batch Inference**: Process test sets efficiently with proper prompt formatting

✓ **Comprehensive Metrics**: Multiple accuracy measures reveal different aspects of performance

✓ **Per-Protocol Analysis**: Identifies which transaction types model handles best

✓ **Confusion Matrix**: Visualizes classification errors and patterns

✓ **Error Analysis**: Categorizes failures to guide improvement efforts

## Improvement Strategies

**If accuracy is below target:**

1. **More Training Data**: Collect additional examples, especially for underperforming protocols
2. **Longer Training**: Increase epochs or training steps
3. **Hyperparameter Tuning**: Adjust learning rate, LoRA rank, batch size
4. **Data Quality**: Review and improve dataset quality, fix labeling errors
5. **Prompt Engineering**: Refine instruction templates for clarity
6. **Model Selection**: Try different base models (e.g., Llama-2 vs Mistral)

**If specific protocols underperform:**

1. **Stratified Sampling**: Ensure balanced representation in training
2. **Protocol-Specific Fine-tuning**: Create focused datasets for problematic protocols
3. **Data Augmentation**: Generate synthetic examples for rare protocols

## Congratulations!

You've completed the full fine-tuning pipeline:
1. ✅ Data Exploration
2. ✅ Data Extraction
3. ✅ Dataset Preparation
4. ✅ Fine-Tuning
5. ✅ Evaluation

You now know how to fine-tune language models on blockchain data using QLoRA! 🎉

---

**Next Steps**: Use your fine-tuned model for production inference, or iterate on the training process to improve performance.