In [3]:
# !pip install stable-baselines3 gym
# !pip install "shimmy>=2.0"
# !pip install opencv-python
# !pip install seaborn


In [4]:
# Add project root to import path (temporary)
import sys
sys.path.insert(0, r"c:\Users\Acer\Desktop\snake-gym")
import snake_gym

In [5]:
import gym
import snake_gym
import random

env = gym.make("snake-v0")
# obs = env.reset()
# done = False

# while not done:
#     action = random.choice([0, 1, 2, 3])
#     obs, reward, done, info = env.step(action)

  logger.warn(


In [None]:
import random
import numpy as np
import time
import psutil
import pickle
import cv2
import pygame
from collections import deque
from snake_gym.envs.snake import SnakeGame
from snake_gym.envs.modules import GRIDSIZE

class MetricsLogger:
    """Comprehensive metrics tracking for RL training"""
    
    def __init__(self, agent_name, seed, record_gameplay=True, record_frequency=500):
        self.agent_name = agent_name
        self.seed = seed
        self.start_time = time.time()
        
        # Gameplay recording settings
        self.record_gameplay = record_gameplay
        self.record_frequency = record_frequency  # Record every N episodes
        self.recorded_episodes = []  # List of episode numbers that were recorded
        self.best_episode_score = 0
        self.best_episode_number = 0
        
        # Episode-level metrics
        self.episode_rewards = []
        self.episode_scores = []
        self.episode_lengths = []
        self.episode_times = []
        
        # Training stability metrics
        self.reward_variance_per_block = []
        self.score_variance_per_block = []
        
        # Exploration metrics
        self.epsilon_history = []
        
        # Action distribution (for behavioral analysis)
        self.action_counts = {0: 0, 1: 0, 2: 0, 3: 0}
        self.action_history_per_block = []
        
        # Computational metrics
        self.cpu_usage = []
        self.memory_usage = []
        
        # Convergence metrics
        self.moving_avg_reward = []
        self.moving_avg_score = []
        
        # Sample efficiency
        self.episodes_to_threshold = {}
    
    def should_record_episode(self, episode_num):
        """Determine if this episode should be recorded"""
        if not self.record_gameplay:
            return False
        
        # Record at specific intervals
        if episode_num % self.record_frequency == 0:
            return True
        
        return False
    
    def is_best_episode(self, score):
        """Check if this is the best episode so far"""
        return score > self.best_episode_score
    
    def update_best_episode(self, episode_num, score):
        """Update best episode tracker"""
        if score > self.best_episode_score:
            self.best_episode_score = score
            self.best_episode_number = episode_num
            return True
        return False
    
    def log_episode(self, reward, score, length, epsilon, episode_time, actions_taken):
        """Log metrics for a single episode"""
        self.episode_rewards.append(reward)
        self.episode_scores.append(score)
        self.episode_lengths.append(length)
        self.episode_times.append(episode_time)
        self.epsilon_history.append(epsilon)
        
        # Update action counts
        for action in actions_taken:
            self.action_counts[action] += 1
        
        # Calculate moving averages (window of 100)
        window = 100
        if len(self.episode_rewards) >= window:
            self.moving_avg_reward.append(np.mean(self.episode_rewards[-window:]))
            self.moving_avg_score.append(np.mean(self.episode_scores[-window:]))
        else:
            self.moving_avg_reward.append(np.mean(self.episode_rewards))
            self.moving_avg_score.append(np.mean(self.episode_scores))
        
        # Log system metrics
        self.cpu_usage.append(psutil.cpu_percent(interval=0.1))
        self.memory_usage.append(psutil.virtual_memory().percent)
    
    def log_block_statistics(self, block_size=1000):
        """Calculate variance for blocks of episodes"""
        if len(self.episode_rewards) >= block_size:
            recent_rewards = self.episode_rewards[-block_size:]
            recent_scores = self.episode_scores[-block_size:]
            
            self.reward_variance_per_block.append(np.var(recent_rewards))
            self.score_variance_per_block.append(np.var(recent_scores))
            
            # Action distribution for this block
            total_actions = sum(self.action_counts.values())
            if total_actions > 0:
                action_dist = {k: v/total_actions for k, v in self.action_counts.items()}
                self.action_history_per_block.append(action_dist)
    
    def check_convergence_threshold(self, threshold_percent=0.8):
        """Check if agent reached threshold of max performance"""
        if len(self.moving_avg_score) < 100:
            return
        
        max_score = max(self.moving_avg_score)
        threshold = threshold_percent * max_score
        
        threshold_name = f"{int(threshold_percent*100)}%_max"
        
        if threshold_name not in self.episodes_to_threshold:
            for i, score in enumerate(self.moving_avg_score):
                if score >= threshold:
                    self.episodes_to_threshold[threshold_name] = i + 1
                    print(f"  ✓ Reached {threshold_name} performance at episode {i+1}")
                    break
    
    def get_training_duration(self):
        """Get total training time"""
        if hasattr(self, 'end_time') and self.end_time is not None:
            return self.end_time - self.start_time
        elif hasattr(self, 'start_time') and self.start_time is not None:
            return time.time() - self.start_time
        else:
            return 0

    
    def get_avg_episode_time(self):
        """Get average time per episode"""
        return np.mean(self.episode_times) if self.episode_times else 0
    
    def save(self, filename):
        """Save all metrics to file"""
        with open(filename, 'wb') as f:
            pickle.dump(self.__dict__, f)
        print(f"✅ Metrics saved to {filename}")


def record_episode(env, state_func, Q, epsilon, episode_num, agent_name, seed, 
                   video_path, max_steps=1000):
    """
    Record a single episode to video
    
    Args:
        env: Snake environment
        state_func: Function to get state representation
        Q: Q-table
        epsilon: Current exploration rate
        episode_num: Episode number
        agent_name: Name of the agent
        seed: Random seed
        video_path: Path to save video
        max_steps: Maximum steps per episode
    """
    # Capture first frame to determine video size
    state = env.reset()
    state = state_func(env)
    first_frame = pygame.surfarray.array3d(env.screen)
    first_frame = np.transpose(first_frame, (1, 0, 2))
    first_frame = cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR)
    
    # Overlay height for info text
    overlay_height = 50
    height, width, _ = first_frame.shape
    frame_height = height + overlay_height
    
    # Initialize VideoWriter
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 15
    video_writer = cv2.VideoWriter(video_path, fourcc, fps, (width, frame_height))
    
    state = env.reset()
    state = state_func(env)
    done = False
    steps = 0
    total_reward = 0
    actions_taken = []
    
    while not done and steps < max_steps:
        # Q-values
        if state not in Q:
            Q[state] = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0}
        
        # Epsilon-greedy
        if random.random() < min(epsilon, 0.1):  # less exploration for recording
            action = random.randint(0, 3)
        else:
            action = max(Q[state], key=Q[state].get)
        
        actions_taken.append(action)
        
        # Step in environment
        _, reward, done, _ = env.step(action)
        next_state = state_func(env)
        
        # Capture frame
        frame = pygame.surfarray.array3d(env.screen)
        frame = np.transpose(frame, (1, 0, 2))
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        
        # Create overlay frame
        overlay = np.zeros((frame_height, width, 3), dtype=np.uint8)
        overlay[0:height, :, :] = frame  # original game frame
        
        # Add info text
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.35
        color = (255, 255, 255)
        thickness = 1
        y_offset = height + 15
        
        text_lines = [
            f"Agent: {agent_name}",
            f"Seed: {seed}",
            f"Episode: {episode_num}",
            f"Step: {steps}",
            f"Score: {env.snake.length - 1}"
        ]
        
        for i, text in enumerate(text_lines):
            cv2.putText(overlay, text, (5, y_offset + i*12), 
                        font, font_scale, color, thickness, cv2.LINE_AA)
        
        overlay = np.ascontiguousarray(overlay, dtype=np.uint8)
        
        # Write frame
        video_writer.write(overlay)
        
        state = next_state
        total_reward += reward
        steps += 1
    
    video_writer.release()
    score = env.snake.length - 1
    
    return score, total_reward, steps, actions_taken


def add_info_overlay(frame, episode, steps, score, agent_name, seed):
    """Add informative text overlay to frame"""
    # Make frame larger to fit text
    overlay = np.zeros((200, 150, 3), dtype=np.uint8)
    overlay[0:150, :] = frame
    
    # Add text information
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.35
    color = (255, 255, 255)
    thickness = 1
    
    text_lines = [
        f"Agent: {agent_name}",
        f"Seed: {seed}",
        f"Episode: {episode}",
        f"Step: {steps}",
        f"Score: {score}"
    ]
    
    y_offset = 160
    for i, text in enumerate(text_lines):
        cv2.putText(overlay, text, (5, y_offset + i * 10), 
                   font, font_scale, color, thickness, cv2.LINE_AA)
    
    return overlay


def get_state(env):
    """State representation function"""
    head_x, head_y = env.snake.get_head_position()
    apple_x, apple_y = env.apple.position
    direction = env.snake.direction
    
    def is_danger(dx, dy):
        new_x = head_x + dx * GRIDSIZE
        new_y = head_y + dy * GRIDSIZE
        
        if new_x < 0 or new_x >= 150 or new_y < 0 or new_y >= 150:
            return 1
        if (new_x, new_y) in env.snake.positions[:-1]:
            return 1
        return 0
    
    danger_up = is_danger(0, -GRIDSIZE)
    danger_down = is_danger(0, GRIDSIZE)
    danger_left = is_danger(-GRIDSIZE, 0)
    danger_right = is_danger(GRIDSIZE, 0)
    
    apple_up = int(apple_y < head_y)
    apple_down = int(apple_y > head_y)
    apple_left = int(apple_x < head_x)
    apple_right = int(apple_x > head_x)
    
    dir_up = int(direction == (0, -1))
    dir_down = int(direction == (0, 1))
    dir_left = int(direction == (-1, 0))
    dir_right = int(direction == (1, 0))
    
    return (danger_up, danger_down, danger_left, danger_right,
            apple_up, apple_down, apple_left, apple_right,
            dir_up, dir_down, dir_left, dir_right)


def train_q_learning(seed=42, num_episodes=2000, agent_name="Q-Learning", 
                    record_gameplay=True, record_frequency=500):
    """Train Q-Learning with comprehensive metrics and gameplay recording"""
    
    # Set random seed for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    
    # Create recordings directory
    import os
    recordings_dir = f"recordings_{agent_name.lower().replace(' ', '_')}_seed{seed}"
    os.makedirs(recordings_dir, exist_ok=True)
    
    # Hyperparameters
    alpha = 0.1
    gamma = 0.95
    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.995
    
    # Initialize
    env = SnakeGame()
    Q = {}
    metrics = MetricsLogger(agent_name, seed, record_gameplay, record_frequency)
    metrics.start_time = time.time()     # ✅ start timer

    
    print(f"\n{'='*70}")
    print(f"Training {agent_name} (Seed: {seed})")
    if record_gameplay:
        print(f"Recording gameplay every {record_frequency} episodes to: {recordings_dir}/")
    print(f"{'='*70}\n")
    
    for episode in range(num_episodes):
        episode_start = time.time()
        
        # Check if we should record this episode
        should_record = metrics.should_record_episode(episode + 1)
        
        if should_record:
            print(f"  🎥 Recording episode {episode + 1}...")
            video_path = f"{recordings_dir}/episode_{episode+1:05d}.mp4"
            score, total_reward, steps, actions_taken = record_episode(
                env, get_state, Q, epsilon, episode + 1, agent_name, seed, video_path
            )
            metrics.recorded_episodes.append(episode + 1)
            episode_time = time.time() - episode_start
            
        else:
            # Normal training without recording
            state = env.reset()
            state = get_state(env)
            done = False
            total_reward = 0
            steps = 0
            actions_taken = []
            
            while not done and steps < 1000:
                # Initialize Q-values
                if state not in Q:
                    Q[state] = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0}
                
                # Epsilon-greedy
                if random.random() < epsilon:
                    action = random.randint(0, 3)
                else:
                    action = max(Q[state], key=Q[state].get)
                
                actions_taken.append(action)
                
                # Take step
                _, reward, done, info = env.step(action)
                next_state = get_state(env)
                
                if next_state not in Q:
                    Q[next_state] = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0}
                
                # Q-Learning update
                best_next = max(Q[next_state].values()) if not done else 0
                Q[state][action] += alpha * (reward + gamma * best_next - Q[state][action])
                
                state = next_state
                total_reward += reward
                steps += 1
            
            score = env.snake.length - 1
            episode_time = time.time() - episode_start
            
        metrics.end_time = time.time()       # ✅ end timer

        
        # Check if this is the best episode so far
        if metrics.update_best_episode(episode + 1, score):
            print(f"  🏆 New best score: {score} at episode {episode + 1}")
        
        # Decay epsilon
        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        
        # Log metrics
        metrics.log_episode(total_reward, score, steps, epsilon, episode_time, actions_taken)
        
        # Check convergence thresholds
        if (episode + 1) % 100 == 0:
            metrics.check_convergence_threshold(0.5)
            metrics.check_convergence_threshold(0.8)
            metrics.check_convergence_threshold(0.9)
        
        # Log block statistics
        if (episode + 1) % 1000 == 0:
            metrics.log_block_statistics(1000)
        
        # Print progress
        if (episode + 1) % 100 == 0:
            avg_score = np.mean(metrics.episode_scores[-100:])
            avg_reward = np.mean(metrics.episode_rewards[-100:])
            max_score = max(metrics.episode_scores[-100:])
            print(f"Episode {episode+1:5d} | "
                  f"Avg Score: {avg_score:5.2f} | "
                  f"Max Score: {max_score:3d} | "
                  f"Avg Reward: {avg_reward:7.2f} | "
                  f"ε: {epsilon:.3f} | "
                  f"States: {len(Q):5d} | "
                  f"Time: {metrics.get_avg_episode_time()*1000:.1f}ms/ep")
    
    # Record the best episode one more time at the end
    if record_gameplay and metrics.best_episode_number > 0:
        print(f"\n🎬 Recording best episode replay (Episode {metrics.best_episode_number}, Score: {metrics.best_episode_score})...")
        best_video_path = f"{recordings_dir}/BEST_episode_{metrics.best_episode_number:05d}_score_{metrics.best_episode_score}.mp4"
        record_episode(env, get_state, Q, 0.0, metrics.best_episode_number, 
                      agent_name, seed, best_video_path)
    
    # Final statistics
    total_time = metrics.get_training_duration()
    print(f"\n{'='*70}")
    print(f"Training Complete!")
    print(f"  Total time: {total_time/60:.2f} minutes")
    print(f"  Avg time per episode: {metrics.get_avg_episode_time()*1000:.2f}ms")
    print(f"  Best score: {max(metrics.episode_scores)}")
    print(f"  Final avg score (last 100): {np.mean(metrics.episode_scores[-100:]):.2f}")
    print(f"  Total states explored: {len(Q)}")
    print(f"  Convergence milestones: {metrics.episodes_to_threshold}")
    if record_gameplay:
        print(f"  Recorded episodes: {len(metrics.recorded_episodes)}")
        print(f"  Videos saved to: {recordings_dir}/")
    print(f"{'='*70}\n")
    
    # Save Q-table and metrics
    with open(f'q_table_seed{seed}.pkl', 'wb') as f:
        pickle.dump(Q, f)
    
    metrics.save(f'metrics_{agent_name.lower().replace(" ", "_")}_seed{seed}.pkl')
    
    # 🎬 Record final greedy gameplay after full training
    print("\n🎬 Recording final gameplay with trained Q-Learning agent...")
    env = SnakeGame()
    final_video_path = f"{recordings_dir}/FINAL_gameplay_seed{seed}.mp4"
    record_episode(
        env, get_state, Q,
        epsilon=0.0,                     # no exploration
        episode_num=num_episodes,        # tag as final episode
        agent_name="Q-Learning-Final",
        seed=seed,
        video_path=final_video_path
    )
    print(f"✅ Final gameplay recording complete! Saved to {final_video_path}")
    
    return Q, metrics


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    # Train with multiple seeds and automatic gameplay recording
    seeds = [42, 123, 456]
    
    all_metrics = []
    
    for seed in seeds:
        Q, metrics = train_q_learning(
            seed=seed, 
            num_episodes=2000, 
            agent_name="Q-Learning",
            record_gameplay=True,  # Enable automatic recording
            record_frequency=500   # Record every 500 episodes
        )
        all_metrics.append(metrics)
        
        # Small break between runs
        time.sleep(2)
    
    print("\n🎉 All training runs complete!")
    print(f"Metrics saved for {len(seeds)} different seeds")
    print("Check the recordings_* directories for gameplay videos!")


Training Q-Learning (Seed: 42)
Recording gameplay every 500 episodes to: recordings_q-learning_seed42/

  🏆 New best score: 1 at episode 1
  🏆 New best score: 2 at episode 39


In [None]:
import cv2
import numpy as np
from pathlib import Path

# =========================
# Safe text overlay
# =========================
def safe_put_text(img, text, start_x, start_y, font_scale=0.4,
                  color=(255,255,255), thickness=1, line_height=15, max_chars_per_line=30):
    """
    Safely put text on a frame (ensures contiguous array and correct type)
    """
    # Ensure img is contiguous and uint8
    img = np.ascontiguousarray(img, dtype=np.uint8)

    # Split long text into lines
    lines = [text[i:i+max_chars_per_line] for i in range(0, len(text), max_chars_per_line)]
    for i, line in enumerate(lines):
        y = int(start_y + i*line_height)
        cv2.putText(img, line, (int(start_x), y),
                    cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness, cv2.LINE_AA)
    return img

# =========================
# Training progression video
# =========================
def create_training_progression_video(recordings_dir, output_path='training_progression.mp4'):
    """
    Create a compilation showing training progression
    """
    video_files = sorted(Path(recordings_dir).glob("episode_*.mp4"))
    
    if not video_files:
        print(f"❌ No videos found in {recordings_dir}")
        return
    
    print(f"Found {len(video_files)} recorded episodes")
    print(f"Creating progression video: {output_path}")
    
    # Read first video to get dimensions and FPS
    first_video = cv2.VideoCapture(str(video_files[0]))
    width = int(first_video.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(first_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(first_video.get(cv2.CAP_PROP_FPS))
    first_video.release()
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    for video_file in video_files:
        print(f"  Processing {video_file.name}...")
        cap = cv2.VideoCapture(str(video_file))
        
        # Safe title frame
        title_frame = np.zeros((height, width, 3), dtype=np.uint8)
        title_frame = np.ascontiguousarray(title_frame, dtype=np.uint8)
        
        episode_number = video_file.stem.split('_')[1]
        text = f"Episode {episode_number}"
        
        # Center text
        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0]
        text_x = (width - text_size[0]) // 2
        text_y = (height + text_size[1]) // 2
        
        # Draw text safely
        title_frame = safe_put_text(title_frame, text, text_x, text_y, font_scale=0.8, thickness=2, max_chars_per_line=50)
        
        # Write title frame for 1 second
        for _ in range(fps):
            output_video.write(title_frame)
        
        # Write all frames from this episode
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = np.ascontiguousarray(frame, dtype=np.uint8)
            output_video.write(frame)
        
        cap.release()
        
        # Black separator for 0.5 seconds
        black_frame = np.zeros((height, width, 3), dtype=np.uint8)
        black_frame = np.ascontiguousarray(black_frame, dtype=np.uint8)
        for _ in range(fps // 2):
            output_video.write(black_frame)
    
    output_video.release()
    print(f"✅ Compilation video saved: {output_path}")

# =========================
# Best episodes compilation
# =========================
def create_best_episodes_compilation(seed_list=[42, 123, 456], agent_name="Q-Learning"):
    output_path = f'best_episodes_{agent_name.lower().replace(" ", "_")}.mp4'
    print(f"Creating best episodes compilation for {agent_name}...")

    best_videos = []
    for seed in seed_list:
        recordings_dir = f"recordings_{agent_name.lower().replace(' ', '_')}_seed{seed}"
        best_video = list(Path(recordings_dir).glob("BEST_*.mp4"))
        if best_video:
            best_videos.extend(best_video)

    if not best_videos:
        print("❌ No best episode videos found")
        return

    # Get dimensions
    cap = cv2.VideoCapture(str(best_videos[0]))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    cap.release()

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for video_file in best_videos:
        print(f"  Adding {video_file.name}...")

        # Safe title frame
        title_frame = np.zeros((height, width, 3), dtype=np.uint8)
        title_frame = np.ascontiguousarray(title_frame, dtype=np.uint8)
        title_text = video_file.stem.replace('_', ' ').title()
        title_frame = safe_put_text(title_frame, title_text, width//8, height//2, font_scale=0.6, thickness=2)

        for _ in range(fps):
            output_video.write(title_frame)

        # Video frames
        cap = cv2.VideoCapture(str(video_file))
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = np.ascontiguousarray(frame, dtype=np.uint8)
            output_video.write(frame)
        cap.release()

        # Black separator for 1 second
        black_frame = np.zeros((height, width, 3), dtype=np.uint8)
        black_frame = np.ascontiguousarray(black_frame, dtype=np.uint8)
        for _ in range(fps):
            output_video.write(black_frame)

    output_video.release()
    print(f"✅ Best episodes compilation saved: {output_path}")


# =========================
# MAIN EXECUTION
# =========================
if __name__ == "__main__":
    print("\n" + "="*80)
    print("VIDEO COMPILATION CREATOR")
    print("="*80 + "\n")

    # Create progression videos for each seed
    for seed in [42, 123, 456]:
        recordings_dir = f"recordings_q-learning_seed{seed}"
        if Path(recordings_dir).exists():
            create_training_progression_video(recordings_dir, f'progression_seed{seed}.mp4')

    # Create best episodes compilation
    create_best_episodes_compilation([42, 123, 456], "Q-Learning")

    print("\n🎉 All compilation videos created!")



VIDEO COMPILATION CREATOR

Found 5 recorded episodes
Creating progression video: progression_seed42.mp4
  Processing episode_00200.mp4...
  Processing episode_00400.mp4...
  Processing episode_00600.mp4...
  Processing episode_00800.mp4...
  Processing episode_01000.mp4...
✅ Compilation video saved: progression_seed42.mp4
Found 5 recorded episodes
Creating progression video: progression_seed123.mp4
  Processing episode_00200.mp4...
  Processing episode_00400.mp4...
  Processing episode_00600.mp4...
  Processing episode_00800.mp4...
  Processing episode_01000.mp4...
✅ Compilation video saved: progression_seed123.mp4
Found 5 recorded episodes
Creating progression video: progression_seed456.mp4
  Processing episode_00200.mp4...
  Processing episode_00400.mp4...
  Processing episode_00600.mp4...
  Processing episode_00800.mp4...
  Processing episode_01000.mp4...
✅ Compilation video saved: progression_seed456.mp4
Creating best episodes compilation for Q-Learning...
  Adding BEST_episode_00

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# =============================
# Self-contained MetricsLogger
# =============================
class MetricsLogger:
    def __init__(self, agent_name, seed):
        self.agent_name = agent_name
        self.seed = seed
        
        # Placeholders for metrics (fill with actual data from pickle)
        self.episode_scores = []
        self.moving_avg_reward = []
        self.moving_avg_score = []
        self.episodes_to_threshold = {}
        self.epsilon_history = []
        self.reward_variance_per_block = []
        self.score_variance_per_block = []
        self.action_counts = {0:0, 1:0, 2:0, 3:0}
        self.episode_lengths = []
        self.cpu_usage = []
        self.memory_usage = []
        self.episode_rewards = []

    def get_training_duration(self):
        return getattr(self, 'training_duration', 0)

    def get_avg_episode_time(self):
        return getattr(self, 'avg_episode_time', 0)

# =============================
# Load metrics
# =============================
def load_metrics(pattern="metrics_q-learning_seed*.pkl"):
    metrics_list = []
    for file in Path('.').glob(pattern):
        with open(file, 'rb') as f:
            metrics = pickle.load(f)
            m = MetricsLogger(metrics['agent_name'], metrics['seed'])
            m.__dict__.update(metrics)
            metrics_list.append(m)
    return metrics_list


def plot_sample_efficiency(metrics_list, save_path='1_sample_efficiency.png'):
    """
    METRIC 1: Sample Efficiency Analysis
    Shows how fast each agent learns
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('1. Sample Efficiency Analysis', fontsize=16, fontweight='bold')
    
    colors = ['#2E86AB', '#A23B72', '#F18F01']
    
    # Plot 1: Reward vs Episodes (all seeds)
    ax = axes[0, 0]
    for i, m in enumerate(metrics_list):
        episodes = range(1, len(m.moving_avg_reward) + 1)
        ax.plot(episodes, m.moving_avg_reward, 
               label=f'Seed {m.seed}', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode')
    ax.set_ylabel('Moving Avg Reward (window=100)')
    ax.set_title('Learning Curves Across Seeds')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Score vs Episodes (all seeds)
    ax = axes[0, 1]
    for i, m in enumerate(metrics_list):
        episodes = range(1, len(m.moving_avg_score) + 1)
        ax.plot(episodes, m.moving_avg_score,
               label=f'Seed {m.seed}', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode')
    ax.set_ylabel('Moving Avg Score (window=100)')
    ax.set_title('Score Progression Across Seeds')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Episodes to Convergence
    ax = axes[1, 0]
    thresholds = ['50%_max', '80%_max', '90%_max']
    threshold_data = {t: [] for t in thresholds}
    
    for m in metrics_list:
        for t in thresholds:
            value = m.episodes_to_threshold.get(t, np.nan)
            threshold_data[t].append(value)
    
    x_pos = np.arange(len(thresholds))
    means = [np.nanmean(threshold_data[t]) for t in thresholds]
    stds = [np.nanstd(threshold_data[t]) for t in thresholds]
    
    ax.bar(x_pos, means, yerr=stds, capsize=5, color='#2E86AB', alpha=0.7)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(thresholds)
    ax.set_ylabel('Episodes Required')
    ax.set_title('Sample Efficiency: Episodes to Reach Performance Thresholds')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: Cumulative max score
    ax = axes[1, 1]
    for i, m in enumerate(metrics_list):
        cumulative_max = np.maximum.accumulate(m.episode_scores)
        episodes = range(1, len(cumulative_max) + 1)
        ax.plot(episodes, cumulative_max,
               label=f'Seed {m.seed}', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode')
    ax.set_ylabel('Cumulative Max Score')
    ax.set_title('Best Performance Over Time')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {save_path}")
    plt.close()


def plot_exploration_stability(metrics_list, save_path='2_exploration_stability.png'):
    """
    METRIC 2: Exploration Stability
    Shows how exploration affects learning stability
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('2. Exploration & Behavioral Stability', fontsize=16, fontweight='bold')
    
    colors = ['#2E86AB', '#A23B72', '#F18F01']
    
    # Plot 1: Epsilon decay with reward overlay
    ax = axes[0, 0]
    ax2 = ax.twinx()
    
    for i, m in enumerate(metrics_list):
        episodes = range(1, len(m.epsilon_history) + 1)
        ax.plot(episodes, m.epsilon_history,
               label=f'ε (Seed {m.seed})', alpha=0.5, linestyle='--', color=colors[i % len(colors)])
        ax2.plot(episodes, m.moving_avg_reward,
                label=f'Reward (Seed {m.seed})', alpha=0.7, color=colors[i % len(colors)])
    
    ax.set_xlabel('Episode')
    ax.set_ylabel('Epsilon (Exploration Rate)', color='black')
    ax2.set_ylabel('Moving Avg Reward', color='black')
    ax.set_title('Exploration Decay vs Learning Performance')
    ax.legend(loc='upper left')
    ax2.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Reward variance per block
    ax = axes[0, 1]
    for i, m in enumerate(metrics_list):
        blocks = range(1, len(m.reward_variance_per_block) + 1)
        block_episodes = [b * 1000 for b in blocks]
        ax.plot(block_episodes, m.reward_variance_per_block,
               label=f'Seed {m.seed}', marker='o', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode (×1000)')
    ax.set_ylabel('Reward Variance')
    ax.set_title('Learning Stability: Reward Variance per 1000 Episodes')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Score variance per block
    ax = axes[1, 0]
    for i, m in enumerate(metrics_list):
        blocks = range(1, len(m.score_variance_per_block) + 1)
        block_episodes = [b * 1000 for b in blocks]
        ax.plot(block_episodes, m.score_variance_per_block,
               label=f'Seed {m.seed}', marker='o', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode (×1000)')
    ax.set_ylabel('Score Variance')
    ax.set_title('Performance Stability: Score Variance per 1000 Episodes')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Action distribution (last seed only for clarity)
    ax = axes[1, 1]
    m = metrics_list[-1]  # Use last seed
    action_names = ['UP (0)', 'DOWN (1)', 'LEFT (2)', 'RIGHT (3)']
    total_actions = sum(m.action_counts.values())
    action_percentages = [m.action_counts[i] / total_actions * 100 for i in range(4)]
    
    ax.bar(action_names, action_percentages, color='#2E86AB', alpha=0.7)
    ax.set_ylabel('Percentage of Total Actions (%)')
    ax.set_title(f'Action Distribution (Seed {m.seed})')
    ax.grid(True, alpha=0.3, axis='y')
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {save_path}")
    plt.close()


def plot_computational_efficiency(metrics_list, save_path='3_computational_efficiency.png'):
    """
    METRIC 3: Runtime / Computational Efficiency
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('3. Computational Efficiency', fontsize=16, fontweight='bold')
    
    colors = ['#2E86AB', '#A23B72', '#F18F01']
    
    # Plot 1: Training time comparison
    ax = axes[0, 0]
    seeds = [m.seed for m in metrics_list]
    training_times = [m.get_training_duration() / 60 for m in metrics_list]  # in minutes
    
    ax.bar(range(len(seeds)), training_times, color='#2E86AB', alpha=0.7)
    ax.set_xticks(range(len(seeds)))
    ax.set_xticklabels([f'Seed {s}' for s in seeds])
    ax.set_ylabel('Training Time (minutes)')
    ax.set_title('Total Training Duration')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add mean line
    mean_time = np.mean(training_times)
    ax.axhline(mean_time, color='red', linestyle='--', label=f'Mean: {mean_time:.2f} min')
    ax.legend()
    
    # Plot 2: Average time per episode
    ax = axes[0, 1]
    avg_times = [m.get_avg_episode_time() * 1000 for m in metrics_list]  # in ms
    
    ax.bar(range(len(seeds)), avg_times, color='#A23B72', alpha=0.7)
    ax.set_xticks(range(len(seeds)))
    ax.set_xticklabels([f'Seed {s}' for s in seeds])
    ax.set_ylabel('Time per Episode (ms)')
    ax.set_title('Average Episode Duration')
    ax.grid(True, alpha=0.3, axis='y')
    
    mean_time_ep = np.mean(avg_times)
    ax.axhline(mean_time_ep, color='red', linestyle='--', label=f'Mean: {mean_time_ep:.2f} ms')
    ax.legend()
    
    # Plot 3: CPU usage over time
    ax = axes[1, 0]
    for i, m in enumerate(metrics_list):
        # Sample every 100 episodes for clarity
        sampled_episodes = range(0, len(m.cpu_usage), 100)
        sampled_cpu = [m.cpu_usage[i] for i in sampled_episodes]
        ax.plot(sampled_episodes, sampled_cpu,
               label=f'Seed {m.seed}', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode')
    ax.set_ylabel('CPU Usage (%)')
    ax.set_title('CPU Utilization During Training')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Memory usage over time
    ax = axes[1, 1]
    for i, m in enumerate(metrics_list):
        sampled_episodes = range(0, len(m.memory_usage), 100)
        sampled_mem = [m.memory_usage[i] for i in sampled_episodes]
        ax.plot(sampled_episodes, sampled_mem,
               label=f'Seed {m.seed}', alpha=0.7, color=colors[i % len(colors)])
    ax.set_xlabel('Episode')
    ax.set_ylabel('Memory Usage (%)')
    ax.set_title('Memory Utilization During Training')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {save_path}")
    plt.close()


def plot_convergence_stability(metrics_list, save_path='4_convergence_stability.png'):
    """
    METRIC 4: Convergence and Stability Visualization
    Shows how consistently the agent converges across multiple runs
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('4. Convergence & Stability Analysis', fontsize=16, fontweight='bold')
    
    colors = ['#2E86AB', '#A23B72', '#F18F01']
    
    # Plot 1: All reward curves together
    ax = axes[0, 0]
    for i, m in enumerate(metrics_list):
        episodes = range(1, len(m.moving_avg_reward) + 1)
        ax.plot(episodes, m.moving_avg_reward,
               label=f'Seed {m.seed}', alpha=0.7, linewidth=2, color=colors[i % len(colors)])
    
    # Add mean and std dev band
    max_len = max(len(m.moving_avg_reward) for m in metrics_list)
    all_rewards = np.zeros((len(metrics_list), max_len))
    for i, m in enumerate(metrics_list):
        all_rewards[i, :len(m.moving_avg_reward)] = m.moving_avg_reward
        if len(m.moving_avg_reward) < max_len:
            all_rewards[i, len(m.moving_avg_reward):] = m.moving_avg_reward[-1]
    
    mean_reward = np.mean(all_rewards, axis=0)
    std_reward = np.std(all_rewards, axis=0)
    episodes = range(1, max_len + 1)
    
    ax.fill_between(episodes, mean_reward - std_reward, mean_reward + std_reward,
                    alpha=0.2, color='gray', label='±1 Std Dev')
    ax.plot(episodes, mean_reward, 'k--', linewidth=2, label='Mean')
    
    ax.set_xlabel('Episode')
    ax.set_ylabel('Moving Avg Reward')
    ax.set_title('Reward Convergence Across Multiple Seeds')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: All score curves together
    ax = axes[0, 1]
    for i, m in enumerate(metrics_list):
        episodes = range(1, len(m.moving_avg_score) + 1)
        ax.plot(episodes, m.moving_avg_score,
               label=f'Seed {m.seed}', alpha=0.7, linewidth=2, color=colors[i % len(colors)])
    
    # Add mean and std dev band
    all_scores = np.zeros((len(metrics_list), max_len))
    for i, m in enumerate(metrics_list):
        all_scores[i, :len(m.moving_avg_score)] = m.moving_avg_score
        if len(m.moving_avg_score) < max_len:
            all_scores[i, len(m.moving_avg_score):] = m.moving_avg_score[-1]
    
    mean_score = np.mean(all_scores, axis=0)
    std_score = np.std(all_scores, axis=0)
    
    ax.fill_between(episodes, mean_score - std_score, mean_score + std_score,
                    alpha=0.2, color='gray', label='±1 Std Dev')
    ax.plot(episodes, mean_score, 'k--', linewidth=2, label='Mean')
    
    ax.set_xlabel('Episode')
    ax.set_ylabel('Moving Avg Score')
    ax.set_title('Score Convergence Across Multiple Seeds')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Variance over time
    ax = axes[1, 0]
    
    window = 100
    variance_over_time = []
    for i in range(0, max_len, window):
        window_scores = all_scores[:, i:min(i+window, max_len)]
        variance_over_time.append(np.mean(np.var(window_scores, axis=0)))
    
    episodes_binned = range(window, max_len + 1, window)
    ax.plot(episodes_binned, variance_over_time, marker='o', color='#2E86AB', linewidth=2)
    ax.set_xlabel('Episode')
    ax.set_ylabel(f'Mean Variance (per {window} episodes)')
    ax.set_title('Stability: Inter-Seed Variance Over Training')
    ax.grid(True, alpha=0.3)
    
    # Plot 4: Final performance comparison
    ax = axes[1, 1]
    
    final_scores = [np.mean(m.episode_scores[-100:]) for m in metrics_list]
    final_rewards = [np.mean(m.episode_rewards[-100:]) for m in metrics_list]
    seeds = [m.seed for m in metrics_list]
    
    x_pos = np.arange(len(seeds))
    width = 0.35
    
    ax.bar(x_pos - width/2, final_scores, width, label='Avg Score (last 100)', color='#2E86AB', alpha=0.7)
    ax.bar(x_pos + width/2, final_rewards, width, label='Avg Reward (last 100)', color='#F18F01', alpha=0.7)
    
    ax.set_xticks(x_pos)
    ax.set_xticklabels([f'Seed {s}' for s in seeds])
    ax.set_ylabel('Performance')
    ax.set_title('Final Performance Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {save_path}")
    plt.close()


def plot_policy_behavior(metrics_list, save_path='5_policy_behavior.png'):
    """
    METRIC 5: Policy Behavior Check
    Shows how the policy evolves over time
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle
    fig.suptitle('5. Policy Behavior & Evolution', fontsize=16, fontweight='bold')
    
    colors = ['#2E86AB', '#A23B72', '#F18F01']
    
    # Plot 1: Average survival time over training
    ax = axes[0, 0]
    for i, m in enumerate(metrics_list):
        # Calculate moving average of episode lengths (survival time)
        window = 100
        if len(m.episode_lengths) >= window:
            moving_avg_length = [np.mean(m.episode_lengths[max(0, j-window):j+1]) 
                                for j in range(len(m.episode_lengths))]
        else:
            moving_avg_length = m.episode_lengths
        
        episodes = range(1, len(moving_avg_length) + 1)
        ax.plot(episodes, moving_avg_length,
               label=f'Seed {m.seed}', alpha=0.7, color=colors[i % len(colors)])
    
    ax.set_xlabel('Episode')
    ax.set_ylabel('Average Steps Survived (window=100)')
    ax.set_title('Policy Evolution: Survival Time Over Training')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Early vs Late behavior comparison
    ax = axes[0, 1]
    
    # Compare first 500 episodes vs last 500 episodes
    early_window = slice(0, 500)
    late_window = slice(-500, None)
    
    comparison_data = {
        'Early Training\n(Episodes 1-500)': [],
        'Late Training\n(Last 500 episodes)': []
    }
    
    for m in metrics_list:
        comparison_data['Early Training\n(Episodes 1-500)'].append(
            np.mean(m.episode_scores[early_window])
        )
        comparison_data['Late Training\n(Last 500 episodes)'].append(
            np.mean(m.episode_scores[late_window])
        )
    
    x_pos = np.arange(len(comparison_data))
    means = [np.mean(comparison_data[k]) for k in comparison_data.keys()]
    stds = [np.std(comparison_data[k]) for k in comparison_data.keys()]
    
    ax.bar(x_pos, means, yerr=stds, capsize=5, color=['#E63946', '#06FFA5'], alpha=0.7)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(comparison_data.keys())
    ax.set_ylabel('Average Score')
    ax.set_title('Behavioral Shift: Early vs Late Training')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add significance annotation
    improvement = ((means[1] - means[0]) / means[0] * 100) if means[0] > 0 else 0
    ax.text(0.5, max(means) * 0.9, f'+{improvement:.1f}% improvement', 
           ha='center', fontsize=12, fontweight='bold',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Plot 3: Action distribution evolution
    ax = axes[1, 0]
    
    # Use first seed for detailed analysis
    m = metrics_list[0]
    
    # Divide training into 4 phases
    num_episodes = len(m.episode_scores)
    phase_size = num_episodes // 4
    phases = ['Early\n(0-25%)', 'Mid-Early\n(25-50%)', 'Mid-Late\n(50-75%)', 'Late\n(75-100%)']
    
    # We need to reconstruct action distribution per phase
    # Since we only have total counts, we'll show the final distribution
    # and note that this is a limitation
    
    action_names = ['UP', 'DOWN', 'LEFT', 'RIGHT']
    total_actions = sum(m.action_counts.values())
    action_percentages = [m.action_counts[i] / total_actions * 100 for i in range(4)]
    
    colors_actions = ['#E63946', '#F1FAEE', '#A8DADC', '#457B9D']
    ax.bar(action_names, action_percentages, color=colors_actions, alpha=0.7)
    ax.set_ylabel('Percentage of Total Actions (%)')
    ax.set_title(f'Action Distribution (Seed {m.seed})')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add balanced line
    ax.axhline(25, color='red', linestyle='--', linewidth=1, label='Balanced (25%)')
    ax.legend()
    
    # Plot 4: Performance consistency (success rate per block)
    ax = axes[1, 1]
    
    # Calculate "success rate" as percentage of episodes with score > 0
    block_size = 500
    
    for i, m in enumerate(metrics_list):
        success_rates = []
        block_episodes = []
        
        for start in range(0, len(m.episode_scores), block_size):
            end = min(start + block_size, len(m.episode_scores))
            block_scores = m.episode_scores[start:end]
            success_rate = sum(1 for s in block_scores if s > 0) / len(block_scores) * 100
            success_rates.append(success_rate)
            block_episodes.append(start + block_size // 2)
        
        ax.plot(block_episodes, success_rates,
               label=f'Seed {m.seed}', marker='o', alpha=0.7, color=colors[i % len(colors)])
    
    ax.set_xlabel('Episode')
    ax.set_ylabel('Success Rate (% with score > 0)')
    ax.set_title(f'Learning Progress: Success Rate per {block_size} Episodes')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 105])
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {save_path}")
    plt.close()


def generate_summary_report(metrics_list, agent_name="Q-Learning", save_path='summary_report.txt'):
    """
    Generate a comprehensive text summary of all metrics
    """
    report = []
    report.append("="*80)
    report.append(f"COMPREHENSIVE BENCHMARKING REPORT: {agent_name}")
    report.append("="*80)
    report.append("")
    
    # 1. Sample Efficiency
    report.append("1. SAMPLE EFFICIENCY ANALYSIS")
    report.append("-" * 80)
    
    for m in metrics_list:
        report.append(f"  Seed {m.seed}:")
        report.append(f"    - Final Avg Score (last 100): {np.mean(m.episode_scores[-100:]):.2f}")
        report.append(f"    - Best Score: {max(m.episode_scores)}")
        report.append(f"    - Episodes to 80% max: {m.episodes_to_threshold.get('80%_max', 'N/A')}")
    
    # Calculate aggregate statistics
    final_scores = [np.mean(m.episode_scores[-100:]) for m in metrics_list]
    best_scores = [max(m.episode_scores) for m in metrics_list]
    
    report.append(f"  Aggregate Statistics:")
    report.append(f"    - Mean final score: {np.mean(final_scores):.2f} ± {np.std(final_scores):.2f}")
    report.append(f"    - Mean best score: {np.mean(best_scores):.2f} ± {np.std(best_scores):.2f}")
    report.append("")
    
    # 2. Exploration Stability
    report.append("2. EXPLORATION & STABILITY ANALYSIS")
    report.append("-" * 80)
    
    for m in metrics_list:
        if m.reward_variance_per_block:
            avg_variance = np.mean(m.reward_variance_per_block)
            report.append(f"  Seed {m.seed}:")
            report.append(f"    - Average reward variance: {avg_variance:.2f}")
            report.append(f"    - Final epsilon: {m.epsilon_history[-1]:.4f}")
    
    report.append("")
    
    # 3. Computational Efficiency
    report.append("3. COMPUTATIONAL EFFICIENCY")
    report.append("-" * 80)
    
    training_times = [m.get_training_duration() / 60 for m in metrics_list]
    episode_times = [m.get_avg_episode_time() * 1000 for m in metrics_list]
    
    report.append(f"  Training Duration:")
    report.append(f"    - Mean: {np.mean(training_times):.2f} minutes")
    report.append(f"    - Std: {np.std(training_times):.2f} minutes")
    report.append(f"  Episode Duration:")
    report.append(f"    - Mean: {np.mean(episode_times):.2f} ms/episode")
    report.append(f"    - Std: {np.std(episode_times):.2f} ms/episode")
    
    # CPU and Memory
    avg_cpu = [np.mean(m.cpu_usage) for m in metrics_list]
    avg_memory = [np.mean(m.memory_usage) for m in metrics_list]
    
    report.append(f"  Resource Usage:")
    report.append(f"    - Mean CPU usage: {np.mean(avg_cpu):.1f}%")
    report.append(f"    - Mean Memory usage: {np.mean(avg_memory):.1f}%")
    report.append("")
    
    # 4. Convergence Stability
    report.append("4. CONVERGENCE & STABILITY")
    report.append("-" * 80)
    
    # Calculate coefficient of variation across seeds
    if len(metrics_list) > 1:
        final_score_cv = np.std(final_scores) / np.mean(final_scores) * 100
        report.append(f"  Cross-seed variability:")
        report.append(f"    - Coefficient of Variation (CV): {final_score_cv:.2f}%")
        
        if final_score_cv < 10:
            report.append(f"    - Interpretation: HIGHLY STABLE (CV < 10%)")
        elif final_score_cv < 20:
            report.append(f"    - Interpretation: MODERATELY STABLE (10% ≤ CV < 20%)")
        else:
            report.append(f"    - Interpretation: VARIABLE (CV ≥ 20%)")
    
    report.append("")
    
    # 5. Policy Behavior
    report.append("5. POLICY BEHAVIOR EVOLUTION")
    report.append("-" * 80)
    
    for m in metrics_list:
        early_score = np.mean(m.episode_scores[:500])
        late_score = np.mean(m.episode_scores[-500:])
        improvement = ((late_score - early_score) / early_score * 100) if early_score > 0 else 0
        
        early_survival = np.mean(m.episode_lengths[:500])
        late_survival = np.mean(m.episode_lengths[-500:])
        
        report.append(f"  Seed {m.seed}:")
        report.append(f"    - Early score (first 500): {early_score:.2f}")
        report.append(f"    - Late score (last 500): {late_score:.2f}")
        report.append(f"    - Improvement: {improvement:.1f}%")
        report.append(f"    - Early survival: {early_survival:.1f} steps")
        report.append(f"    - Late survival: {late_survival:.1f} steps")
    
    report.append("")
    report.append("="*80)
    report.append("END OF REPORT")
    report.append("="*80)
    
    # Save report
    report_text = "\n".join(report)
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write(report_text)
    
    print(f"✅ Saved: {save_path}")
    print("\n" + report_text)
    
    return report_text


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    print("\n" + "="*80)
    print("COMPREHENSIVE METRICS VISUALIZATION")
    print("="*80 + "\n")
    
    # Load metrics
    print("Loading metrics files...")
    metrics_list = load_metrics("metrics_q-learning_seed*.pkl")
    
    if not metrics_list:
        print("❌ No metrics files found! Run train_with_metrics.py first.")
        exit(1)
    
    print(f"✅ Loaded {len(metrics_list)} metric files\n")
    
    # Generate all visualizations
    print("Generating visualizations...\n")
    
    plot_sample_efficiency(metrics_list, '1_sample_efficiency.png')
    plot_exploration_stability(metrics_list, '2_exploration_stability.png')
    plot_computational_efficiency(metrics_list, '3_computational_efficiency.png')
    plot_convergence_stability(metrics_list, '4_convergence_stability.png')
    plot_policy_behavior(metrics_list, '5_policy_behavior.png')
    
    # Generate summary report
    print("\nGenerating summary report...\n")
    generate_summary_report(metrics_list, agent_name="Q-Learning", save_path='summary_report.txt')
    
    print("\n" + "="*80)
    print("🎉 ALL VISUALIZATIONS COMPLETE!")
    print("="*80)
    print("\nGenerated files:")
    print("  1. 1_sample_efficiency.png")
    print("  2. 2_exploration_stability.png")
    print("  3. 3_computational_efficiency.png")
    print("  4. 4_convergence_stability.png")
    print("  5. 5_policy_behavior.png")
    print("  6. summary_report.txt")
    print("\nThese visualizations are ready for your paper/presentation!")


COMPREHENSIVE METRICS VISUALIZATION

Loading metrics files...
✅ Loaded 3 metric files

Generating visualizations...

✅ Saved: 1_sample_efficiency.png
✅ Saved: 2_exploration_stability.png
✅ Saved: 3_computational_efficiency.png
✅ Saved: 4_convergence_stability.png
✅ Saved: 5_policy_behavior.png

Generating summary report...

✅ Saved: summary_report.txt

COMPREHENSIVE BENCHMARKING REPORT: Q-Learning

1. SAMPLE EFFICIENCY ANALYSIS
--------------------------------------------------------------------------------
  Seed 123:
    - Final Avg Score (last 100): 1.12
    - Best Score: 21
    - Episodes to 80% max: 1
  Seed 42:
    - Final Avg Score (last 100): 0.46
    - Best Score: 13
    - Episodes to 80% max: 1
  Seed 456:
    - Final Avg Score (last 100): 0.68
    - Best Score: 18
    - Episodes to 80% max: 55
  Aggregate Statistics:
    - Mean final score: 0.75 ± 0.27
    - Mean best score: 17.33 ± 3.30

2. EXPLORATION & STABILITY ANALYSIS
--------------------------------------------------

In [None]:
# # ...existing code...
# def record_episode(env, state_func, Q, epsilon, episode_num, agent_name, seed,
#                    video_path, max_steps=1000):
#     """
#     Evaluation recorder: keeps the last valid score even if the env auto-resets on death.
#     """
#     # Reset and size video from first frame
#     env.reset()
#     _ = state_func(env)

#     surf = getattr(env, "screen", None) or getattr(env, "surface", None)
#     assert surf is not None, "SnakeGame has no screen/surface attribute"

#     first_frame = pygame.surfarray.array3d(surf)
#     first_frame = np.transpose(first_frame, (1, 0, 2))
#     first_frame = cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR)

#     overlay_h = 50
#     h, w, _ = first_frame.shape
#     frame_h = h + overlay_h

#     fourcc = cv2.VideoWriter_fourcc(*'mp4v')
#     fps = 15
#     vw = cv2.VideoWriter(video_path, fourcc, fps, (w, frame_h))

#     state = state_func(env)
#     done = False
#     steps = 0
#     total_reward = 0.0
#     actions_taken = []
#     last_score = max(0, getattr(env.snake, "length", 1) - 1)

#     while not done and steps < max_steps:
#         if state not in Q:
#             Q[state] = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0}

#         # Greedy (tiny exploration cap just in case)
#         if random.random() < min(epsilon, 0.1):
#             action = random.randint(0, 3)
#         else:
#             action = int(max(Q[state], key=Q[state].get))

#         actions_taken.append(action)

#         # Step
#         prior_score = max(0, getattr(env.snake, "length", 1) - 1)
#         _, reward, done, info = env.step(action)
#         next_state = state_func(env)
#         total_reward += reward

#         # Stable score even if env auto-resets
#         info_score = info.get("score") if isinstance(info, dict) else None
#         cur_score = max(0, getattr(env.snake, "length", 1) - 1)
#         last_score = max(last_score, prior_score, cur_score if info_score is None else int(info_score))

#         # Capture frame (screen/surface fallback)
#         surf = getattr(env, "screen", None) or getattr(env, "surface", None)
#         frame = pygame.surfarray.array3d(surf)
#         frame = np.transpose(frame, (1, 0, 2))
#         frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

#         # Overlay
#         overlay = np.zeros((frame_h, w, 3), dtype=np.uint8)
#         overlay[0:h, :, :] = frame
#         font = cv2.FONT_HERSHEY_SIMPLEX
#         fs, color, th = 0.35, (255, 255, 255), 1
#         y0 = h + 15
#         for i, text in enumerate([
#             f"Agent: {agent_name}",
#             f"Seed: {seed}",
#             f"Episode: {episode_num}",
#             f"Step: {steps+1}",
#             f"Score: {last_score}",
#         ]):
#             cv2.putText(overlay, text, (5, y0 + i*12), font, fs, color, th, cv2.LINE_AA)

#         vw.write(np.ascontiguousarray(overlay, dtype=np.uint8))

#         state = next_state
#         steps += 1

#     vw.release()
#     return last_score, total_reward, steps, actions_taken
# # ...existing code...

In [None]:
import pickle
from snake_gym.envs.snake import SnakeGame
# from your_q_module import get_state, record_episode  # if not defined in this notebook

seed = 123
agent_name = "Q-Learning"

# Load Q-table saved during training
with open(f"q_table_seed{seed}.pkl", "rb") as f:
    Q = pickle.load(f)

env = SnakeGame()
video_path = f"replay_q_learning_seed{seed}.mp4"

# Greedy replay (epsilon=0.0)
score, total_reward, steps, actions = record_episode(
    env, get_state, Q, epsilon=0.0,
    episode_num=1, agent_name=agent_name, seed=seed, video_path=video_path
)

print(f"Replay done. Score: {score}, Steps: {steps}, Total reward: {total_reward}")
print(f"Saved video: {video_path}")

Replay done. Score: 14, Steps: 104, Total reward: 137.7200000000002
Saved video: replay_q_learning_seed123.mp4


In [None]:
# import pickle
# with open("q_table.pkl", "wb") as f:
#     pickle.dump(Q, f)


In [None]:
# import pickle
# import numpy as np
# import pygame
# from snake_gym.envs.snake import SnakeGame
# from snake_gym.envs.modules import GRIDSIZE
# import cv2
# import time

# # Load Q-table
# with open('q_table.pkl', 'rb') as f:
#     Q = pickle.load(f)

# print(f"Loaded Q-table with {len(Q)} states")

# def get_state(env):
#     """Same as training"""
#     head_x, head_y = env.snake.get_head_position()
#     apple_x, apple_y = env.apple.position
#     direction = env.snake.direction
    
#     def is_danger(dx, dy):
#         new_x = head_x + dx * GRIDSIZE
#         new_y = head_y + dy * GRIDSIZE
#         if new_x < 0 or new_x >= 150 or new_y < 0 or new_y >= 150:
#             return 1
#         if (new_x, new_y) in env.snake.positions[:-1]:
#             return 1
#         return 0
    
#     return (is_danger(0, -GRIDSIZE), is_danger(0, GRIDSIZE), 
#             is_danger(-GRIDSIZE, 0), is_danger(GRIDSIZE, 0),
#             int(apple_y < head_y), int(apple_y > head_y),
#             int(apple_x < head_x), int(apple_x > head_x),
#             int(direction == (0, -1)), int(direction == (0, 1)),
#             int(direction == (-1, 0)), int(direction == (1, 0)))

# # Initialize environment
# env = SnakeGame()
# state = env.reset()
# state = get_state(env)
# done = False
# steps = 0

# # --- Video recording setup ---
# frame_width, frame_height = 150, 150
# video_filename = 'snake_play.mp4'
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out = cv2.VideoWriter(video_filename, fourcc, 10.0, (frame_width, frame_height))

# print("\nPlaying game... (close window to stop)")

# while not done and steps < 1000:
#     # Get best action from Q-table
#     if state in Q:
#         action = max(Q[state], key=Q[state].get)
#     else:
#         action = np.random.randint(0, 4)
    
#     _, reward, done, _ = env.step(action)
    
#     # Capture frame from Pygame surface
#     frame = pygame.surfarray.array3d(env.surface)
#     frame = np.transpose(frame, (1, 0, 2))  # Swap axes for OpenCV
#     frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
    
#     # Write frame to video
#     out.write(frame)
    
#     pygame.display.flip()
#     time.sleep(0.05)  # Slow down for viewing
    
#     state = get_state(env)
#     steps += 1

# # Release video writer
# out.release()
# pygame.quit()

# print(f"\nGame Over!")
# print(f"Final Score: {env.snake.length - 1}")
# print(f"Steps survived: {steps}")
# print(f"Video saved as {video_filename}")


In [None]:
# import random
# import numpy as np
# from snake_gym.envs.snake import SnakeGame

# # Hyperparameters
# alpha = 0.1
# gamma = 0.95
# epsilon = 1.0
# epsilon_min = 0.01
# epsilon_decay = 0.995
# num_episodes = 1000

# # Initialize
# env = SnakeGame()
# Q = {}

# def get_state(env):
#     """Get a simplified state representation"""
#     head_x, head_y = env.snake.get_head_position()
#     apple_x, apple_y = env.apple.position
#     direction = env.snake.direction
    
#     # Danger detection
#     def is_danger(dx, dy):
#         new_x = head_x + dx * GRIDSIZE
#         new_y = head_y + dy * GRIDSIZE
        
#         # Wall collision
#         if new_x < 0 or new_x >= 150 or new_y < 0 or new_y >= 150:
#             return 1
        
#         # Self collision
#         if (new_x, new_y) in env.snake.positions[:-1]:
#             return 1
        
#         return 0
    
#     # Check danger in all 4 directions
#     danger_up = is_danger(0, -GRIDSIZE)
#     danger_down = is_danger(0, GRIDSIZE)
#     danger_left = is_danger(-GRIDSIZE, 0)
#     danger_right = is_danger(GRIDSIZE, 0)
    
#     # Apple direction
#     apple_up = int(apple_y < head_y)
#     apple_down = int(apple_y > head_y)
#     apple_left = int(apple_x < head_x)
#     apple_right = int(apple_x > head_x)
    
#     # Current direction
#     dir_up = int(direction == (0, -1))
#     dir_down = int(direction == (0, 1))
#     dir_left = int(direction == (-1, 0))
#     dir_right = int(direction == (1, 0))
    
#     return (danger_up, danger_down, danger_left, danger_right,
#             apple_up, apple_down, apple_left, apple_right,
#             dir_up, dir_down, dir_left, dir_right)

# # Import GRIDSIZE from modules
# from snake_gym.envs.modules import GRIDSIZE

# # Training
# scores = []
# rewards_list = []

# print("Starting training...")
# print("=" * 60)

# for episode in range(num_episodes):
#     state = env.reset()
#     state = get_state(env)
#     done = False
#     total_reward = 0
#     steps = 0
    
#     while not done and steps < 1000:
#         # Initialize Q-values
#         if state not in Q:
#             Q[state] = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0}
        
#         # Epsilon-greedy
#         if random.random() < epsilon:
#             action = random.randint(0, 3)
#         else:
#             action = max(Q[state], key=Q[state].get)
        
#         # Take step
#         _, reward, done, info = env.step(action)
#         next_state = get_state(env)
        
#         if next_state not in Q:
#             Q[next_state] = {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0}
        
#         # Q-Learning update
#         best_next = max(Q[next_state].values()) if not done else 0
#         Q[state][action] += alpha * (reward + gamma * best_next - Q[state][action])
        
#         state = next_state
#         total_reward += reward
#         steps += 1
    
#     # Decay epsilon
#     epsilon = max(epsilon_min, epsilon * epsilon_decay)
    
#     # Track metrics
#     score = env.snake.length - 1
#     scores.append(score)
#     rewards_list.append(total_reward)
    
#     # Print progress
#     if (episode + 1) % 100 == 0:
#         avg_score = np.mean(scores[-100:])
#         avg_reward = np.mean(rewards_list[-100:])
#         max_score = max(scores[-100:])
#         print(f"Episode {episode+1:5d} | "
#               f"Avg Score: {avg_score:5.2f} | "
#               f"Max Score: {max_score:3d} | "
#               f"Avg Reward: {avg_reward:7.2f} | "
#               f"Epsilon: {epsilon:.3f} | "
#               f"States: {len(Q)}")

# print("\n" + "=" * 60)
# print("Training Complete!")
# print(f"Best Score: {max(scores)}")
# print(f"Final Avg (last 100): {np.mean(scores[-100:]):.2f}")

Starting training...
Episode   100 | Avg Score:  0.72 | Max Score:   4 | Avg Reward:   -2.80 | Epsilon: 0.606 | States: 34
Episode   200 | Avg Score:  1.23 | Max Score:   6 | Avg Reward:    9.20 | Epsilon: 0.367 | States: 35
Episode   300 | Avg Score:  1.67 | Max Score:  11 | Avg Reward:   25.00 | Epsilon: 0.222 | States: 36
Episode   400 | Avg Score:  1.71 | Max Score:  13 | Avg Reward:   36.80 | Epsilon: 0.135 | States: 36
Episode   500 | Avg Score:  1.10 | Max Score:  10 | Avg Reward:   44.80 | Epsilon: 0.082 | States: 36
Episode   600 | Avg Score:  1.65 | Max Score:  15 | Avg Reward:   56.50 | Epsilon: 0.049 | States: 36
Episode   700 | Avg Score:  1.03 | Max Score:  10 | Avg Reward:   60.90 | Epsilon: 0.030 | States: 36
Episode   800 | Avg Score:  1.08 | Max Score:  12 | Avg Reward:   81.20 | Epsilon: 0.018 | States: 36
Episode   900 | Avg Score:  0.96 | Max Score:  16 | Avg Reward:   73.70 | Epsilon: 0.011 | States: 36
Episode  1000 | Avg Score:  0.62 | Max Score:   8 | Avg Rewar

In [None]:
# import pickle
# import numpy as np
# import pygame
# from snake_gym.envs.snake import SnakeGame
# from snake_gym.envs.modules import GRIDSIZE
# import time

# # Load Q-table
# with open('q_table.pkl', 'rb') as f:
#     Q = pickle.load(f)

# print(f"Loaded Q-table with {len(Q)} states")

# def get_state(env):
#     """Same as training"""
#     head_x, head_y = env.snake.get_head_position()
#     apple_x, apple_y = env.apple.position
#     direction = env.snake.direction
    
#     def is_danger(dx, dy):
#         new_x = head_x + dx * GRIDSIZE
#         new_y = head_y + dy * GRIDSIZE
#         if new_x < 0 or new_x >= 150 or new_y < 0 or new_y >= 150:
#             return 1
#         if (new_x, new_y) in env.snake.positions[:-1]:
#             return 1
#         return 0
    
#     return (is_danger(0, -GRIDSIZE), is_danger(0, GRIDSIZE), 
#             is_danger(-GRIDSIZE, 0), is_danger(GRIDSIZE, 0),
#             int(apple_y < head_y), int(apple_y > head_y),
#             int(apple_x < head_x), int(apple_x > head_x),
#             int(direction == (0, -1)), int(direction == (0, 1)),
#             int(direction == (-1, 0)), int(direction == (1, 0)))

# # Play game
# env = SnakeGame()
# state = env.reset()
# state = get_state(env)
# done = False
# steps = 0

# print("\nPlaying game... (close window to stop)")

# while not done and steps < 1000:
#     # Get best action from Q-table
#     if state in Q:
#         action = max(Q[state], key=Q[state].get)
#     else:
#         action = np.random.randint(0, 4)
    
#     _, reward, done, _ = env.step(action)
#     pygame.display.flip()
#     time.sleep(0.1)  # Slow down for viewing
    
#     state = get_state(env)
#     steps += 1

# print(f"\nGame Over!")
# print(f"Final Score: {env.snake.length - 1}")
# print(f"Steps survived: {steps}")

Loaded Q-table with 36 states

Playing game... (close window to stop)

Game Over!
Final Score: 0
Steps survived: 78
