# Snake AI Model Testing & Evaluation

This notebook provides tools for:
- Loading and testing trained models
- Evaluating model performance on different scenarios
- Visualizing model decision-making
- Comparing different model versions

In [51]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from pathlib import Path
import sys
import json
from typing import List, Dict, Tuple

# Add the src directory to Python path
sys.path.append('..\src')

# Import our Snake AI modules
try:
    from snake_ai import DQN, SnakeAI
    from snake_game import SnakeGame
    from utils import get_state
    print("✅ Snake AI modules imported successfully!")
except ImportError as e:
    print(f"❌ Error importing modules: {e}")
    print("Make sure you're running this from the correct directory and modules exist.")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

❌ Error importing modules: cannot import name 'get_state' from 'utils' (c:\Users\jross\Source\ai-snake\notebooks\../src\utils.py)
Make sure you're running this from the correct directory and modules exist.


## 1. Model Loading and Setup

In [42]:
def find_model_files():
    """Find available model files in the project directory."""
    # Look in the user's roaming data directory
    data_path = Path.home() / "AppData" / "Roaming" / "SnakeAI"
    
    if not data_path.exists():
        print(f"Project data directory not found: {data_path}")
        return []
    
    # Find all .pth model files
    model_files = list(data_path.glob("**/models/*.pth"))
    
    if not model_files:
        print("No model files found.")
        print(f"Looking in: {data_path}")
        return []
    
    print(f"Found {len(model_files)} model files:")
    for i, model_file in enumerate(model_files):
        print(f"  {i+1}. {model_file.parent.parent.name}/{model_file.name}")
    
    return model_files

def load_model(model_path: Path, input_size: int = 11, output_size: int = 3):
    """Load a trained model from file."""
    try:
        # Create model architecture
        model = DQN(input_size, output_size)
        
        # Load trained weights
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        model.eval()
        
        print(f"✅ Model loaded successfully from: {model_path.name}")
        return model
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None

# Find and display available models
model_files = find_model_files()

# Load the most recent model (or specify index)
if model_files:
    # Load the most recent model
    latest_model_path = max(model_files, key=lambda x: x.stat().st_mtime)
    model = load_model(latest_model_path)
    
    if model is not None:
        print(f"\nModel architecture:")
        print(model)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"\nModel parameters:")
        print(f"  Total: {total_params:,}")
        print(f"  Trainable: {trainable_params:,}")
else:
    model = None
    print("No models available for testing.")

Found 12 model files:
  1. SnakeAI/snake_model_best.pth
  2. SnakeAI/snake_model_best_best.pth
  3. SnakeAI/snake_model_best_checkpoint.pth
  4. Snake_AI_Project/best_model.pth
  5. Snake_AI_Project/checkpoint.pth
  6. Snake_AI_Project_0/best_model.pth
  7. Snake_AI_Project_100k/best_model.pth
  8. Snake_AI_Project_100k/checkpoint.pth
  9. Snake_AI_Project_2/best_model.pth
  10. Snake_AI_Project_2/checkpoint.pth
  11. Snake_AI_Project_3/best_model.pth
  12. Snake_AI_Project_3/checkpoint.pth
❌ Error loading model: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively,

## 2. Model Testing Functions

In [43]:
def test_model_performance(model, num_games: int = 100, board_size: int = 10):
    """Test model performance over multiple games."""
    if model is None:
        print("No model loaded for testing.")
        return None
    
    scores = []
    steps_list = []
    game_results = []
    
    print(f"Testing model performance over {num_games} games...")
    
    for game_num in range(num_games):
        # Create game instance
        game = SnakeGame(board_size, board_size)
        game.reset()
        
        steps = 0
        max_steps = board_size * board_size * 2  # Prevent infinite loops
        
        while not game.game_over and steps < max_steps:
            # Get current state
            state = get_state(game)
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            
            # Get model prediction
            with torch.no_grad():
                q_values = model(state_tensor)
                action = q_values.argmax().item()
            
            # Take action
            game.move(action)
            steps += 1
        
        scores.append(game.score)
        steps_list.append(steps)
        
        # Record detailed game result
        game_results.append({
            'game': game_num + 1,
            'score': game.score,
            'steps': steps,
            'survived': game.score > 0,
            'efficiency': game.score / steps if steps > 0 else 0
        })
        
        if (game_num + 1) % 20 == 0:
            print(f"  Completed {game_num + 1}/{num_games} games...")
    
    # Create results DataFrame
    results_df = pd.DataFrame(game_results)
    
    # Calculate statistics
    avg_score = np.mean(scores)
    max_score = np.max(scores)
    success_rate = (np.array(scores) > 0).mean() * 100
    avg_steps = np.mean(steps_list)
    
    print(f"\n=== Test Results ===")
    print(f"Average Score: {avg_score:.2f}")
    print(f"Best Score: {max_score}")
    print(f"Success Rate: {success_rate:.1f}%")
    print(f"Average Steps: {avg_steps:.1f}")
    print(f"Score Std Dev: {np.std(scores):.2f}")
    
    return results_df

# Run performance test
if model is not None:
    test_results = test_model_performance(model, num_games=50)
    
    if test_results is not None:
        print("\nFirst few test results:")
        display(test_results.head(10))
else:
    test_results = None

## 3. Performance Visualization

In [44]:
# Visualize test results
if test_results is not None:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Score distribution
    axes[0,0].hist(test_results['score'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0,0].axvline(test_results['score'].mean(), color='red', linestyle='--', 
                     label=f'Mean: {test_results["score"].mean():.1f}')
    axes[0,0].set_title('Score Distribution')
    axes[0,0].set_xlabel('Score')
    axes[0,0].set_ylabel('Frequency')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # Steps vs Score scatter
    axes[0,1].scatter(test_results['steps'], test_results['score'], alpha=0.6, color='green')
    axes[0,1].set_title('Steps vs Score')
    axes[0,1].set_xlabel('Steps Taken')
    axes[0,1].set_ylabel('Score Achieved')
    axes[0,1].grid(True, alpha=0.3)
    
    # Add correlation
    correlation = test_results['steps'].corr(test_results['score'])
    axes[0,1].text(0.05, 0.95, f'Correlation: {correlation:.2f}', 
                  transform=axes[0,1].transAxes, verticalalignment='top',
                  bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Efficiency distribution
    axes[1,0].hist(test_results['efficiency'], bins=20, alpha=0.7, color='orange', edgecolor='black')
    axes[1,0].axvline(test_results['efficiency'].mean(), color='red', linestyle='--',
                     label=f'Mean: {test_results["efficiency"].mean():.3f}')
    axes[1,0].set_title('Efficiency Distribution (Score/Steps)')
    axes[1,0].set_xlabel('Efficiency')
    axes[1,0].set_ylabel('Frequency')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # Performance over games (to check consistency)
    axes[1,1].plot(test_results['game'], test_results['score'], alpha=0.6, color='purple')
    axes[1,1].plot(test_results['game'], test_results['score'].rolling(10, min_periods=1).mean(), 
                  color='darkred', linewidth=2, label='Rolling Mean (10 games)')
    axes[1,1].set_title('Performance Consistency')
    axes[1,1].set_xlabel('Game Number')
    axes[1,1].set_ylabel('Score')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics table
    summary_stats = test_results.describe()
    print("\nDetailed Statistics:")
    display(summary_stats)
else:
    print("No test results to visualize.")

No test results to visualize.


## 4. Model Decision Analysis

In [45]:
def analyze_model_decisions(model, num_states: int = 100):
    """Analyze what types of decisions the model makes in different situations."""
    if model is None:
        print("No model loaded for analysis.")
        return None
    
    # Action names for interpretation
    action_names = ['Straight', 'Right', 'Left']
    
    # Collect decision data
    decision_data = []
    
    print(f"Analyzing model decisions over {num_states} random game states...")
    
    for i in range(num_states):
        # Create a random game state
        game = SnakeGame(10, 10)
        game.reset()
        
        # Take a few random moves to create variety
        for _ in range(np.random.randint(0, 10)):
            if not game.game_over:
                game.move(np.random.randint(0, 3))
        
        if not game.game_over:
            # Get state and model decision
            state = get_state(game)
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            
            with torch.no_grad():
                q_values = model(state_tensor)
                action_probs = torch.softmax(q_values, dim=1)
                action = q_values.argmax().item()
            
            # Analyze the state
            decision_data.append({
                'state_id': i,
                'chosen_action': action,
                'action_name': action_names[action],
                'q_straight': q_values[0][0].item(),
                'q_right': q_values[0][1].item(),
                'q_left': q_values[0][2].item(),
                'confidence': action_probs[0][action].item(),
                'danger_straight': state[0],  # Danger straight ahead
                'danger_right': state[1],     # Danger to the right
                'danger_left': state[2],      # Danger to the left
                'food_direction': np.argmax(state[7:11]),  # Food direction
                'snake_length': game.score + 1
            })
    
    decisions_df = pd.DataFrame(decision_data)
    
    print(f"\n=== Decision Analysis ===")
    print(f"Action Distribution:")
    action_counts = decisions_df['action_name'].value_counts()
    for action, count in action_counts.items():
        print(f"  {action}: {count} ({count/len(decisions_df)*100:.1f}%)")
    
    # Analyze decision confidence
    print(f"\nAverage Decision Confidence: {decisions_df['confidence'].mean():.2f}")
    
    return decisions_df

# Analyze model decisions
if model is not None:
    decision_analysis = analyze_model_decisions(model, num_states=200)
    
    if decision_analysis is not None:
        print("\nFirst few decision records:")
        display(decision_analysis.head())
else:
    decision_analysis = None

## 5. Decision Visualization

In [46]:
# Visualize decision patterns
if decision_analysis is not None:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Action distribution pie chart
    action_counts = decision_analysis['action_name'].value_counts()
    axes[0,0].pie(action_counts.values, labels=action_counts.index, autopct='%1.1f%%',
                 colors=['lightblue', 'lightgreen', 'lightcoral'])
    axes[0,0].set_title('Action Distribution')
    
    # Decision confidence by action
    sns.boxplot(data=decision_analysis, x='action_name', y='confidence', ax=axes[0,1])
    axes[0,1].set_title('Decision Confidence by Action')
    axes[0,1].set_xlabel('Action')
    axes[0,1].set_ylabel('Confidence')
    
    # Q-values distribution
    q_values_cols = ['q_straight', 'q_right', 'q_left']
    q_values_data = decision_analysis[q_values_cols]
    
    axes[1,0].hist([q_values_data['q_straight'], q_values_data['q_right'], q_values_data['q_left']], 
                  bins=20, alpha=0.7, label=['Straight', 'Right', 'Left'], color=['blue', 'green', 'red'])
    axes[1,0].set_title('Q-Values Distribution')
    axes[1,0].set_xlabel('Q-Value')
    axes[1,0].set_ylabel('Frequency')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # Danger vs Action analysis
    danger_actions = []
    for _, row in decision_analysis.iterrows():
        if row['danger_straight'] == 1:
            if row['chosen_action'] == 0:  # Chose straight despite danger
                danger_actions.append('Risky')
            else:
                danger_actions.append('Safe')
        else:
            danger_actions.append('No immediate danger')
    
    danger_df = pd.DataFrame({'decision_type': danger_actions})
    danger_counts = danger_df['decision_type'].value_counts()
    
    axes[1,1].bar(danger_counts.index, danger_counts.values, 
                 color=['red', 'green', 'blue'])
    axes[1,1].set_title('Danger Response Analysis')
    axes[1,1].set_xlabel('Decision Type')
    axes[1,1].set_ylabel('Count')
    axes[1,1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed analysis
    risky_decisions = sum(1 for d in danger_actions if d == 'Risky')
    total_danger_situations = sum(1 for d in danger_actions if d != 'No immediate danger')
    
    if total_danger_situations > 0:
        safety_rate = (total_danger_situations - risky_decisions) / total_danger_situations * 100
        print(f"\n🛡️  Safety Analysis:")
        print(f"   Safe decisions in danger: {safety_rate:.1f}%")
        print(f"   Risky decisions: {risky_decisions}/{total_danger_situations}")
    
    # Correlation analysis
    print(f"\n📊 Correlations:")
    print(f"   Confidence vs Snake Length: {decision_analysis['confidence'].corr(decision_analysis['snake_length']):.2f}")
    
else:
    print("No decision analysis data to visualize.")

No decision analysis data to visualize.


## 6. Model Comparison (if multiple models available)

In [47]:
def compare_models(model_files: List[Path], num_games: int = 50):
    """Compare performance of multiple models."""
    if len(model_files) < 2:
        print("Need at least 2 models for comparison.")
        return None
    
    comparison_results = []
    
    for model_path in model_files:
        print(f"\nTesting model: {model_path.name}")
        
        # Load model
        test_model = load_model(model_path)
        if test_model is None:
            continue
        
        # Test performance
        results = test_model_performance(test_model, num_games=num_games)
        
        if results is not None:
            # Calculate summary stats
            summary = {
                'model_name': model_path.name,
                'avg_score': results['score'].mean(),
                'max_score': results['score'].max(),
                'success_rate': (results['score'] > 0).mean() * 100,
                'avg_steps': results['steps'].mean(),
                'avg_efficiency': results['efficiency'].mean(),
                'score_std': results['score'].std()
            }
            comparison_results.append(summary)
    
    if comparison_results:
        comparison_df = pd.DataFrame(comparison_results)
        
        print("\n=== Model Comparison ===")
        display(comparison_df.round(2))
        
        # Find best model for each metric
        print("\n🏆 Best Models by Metric:")
        metrics = ['avg_score', 'max_score', 'success_rate', 'avg_efficiency']
        
        for metric in metrics:
            best_idx = comparison_df[metric].idxmax()
            best_model = comparison_df.loc[best_idx, 'model_name']
            best_value = comparison_df.loc[best_idx, metric]
            print(f"   {metric}: {best_model} ({best_value:.2f})")
        
        return comparison_df
    
    return None

# Compare models if multiple available
if len(model_files) >= 2:
    print(f"Comparing {len(model_files)} available models...")
    model_comparison = compare_models(model_files[:3], num_games=30)  # Limit to 3 models for speed
else:
    print(f"Only {len(model_files)} model(s) available. Need at least 2 for comparison.")
    model_comparison = None

Comparing 12 available models...

Testing model: snake_model_best.pth
❌ Error loading model: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy._core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy._core.multiarray.scalar])` or the `torch.serialization.safe_globals([numpy._core.multiarray.scalar])` context manager to allowlist this global if you

## 7. Save Test Results

In [48]:
# Save test results for later analysis
output_dir = Path('../model_test_results')
output_dir.mkdir(exist_ok=True)

timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')

if test_results is not None:
    # Save individual test results
    test_results.to_csv(output_dir / f'model_test_{timestamp}.csv', index=False)
    print(f"✅ Test results saved to: model_test_{timestamp}.csv")

if decision_analysis is not None:
    # Save decision analysis
    decision_analysis.to_csv(output_dir / f'decision_analysis_{timestamp}.csv', index=False)
    print(f"✅ Decision analysis saved to: decision_analysis_{timestamp}.csv")

if model_comparison is not None:
    # Save model comparison
    model_comparison.to_csv(output_dir / f'model_comparison_{timestamp}.csv', index=False)
    print(f"✅ Model comparison saved to: model_comparison_{timestamp}.csv")

print(f"\n📁 All results saved to: {output_dir}")
print("\n🎯 Model evaluation complete!")


📁 All results saved to: ..\model_test_results

🎯 Model evaluation complete!
