# Agent Validation Notebook

This notebook provides comprehensive validation of trained RL agents to ensure they are actually learning and not just memorizing paths or getting stuck in local optima.

## Validation Tests:
1. **Learning vs Random**: Statistical significance testing
2. **Convergence Analysis**: Stability and improvement over time
3. **Exploration Analysis**: State coverage and action diversity
4. **Generalization**: Performance across different scenarios
5. **Local Optima Detection**: Multiple training runs comparison
6. **Policy Analysis**: Learned behavior examination

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add src to path for imports
sys.path.append('../src')

from apr import WarehouseEnv, RLAgentValidator
from apr.agents import create_agent

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
%matplotlib inline

## Setup Environment and Agents

In [None]:
# Create environment
env = WarehouseEnv(seed=42)
print(f"Environment: {env.n_rows}x{env.n_cols} warehouse")
print(f"Action space: {env.n_actions} actions")
print(f"State space: {env.n_rows * env.n_cols} possible states")

# Create agents to validate
agents_to_validate = {
    'Q-Learning': create_agent('q_learning', env.observation_space, env.action_space, 
                              alpha=0.1, gamma=0.95, epsilon=0.3),
    'Double Q-Learning': create_agent('double_q_learning', env.observation_space, env.action_space, 
                                     alpha=0.1, gamma=0.95, epsilon=0.3),
    'SARSA': create_agent('sarsa', env.observation_space, env.action_space, 
                         alpha=0.1, gamma=0.95, epsilon=0.3)
}

print(f"\nAgents to validate: {list(agents_to_validate.keys())}")

## Validation Function

In [None]:
def validate_agent(agent_name, agent, training_episodes=400, test_episodes=50, n_seeds=3):
    """
    Comprehensive validation of a single agent.
    """
    print(f"\n🔍 Validating {agent_name}")
    print("=" * 50)
    
    # Create validator
    validator = RLAgentValidator(agent, env, verbose=True)
    
    # Run full validation suite
    results = validator.full_validation(
        training_episodes=training_episodes,
        test_episodes=test_episodes,
        n_seeds=n_seeds
    )
    
    return results

## Individual Agent Validation

In [None]:
# Validate Q-Learning
qlearning_results = validate_agent('Q-Learning', agents_to_validate['Q-Learning'])

In [None]:
# Validate Double Q-Learning
double_qlearning_results = validate_agent('Double Q-Learning', agents_to_validate['Double Q-Learning'])

In [None]:
# Validate SARSA
sarsa_results = validate_agent('SARSA', agents_to_validate['SARSA'])

## Validation Results Comparison

In [None]:
# Collect all results
all_results = {
    'Q-Learning': qlearning_results,
    'Double Q-Learning': double_qlearning_results,
    'SARSA': sarsa_results
}

# Create comparison summary
print("\n📊 VALIDATION COMPARISON SUMMARY")
print("=" * 60)

comparison_data = []
for agent_name, results in all_results.items():
    summary = results['summary']
    learning_result = results['learning']
    exploration = results['exploration']
    generalization = results['generalization']
    
    comparison_data.append({
        'Agent': agent_name,
        'Overall': summary['overall_assessment'],
        'Learning Improvement': f"{learning_result['improvement']:.1f}",
        'P-Value': f"{learning_result['statistical_test']['p_value']:.4f}",
        'State Coverage': f"{exploration['state_coverage']['coverage_percent']:.1f}%",
        'Generalization': f"{generalization['consistency_score']:.3f}"
    })

# Display comparison table
import pandas as pd
comparison_df = pd.DataFrame(comparison_data)
print(comparison_df.to_string(index=False))

# Count validation outcomes
outcomes = [results['summary']['overall_assessment'] for results in all_results.values()]
print(f"\n📈 Validation Outcomes:")
print(f"  PASS: {outcomes.count('PASS')}")
print(f"  PASS_WITH_WARNINGS: {outcomes.count('PASS_WITH_WARNINGS')}")
print(f"  FAIL: {outcomes.count('FAIL')}")

## Detailed Validation Analysis

In [None]:
# Create detailed comparison plots
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

agent_names = list(all_results.keys())
colors = ['skyblue', 'lightgreen', 'lightcoral']

# 1. Learning improvement comparison
ax1 = axes[0, 0]
improvements = [all_results[name]['learning']['improvement'] for name in agent_names]
bars1 = ax1.bar(agent_names, improvements, color=colors)
ax1.set_title('Learning vs Random Performance')
ax1.set_ylabel('Reward Improvement')
ax1.tick_params(axis='x', rotation=45)

# Add value labels
for bar, val in zip(bars1, improvements):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10, 
             f'{val:.1f}', ha='center', va='bottom')

# 2. State coverage comparison
ax2 = axes[0, 1]
coverages = [all_results[name]['exploration']['state_coverage']['coverage_percent'] 
            for name in agent_names]
bars2 = ax2.bar(agent_names, coverages, color=colors)
ax2.set_title('State Coverage')
ax2.set_ylabel('Coverage %')
ax2.tick_params(axis='x', rotation=45)

for bar, val in zip(bars2, coverages):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
             f'{val:.1f}%', ha='center', va='bottom')

# 3. Generalization consistency
ax3 = axes[0, 2]
consistencies = [all_results[name]['generalization']['consistency_score'] 
                for name in agent_names]
bars3 = ax3.bar(agent_names, consistencies, color=colors)
ax3.set_title('Generalization Consistency')
ax3.set_ylabel('Consistency Score')
ax3.tick_params(axis='x', rotation=45)

for bar, val in zip(bars3, consistencies):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{val:.3f}', ha='center', va='bottom')

# 4. Statistical significance
ax4 = axes[1, 0]
p_values = [all_results[name]['learning']['statistical_test']['p_value'] 
           for name in agent_names]
bars4 = ax4.bar(agent_names, [-np.log10(p) for p in p_values], color=colors)
ax4.set_title('Statistical Significance (-log10(p))')
ax4.set_ylabel('-log10(p-value)')
ax4.axhline(-np.log10(0.05), color='red', linestyle='--', label='p=0.05 threshold')
ax4.legend()
ax4.tick_params(axis='x', rotation=45)

# 5. Action diversity
ax5 = axes[1, 1]
diversities = [all_results[name]['exploration']['action_diversity']['diversity_percent'] 
              for name in agent_names]
bars5 = ax5.bar(agent_names, diversities, color=colors)
ax5.set_title('Action Diversity')
ax5.set_ylabel('Diversity %')
ax5.tick_params(axis='x', rotation=45)

for bar, val in zip(bars5, diversities):
    ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
             f'{val:.1f}%', ha='center', va='bottom')

# 6. Local optima performance range
ax6 = axes[1, 2]
ranges = [all_results[name]['local_optima']['performance_range'] 
         for name in agent_names]
bars6 = ax6.bar(agent_names, ranges, color=colors)
ax6.set_title('Local Optima (Performance Range)')
ax6.set_ylabel('Performance Range')
ax6.axhline(50, color='red', linestyle='--', label='Concern threshold')
ax6.legend()
ax6.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

## Validation Issues and Recommendations

In [None]:
print("⚠️  VALIDATION ISSUES AND RECOMMENDATIONS")
print("=" * 55)

for agent_name, results in all_results.items():
    summary = results['summary']
    
    print(f"\n{agent_name}:")
    print(f"  Overall Assessment: {summary['overall_assessment']}")
    
    if summary['critical_issues']:
        print("  ❌ Critical Issues:")
        for issue in summary['critical_issues']:
            print(f"    - {issue}")
    
    if summary['warnings']:
        print("  ⚠️  Warnings:")
        for warning in summary['warnings']:
            print(f"    - {warning}")
    
    if not summary['critical_issues'] and not summary['warnings']:
        print("  ✅ No issues detected - agent is learning properly!")

# Overall recommendations
print("\n💡 OVERALL RECOMMENDATIONS:")
print("-" * 30)

# Check for common issues
low_coverage_agents = [name for name, results in all_results.items() 
                      if results['exploration']['state_coverage']['coverage_percent'] < 50]

if low_coverage_agents:
    print(f"• Agents with low state coverage: {', '.join(low_coverage_agents)}")
    print("  → Consider increasing exploration (higher epsilon) or longer training")

# Check for statistical significance
non_significant = [name for name, results in all_results.items() 
                  if not results['learning']['statistical_test']['significant']]

if non_significant:
    print(f"• Agents without statistically significant learning: {', '.join(non_significant)}")
    print("  → These agents may not be learning effectively")

# Performance ranking
performance_ranking = sorted(all_results.items(), 
                           key=lambda x: x[1]['learning']['improvement'], reverse=True)

print(f"\n🏆 Performance Ranking (by learning improvement):")
for i, (name, results) in enumerate(performance_ranking):
    improvement = results['learning']['improvement']
    print(f"  {i+1}. {name}: {improvement:.1f} reward improvement")

## Individual Agent Visualizations

In [None]:
# Visualize the best performing agent in detail
best_agent_name = performance_ranking[0][0]
best_results = performance_ranking[0][1]

print(f"🔍 Detailed visualization for best agent: {best_agent_name}")

# Create validator for visualization
best_agent = agents_to_validate[best_agent_name]
validator = RLAgentValidator(best_agent, env, verbose=False)
validator.validation_results = best_results

# Generate comprehensive visualization
validator.visualize_results()

print(f"\n✅ Detailed validation visualization complete for {best_agent_name}")

## Save Validation Results

In [None]:
# Save validation results
results_dir = Path('../validation_results')
results_dir.mkdir(exist_ok=True)

import json
import pickle

# Save detailed results
for agent_name, results in all_results.items():
    # Convert numpy arrays to lists for JSON serialization
    json_results = {}
    for key, value in results.items():
        if isinstance(value, dict):
            json_results[key] = {}
            for subkey, subvalue in value.items():
                if isinstance(subvalue, np.ndarray):
                    json_results[key][subkey] = subvalue.tolist()
                elif isinstance(subvalue, dict):
                    json_results[key][subkey] = subvalue
                else:
                    json_results[key][subkey] = subvalue
        else:
            json_results[key] = value
    
    # Save as JSON
    json_path = results_dir / f'{agent_name.lower().replace(" ", "_")}_validation.json'
    with open(json_path, 'w') as f:
        json.dump(json_results, f, indent=2, default=str)
    
    print(f"💾 Saved {agent_name} validation results to {json_path}")

# Save comparison summary
comparison_path = results_dir / 'validation_comparison.csv'
comparison_df.to_csv(comparison_path, index=False)
print(f"📊 Saved comparison summary to {comparison_path}")

print(f"\n✅ All validation results saved to {results_dir}")