# SARSA(λ): Complete Training, Validation & Testing

This notebook provides a complete workflow for SARSA(λ) agent:
1. **Training**: Learn policy with eligibility traces
2. **Validation**: Rigorous testing to ensure proper learning
3. **Testing**: Evaluate final performance and behavior

## SARSA(λ) Algorithm
- **Type**: On-policy temporal difference learning with eligibility traces
- **Update Rule**: Q(s,a) ← Q(s,a) + α δ e(s,a)
- **Key Feature**: Eligibility traces for faster credit assignment
- **Lambda (λ)**: Controls trace decay (0=SARSA, 1=Monte Carlo)

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import time

# Add src to path for imports
sys.path.append('../src')

from apr import WarehouseEnv, RLAgentValidator, AgentEvaluator
from apr.agents import create_agent
from apr.train import run_episode
from apr.utils import ensure_outputs_directory

# Set up plotting
plt.style.use('seaborn-v0_8')
%matplotlib inline

# Ensure output directories exist
ensure_outputs_directory()
print("✅ Output directories ready")

## 1. Environment & Agent Setup

In [None]:
# Create environment
env = WarehouseEnv(seed=42)
print(f"Environment: {env.n_rows}x{env.n_cols} warehouse")
print(f"Packages to collect: {len(env.packages_remaining)}")
print(f"Max steps per episode: {env.max_steps}")

# Visualize environment
env.reset()
env.render(mode='human')
plt.title('SARSA(λ) Training Environment')
plt.show()

# Create SARSA(λ) agent
agent = create_agent(
    'sarsa_lambda',
    env.observation_space,
    env.action_space,
    alpha=0.1,          # Learning rate
    gamma=0.95,         # Discount factor
    epsilon=0.3,        # Initial exploration rate
    epsilon_decay=0.999, # Exploration decay
    epsilon_min=0.05,   # Minimum exploration
    lambda_=0.9         # Eligibility trace decay
)

print(f"\n🤖 Created SARSA(λ) Agent:")
print(f"  Algorithm: {type(agent).__name__}")
print(f"  Learning rate (α): {agent.alpha}")
print(f"  Discount factor (γ): {agent.gamma}")
print(f"  Initial exploration (ε): {agent.epsilon}")
print(f"  Eligibility trace decay (λ): {agent.lambda_}")
print(f"  Learning type: On-policy with eligibility traces")

## 2. Training Phase

In [None]:
# Training parameters
training_episodes = 800
log_interval = 100

print("🚀 Starting SARSA(λ) Training")
print("=" * 40)

# Training metrics
episode_rewards = []
episode_lengths = []
epsilon_values = []
q_table_sizes = []
eligibility_trace_sizes = []
training_times = []

start_time = time.time()

for episode in range(training_episodes):
    episode_start = time.time()
    
    # Run training episode
    reward = run_episode(env, agent, training=True)
    
    # Track metrics
    episode_rewards.append(reward)
    episode_lengths.append(env.episode_length if hasattr(env, 'episode_length') else 0)
    epsilon_values.append(agent.epsilon)
    q_table_sizes.append(len(agent.Q))
    
    # Track eligibility traces if available
    if hasattr(agent, 'E'):
        active_traces = sum(1 for state_traces in agent.E.values() 
                          for trace in state_traces if trace > 0.01)
        eligibility_trace_sizes.append(active_traces)
    else:
        eligibility_trace_sizes.append(0)
    
    training_times.append(time.time() - episode_start)
    
    # Logging
    if (episode + 1) % log_interval == 0:
        avg_reward = np.mean(episode_rewards[-log_interval:])
        avg_time = np.mean(training_times[-log_interval:])
        avg_traces = np.mean(eligibility_trace_sizes[-log_interval:])
        print(f"Episode {episode + 1:3d}: Reward = {avg_reward:6.1f}, ε = {agent.epsilon:.3f}, "
              f"Q-states = {len(agent.Q):3d}, Traces = {avg_traces:4.1f}, Time = {avg_time:.3f}s")

total_training_time = time.time() - start_time

print(f"\n✅ Training Complete!")
print(f"  Total time: {total_training_time:.1f}s")
print(f"  Final performance: {np.mean(episode_rewards[-50:]):.1f} (last 50 episodes)")
print(f"  Q-table size: {len(agent.Q)} states")
print(f"  Final exploration: {agent.epsilon:.3f}")
print(f"  Average active traces: {np.mean(eligibility_trace_sizes[-100:]):.1f}")

## 3. Training Analysis

In [None]:
# Create comprehensive training analysis
fig, axes = plt.subplots(3, 2, figsize=(15, 18))

# 1. Learning curve
ax1 = axes[0, 0]
ax1.plot(episode_rewards, alpha=0.3, color='gold', linewidth=0.5, label='Raw')
window = 50
smoothed = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
ax1.plot(range(window-1, len(episode_rewards)), smoothed, 'orange', linewidth=2, label='Smoothed')
ax1.set_xlabel('Episode')
ax1.set_ylabel('Reward')
ax1.set_title('SARSA(λ): Learning Curve')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Epsilon decay
ax2 = axes[0, 1]
ax2.plot(epsilon_values, 'r-', linewidth=2)
ax2.set_xlabel('Episode')
ax2.set_ylabel('Epsilon')
ax2.set_title('Exploration Rate Decay')
ax2.grid(True, alpha=0.3)

# 3. Q-table growth
ax3 = axes[1, 0]
ax3.plot(q_table_sizes, 'g-', linewidth=2)
ax3.set_xlabel('Episode')
ax3.set_ylabel('Number of States')
ax3.set_title('Q-Table Growth')
ax3.grid(True, alpha=0.3)

# 4. Eligibility traces (unique to SARSA-λ)
ax4 = axes[1, 1]
traces_smoothed = np.convolve(eligibility_trace_sizes, np.ones(window)//window, mode='valid')
ax4.plot(range(window-1, len(eligibility_trace_sizes)), traces_smoothed, 'purple', linewidth=2)
ax4.set_xlabel('Episode')
ax4.set_ylabel('Active Eligibility Traces')
ax4.set_title('Eligibility Traces Activity')
ax4.grid(True, alpha=0.3)

# 5. Episode lengths
ax5 = axes[2, 0]
lengths_smoothed = np.convolve(episode_lengths, np.ones(window)/window, mode='valid')
ax5.plot(range(window-1, len(episode_lengths)), lengths_smoothed, 'm-', linewidth=2)
ax5.set_xlabel('Episode')
ax5.set_ylabel('Episode Length')
ax5.set_title('Episode Length Over Time')
ax5.grid(True, alpha=0.3)

# 6. Final reward distribution
ax6 = axes[2, 1]
final_rewards = episode_rewards[-200:]  # Last 200 episodes
ax6.hist(final_rewards, bins=20, alpha=0.7, color='gold', edgecolor='black')
ax6.axvline(np.mean(final_rewards), color='red', linestyle='--', 
           label=f'Mean: {np.mean(final_rewards):.1f}')
ax6.set_xlabel('Reward')
ax6.set_ylabel('Frequency')
ax6.set_title('Final Performance Distribution')
ax6.legend()

plt.tight_layout()
plt.show()

# Print training statistics
print("📊 Training Statistics:")
print(f"  Episodes: {training_episodes}")
print(f"  Total time: {total_training_time:.1f}s ({total_training_time/60:.1f} min)")
print(f"  Avg time per episode: {np.mean(training_times):.3f}s")
print(f"  Final 100-episode average: {np.mean(episode_rewards[-100:]):.1f}")
print(f"  Best single episode: {np.max(episode_rewards):.1f}")
print(f"  Q-table final size: {len(agent.Q)} states")
print(f"  State space coverage: {len(agent.Q)/(env.n_rows*env.n_cols)*100:.1f}%")

# SARSA(λ)-specific analysis
print(f"\n🎯 SARSA(λ)-Specific Characteristics:")
print(f"  Lambda (λ): {agent.lambda_}")
print(f"  Eligibility traces: Enable faster credit assignment")
print(f"  Average active traces: {np.mean(eligibility_trace_sizes[-100:]):.1f}")
print(f"  Trace benefit: Updates multiple state-action pairs per step")
print(f"  Memory: Higher than SARSA due to eligibility trace storage")

## 4. Validation Phase

In [None]:
print("🔍 SARSA(λ) Agent Validation")
print("=" * 35)

# Create fresh agent for validation (to avoid training bias)
validation_agent = create_agent(
    'sarsa_lambda',
    env.observation_space,
    env.action_space,
    alpha=0.1, gamma=0.95, epsilon=0.3, lambda_=0.9
)

# Run comprehensive validation
validator = RLAgentValidator(validation_agent, env, verbose=True)
validation_results = validator.full_validation(
    training_episodes=300,
    test_episodes=50,
    n_seeds=3
)

# Display validation summary
print("\n" + "="*60)
print("📊 VALIDATION SUMMARY")
print("="*60)

summary = validation_results['summary']
print(f"Overall Assessment: {summary['overall_assessment']}")
print()

print("Component Scores:")
for component, score in summary['scores'].items():
    status_icon = "✅" if score == "PASS" else "⚠️" if score == "WARNING" else "❌"
    print(f"  {status_icon} {component.capitalize()}: {score}")

if summary['warnings']:
    print("\n⚠️  Warnings:")
    for warning in summary['warnings']:
        print(f"  - {warning}")

# Key metrics
learning_result = validation_results['learning']
exploration = validation_results['exploration']
generalization = validation_results['generalization']

print("\nKey Validation Metrics:")
print(f"  Learning improvement: {learning_result['improvement']:.1f} reward vs random")
print(f"  Statistical significance: p={learning_result['statistical_test']['p_value']:.4f}")
print(f"  State coverage: {exploration['state_coverage']['coverage_percent']:.1f}%")
print(f"  Generalization consistency: {generalization['consistency_score']:.3f}")

# SARSA(λ)-specific validation notes
print("\n📝 SARSA(λ) Validation Notes:")
print(f"  • Eligibility traces should improve learning speed")
print(f"  • λ={agent.lambda_} provides balance between TD(0) and Monte Carlo")
print(f"  • Expected to show faster convergence than standard SARSA")
if exploration['state_coverage']['coverage_percent'] > 50:
    print(f"  • Good exploration coverage suggests effective trace propagation")
else:
    print(f"  • Lower exploration typical for on-policy methods")

## 5. Validation Visualization

In [None]:
# Generate comprehensive validation visualization
validator.visualize_results()
print("✅ Validation visualization complete!")

## 6. Testing Phase

In [None]:
print("🎯 SARSA(λ) Agent Testing")
print("=" * 30)

# Use the trained agent for testing
evaluator = AgentEvaluator(env, verbose=True)

# Comprehensive evaluation
test_results = evaluator.evaluate_agent(
    agent,
    num_episodes=100,
    seeds=[42, 123, 456, 789, 999],  # Multiple scenarios
    render=False
)

# Display test results
print("\n📊 TEST RESULTS SUMMARY")
print("=" * 40)

agg_stats = test_results['aggregated_results']['overall_statistics']
print(f"Mean Reward: {agg_stats['mean_reward']:.1f} ± {agg_stats['std_reward']:.1f}")
print(f"Success Rate: {agg_stats['mean_success_rate']:.1%}")
print(f"Mean Episode Length: {agg_stats['mean_episode_length']:.1f}")
print(f"State Coverage: {agg_stats['mean_state_coverage']:.1%}")

# Performance across different seeds
seed_results = test_results['per_seed_results']
print("\nPerformance Across Seeds:")
for seed_key, result in seed_results.items():
    seed = result['seed']
    mean_reward = result['statistics']['mean_reward']
    success_rate = result['statistics']['success_rate']
    print(f"  Seed {seed}: {mean_reward:.1f} reward, {success_rate:.1%} success")

## 7. Testing Visualization

In [None]:
# Generate testing visualization
evaluator.visualize_evaluation(test_results)
print("✅ Testing visualization complete!")

## 8. Agent Demonstration

In [None]:
print("🎬 SARSA(λ) Agent Demonstration")
print("=" * 35)

# Reset for demonstration
env.reset()
done = False
steps = 0
total_reward = 0
max_demo_steps = 25

print(f"Initial state: Agent at {env.agent_pos}")
print(f"Packages to collect: {env.packages_remaining}")
print(f"Dropoff location: {env.dropoff}")
print()

# Action names for readability
action_names = ['Up', 'Down', 'Left', 'Right']

while not done and steps < max_demo_steps:
    state = env.agent_pos
    action = agent.act(state, training=False)  # No exploration
    
    next_state, reward, done, info = env.step(action)
    total_reward += reward
    
    print(f"Step {steps + 1:2d}: {state} → {action_names[action]:5s} → {next_state} "
          f"(reward: {reward:+4.0f}, total: {total_reward:+4.0f})")
    
    steps += 1

print(f"\nDemo completed after {steps} steps")
print(f"Final reward: {total_reward}")
print(f"Episode completed: {done}")
print(f"Packages remaining: {len(env.packages_remaining)}")
print(f"Carrying packages: {env.carrying_packages}")

# Show final state
env.render(mode='human')
plt.title(f'SARSA(λ) Agent After {steps} Steps (Reward: {total_reward})')
plt.show()

print("\n💡 SARSA(λ) Behavior Notes:")
print(f"  • Eligibility traces propagate rewards backward through trajectory")
print(f"  • λ={agent.lambda_} balances immediate vs delayed credit assignment")
print(f"  • Should learn faster than SARSA due to multi-step updates")
print(f"  • More memory intensive due to eligibility trace storage")

## 9. Eligibility Traces Analysis

In [None]:
print("🔥 Eligibility Traces Analysis")
print("=" * 30)

# Analyze eligibility traces if available
if hasattr(agent, 'E') and len(agent.E) > 0:
    # Extract trace information
    all_traces = []
    active_traces_per_state = {}
    
    for state, state_traces in agent.E.items():
        if isinstance(state_traces, np.ndarray):
            all_traces.extend(state_traces)
            active_traces_per_state[state] = np.sum(state_traces > 0.01)
    
    print(f"Eligibility Traces Statistics:")
    print(f"  States with traces: {len(agent.E)}")
    print(f"  Total trace values: {len(all_traces)}")
    print(f"  Active traces (>0.01): {sum(1 for t in all_traces if t > 0.01)}")
    print(f"  Max trace value: {np.max(all_traces):.3f}")
    print(f"  Mean trace value: {np.mean(all_traces):.3f}")
    print(f"  Trace decay (λ): {agent.lambda_}")
    
    # Visualize eligibility traces
    plt.figure(figsize=(15, 5))
    
    # 1. Trace value distribution
    plt.subplot(1, 3, 1)
    active_traces = [t for t in all_traces if t > 0.001]
    if active_traces:
        plt.hist(active_traces, bins=20, alpha=0.7, color='purple', edgecolor='black')
        plt.xlabel('Trace Value')
        plt.ylabel('Frequency')
        plt.title('Active Eligibility Trace Distribution')
        plt.yscale('log')
    
    # 2. Traces per state
    plt.subplot(1, 3, 2)
    trace_counts = list(active_traces_per_state.values())
    if trace_counts:
        plt.hist(trace_counts, bins=min(10, max(trace_counts)), alpha=0.7, 
                color='orange', edgecolor='black')
        plt.xlabel('Active Traces per State')
        plt.ylabel('Number of States')
        plt.title('Trace Activity per State')
    
    # 3. Trace activity over training
    plt.subplot(1, 3, 3)
    plt.plot(eligibility_trace_sizes, alpha=0.5, color='purple', linewidth=0.5)
    trace_smoothed = np.convolve(eligibility_trace_sizes, np.ones(50)/50, mode='valid')
    plt.plot(range(49, len(eligibility_trace_sizes)), trace_smoothed, 'purple', linewidth=2)
    plt.xlabel('Episode')
    plt.ylabel('Active Traces')
    plt.title('Trace Activity Over Training')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n📈 Trace Analysis:")
    print(f"  Peak trace activity: {np.max(eligibility_trace_sizes)} active traces")
    print(f"  Average activity (final 100 episodes): {np.mean(eligibility_trace_sizes[-100:]):.1f}")
    print(f"  Trace efficiency: {np.mean(eligibility_trace_sizes)/len(agent.Q)*100:.1f}% of Q-states have traces")

else:
    print("⚠️  No eligibility traces available for analysis")
    print("    This might indicate the agent wasn't properly initialized as SARSA(λ)")

## 10. Policy Analysis

In [None]:
print("🧠 SARSA(λ) Policy Analysis")
print("=" * 30)

# Analyze learned Q-values and policy
if hasattr(agent, 'Q') and len(agent.Q) > 0:
    # Extract Q-values and policy
    all_q_values = []
    policy = {}
    state_values = {}
    
    for state, q_vals in agent.Q.items():
        if isinstance(q_vals, np.ndarray):
            all_q_values.extend(q_vals)
            policy[state] = np.argmax(q_vals)
            state_values[state] = np.max(q_vals)
    
    print(f"Q-table Statistics:")
    print(f"  States learned: {len(agent.Q)}")
    print(f"  Q-value range: [{np.min(all_q_values):.1f}, {np.max(all_q_values):.1f}]")
    print(f"  Q-value mean: {np.mean(all_q_values):.1f}")
    print(f"  Q-value std: {np.std(all_q_values):.1f}")
    
    # Action distribution in policy
    from collections import Counter
    action_dist = Counter(policy.values())
    print(f"\nPolicy Action Distribution:")
    action_names = ['Up', 'Down', 'Left', 'Right']
    for action, count in action_dist.items():
        percentage = count / len(policy) * 100
        print(f"  {action_names[action]:5s}: {count:3d} states ({percentage:4.1f}%)")
    
    # Compare eligibility-enhanced learning
    print(f"\n🎯 SARSA(λ) Policy Characteristics:")
    print(f"  • Eligibility traces enable faster value propagation")
    print(f"  • Should converge faster than standard SARSA")
    print(f"  • λ={agent.lambda_} balances TD(0) and Monte Carlo methods")
    print(f"  • More stable than Q-Learning in stochastic environments")
    
    # Visualize Q-value distribution and policy
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.hist(all_q_values, bins=30, alpha=0.7, color='gold', edgecolor='black')
    plt.axvline(np.mean(all_q_values), color='red', linestyle='--', 
               label=f'Mean: {np.mean(all_q_values):.1f}')
    plt.xlabel('Q-Value')
    plt.ylabel('Frequency')
    plt.title('SARSA(λ) Q-Value Distribution')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    actions = list(action_dist.keys())
    counts = list(action_dist.values())
    action_labels = [action_names[a] for a in actions]
    
    colors = ['red', 'blue', 'green', 'orange'][:len(actions)]
    plt.bar(action_labels, counts, color=colors, alpha=0.7)
    plt.xlabel('Action')
    plt.ylabel('Frequency in Policy')
    plt.title('SARSA(λ) Learned Policy Distribution')
    
    plt.tight_layout()
    plt.show()

else:
    print("⚠️  No Q-table available for analysis")

## 11. Save Results

In [None]:
from apr.utils import get_outputs_dir

# Create save paths
outputs_dir = get_outputs_dir()
models_dir = outputs_dir / 'models'
models_dir.mkdir(exist_ok=True)

# Save trained agent
agent_path = models_dir / 'sarsa_lambda_complete.pkl'
agent.save(agent_path)
print(f"✅ Saved trained agent to: {agent_path}")

# Save training metrics
training_df = pd.DataFrame({
    'episode': range(1, training_episodes + 1),
    'reward': episode_rewards,
    'epsilon': epsilon_values,
    'q_table_size': q_table_sizes,
    'eligibility_traces': eligibility_trace_sizes,
    'episode_length': episode_lengths,
    'training_time': training_times
})

metrics_path = models_dir / 'sarsa_lambda_training_metrics.csv'
training_df.to_csv(metrics_path, index=False)
print(f"📊 Saved training metrics to: {metrics_path}")

# Save validation results
validation_dir = outputs_dir / 'validation_results'
validation_dir.mkdir(exist_ok=True)

# Save test results
test_summary = pd.DataFrame([{
    'algorithm': 'SARSA(λ)',
    'lambda': agent.lambda_,
    'mean_reward': agg_stats['mean_reward'],
    'std_reward': agg_stats['std_reward'],
    'success_rate': agg_stats['mean_success_rate'],
    'episode_length': agg_stats['mean_episode_length'],
    'state_coverage': agg_stats['mean_state_coverage'],
    'training_episodes': training_episodes,
    'training_time': total_training_time,
    'final_q_table_size': len(agent.Q),
    'avg_active_traces': np.mean(eligibility_trace_sizes[-100:])
}])

test_path = validation_dir / 'sarsa_lambda_test_summary.csv'
test_summary.to_csv(test_path, index=False)
print(f"🎯 Saved test summary to: {test_path}")

# Final summary
print("\n" + "="*50)
print("🎉 SARSA(λ) COMPLETE WORKFLOW FINISHED!")
print("="*50)
print(f"Training Episodes: {training_episodes}")
print(f"Training Time: {total_training_time:.1f}s")
print(f"Final Performance: {np.mean(episode_rewards[-50:]):.1f} reward")
print(f"Test Performance: {agg_stats['mean_reward']:.1f} ± {agg_stats['std_reward']:.1f}")
print(f"Success Rate: {agg_stats['mean_success_rate']:.1%}")
print(f"Validation Status: {summary['overall_assessment']}")
print(f"\n🔥 SARSA(λ) Advantages:")
print(f"  • Eligibility traces (λ={agent.lambda_}) for faster learning")
print(f"  • Multi-step credit assignment")
print(f"  • On-policy stability with enhanced convergence")
print(f"  • Average {np.mean(eligibility_trace_sizes[-100:]):.1f} active traces per episode")
print(f"  • Better than SARSA for delayed reward scenarios")
print("\n✅ All results saved to outputs directory")