In [None]:
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import time
from typing import Deque, Tuple
from collections import deque 

try:
    plt.switch_backend('TkAgg') 
except ImportError:
    pass

random.seed(42)
np.random.seed(42)

class ChaseEnv:
    def __init__(self, size=10, gamma=0.9):
        self.SIZE = size
        self.nA = 4
        self.gamma = gamma
        self.TRAPS = self._generate_traps()
        self.reset()

    def _generate_traps(self):
        traps = set()
        for r in [2, 7]:
            for c in [2, 7]:
                traps.add((r, c))
        return traps

    def reset(self):
        self.agent_pos = self._get_safe_random_pos()
        self.prey_pos = self._get_safe_random_pos()
        self.is_terminal = False
        self.score = 0
        return self._get_state()

    def _get_safe_random_pos(self):
        while True:
            pos = (random.randint(0, self.SIZE - 1), random.randint(0, self.SIZE - 1))
            if pos not in self.TRAPS:
                return pos
    
    def _get_state(self):
        r_agent, c_agent = self.agent_pos
        r_prey, c_prey = self.prey_pos
        dist_to_prey = np.sqrt((r_agent - r_prey)**2 + (c_agent - c_prey)**2) / (self.SIZE * 1.5)
        danger_ahead = 0
        current_direction = random.randint(0,3)
        dr, dc = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}[current_direction]
        next_r, next_c = r_agent + dr, c_agent + dc
        
        if next_r < 0 or next_r >= self.SIZE or next_c < 0 or next_c >= self.SIZE or (next_r, next_c) in self.TRAPS:
            danger_ahead = 1
            
        prey_direction = [r_prey < r_agent, r_prey > r_agent, c_prey < c_agent, c_prey > c_agent]
        state_vector = [dist_to_prey, danger_ahead] + prey_direction
        return np.array(state_vector)

    def _get_state_index(self, state_vector):
        Food_Close = 1 if state_vector[0] < 0.2 else 0 
        Danger_Ahead = int(state_vector[1])
        F_U, F_D, F_L, F_R = [int(x) for x in state_vector[2:]]
        idx = (Food_Close * 32) + (Danger_Ahead * 16) + (F_U * 8) + (F_D * 4) + (F_L * 2) + F_R
        return idx

    def _move_prey(self):
        r, c = self.prey_pos
        action = random.randint(0, 3)
        dr, dc = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}[action]
        
        nr, nc = r + dr, c + dc
        
        if 0 <= nr < self.SIZE and 0 <= nc < self.SIZE and (nr, nc) not in self.TRAPS:
            self.prey_pos = (nr, nc)

    def step(self, action):
        reward = -0.1
        done = False
        
        r, c = self.agent_pos
        dr, dc = {0: (-1, 0), 1: (1, 0), 2: (0, -1), 3: (0, 1)}[action]
        
        next_r, next_c = r + dr, c + dc
        next_pos = (next_r, next_c)
        
        if next_r < 0 or next_r >= self.SIZE or next_c < 0 or next_c >= self.SIZE or next_pos in self.TRAPS:
            reward = -20.0
            done = True
        else:
            self.agent_pos = next_pos
            
            if self.agent_pos == self.prey_pos:
                reward = 20.0
                self.score += 1
                self.prey_pos = self._get_safe_random_pos()
                
            old_pos_array = np.array([r, c]) 
            new_pos_array = np.array(next_pos)
            
            old_dist = np.linalg.norm(old_pos_array - np.array(self.prey_pos))
            new_dist = np.linalg.norm(new_pos_array - np.array(self.prey_pos))
            
            if new_dist < old_dist:
                 reward += 0.2

            self._move_prey()
        
        next_s_vector = self._get_state()
        return self._get_state_index(next_s_vector), reward, done, {}

class BaseTDAgent:
    """Base class providing common Q-table management and action selection."""
    def __init__(self, env, alpha=0.1, gamma=0.9):
        self.nS = 64
        self.nA = env.nA
        self.Q = np.zeros((self.nS, self.nA))
        self.alpha = alpha
        self.gamma = gamma

    def _discretize(self, state_vector, env):
        return env._get_state_index(state_vector)

    def get_action(self, s_vector, epsilon, env):
        s = self._discretize(s_vector, env) 
        if random.random() < epsilon:
            return random.randint(0, self.nA - 1)
        else:
            return np.argmax(self.Q[s]) 
            
    def get_greedy_action(self, s_vector, env):
        s = self._discretize(s_vector, env)
        return np.argmax(self.Q[s])
        
    def get_max_q(self, s_vector, env):
        s = self._discretize(s_vector, env)
        return np.max(self.Q[s])
        
    def get_q_value(self, s_vector, a, env):
        s = self._discretize(s_vector, env)
        return self.Q[s, a]

    def set_q_value(self, s_vector, a, new_value, env):
        s = self._discretize(s_vector, env)
        self.Q[s, a] = new_value

class QAgent(BaseTDAgent):
    def update_Q(self, s_vec, a, r, ns_vec, env):
        next_max = self.get_max_q(ns_vec, env)
        td_target = r + self.gamma * next_max
        
        old_value = self.get_q_value(s_vec, a, env)
        td_error = td_target - old_value
        new_value = old_value + self.alpha * td_error
        self.set_q_value(s_vec, a, new_value, env)

class SARSAAgent(BaseTDAgent):
    def update_Q(self, s_vec, a, r, ns_vec, next_a, env):
        next_q = self.get_q_value(ns_vec, next_a, env)
        td_target = r + self.gamma * next_q
        
        old_value = self.get_q_value(s_vec, a, env)
        td_error = td_target - old_value
        new_value = old_value + self.alpha * td_error
        self.set_q_value(s_vec, a, new_value, env)

def train_agent(env, agent, num_episodes=5000, agent_type='Q'):
    epsilon = 1.0
    epsilon_min = 0.05
    epsilon_decay = 0.999 
    
    total_rewards = []
    
    for episode in tqdm(range(num_episodes), desc=f"{agent_type} Training"):
        s_vector = env.reset()
        done = False
        episode_reward = 0 
        
        if agent_type == 'SARSA':
            a = agent.get_action(s_vector, epsilon, env)
        
        for step in range(100):
            if agent_type == 'Q':
                a = agent.get_action(s_vector, epsilon, env)
            
            s, reward, done, _ = env.step(a)
            ns_vector = env._get_state()
            
            if agent_type == 'SARSA':
                next_a = agent.get_action(ns_vector, epsilon, env)
                agent.update_Q(s_vector, a, reward, ns_vector, next_a, env)
                a = next_a 
            else: 
                agent.update_Q(s_vector, a, reward, ns_vector, env)
            
            s_vector = ns_vector
            episode_reward += reward
            
            if done: break 
            
        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        total_rewards.append(episode_reward)
        
    return total_rewards, agent
\
def track_convergence_time(env, agent, num_episodes, agent_type, reward_threshold=5.0, window_size=50):
    """Runs a single training session and returns the episode count when the reward threshold is consistently met."""
    
    epsilon = 1.0
    epsilon_min = 0.05
    epsilon_decay = 0.999 
    
    rewards = deque(maxlen=window_size)
    
    for episode in range(1, num_episodes + 1):
        s_vector = env.reset()
        episode_reward = 0
        
        if agent_type == 'SARSA':
            a = agent.get_action(s_vector, epsilon, env)
        
        for step in range(100):
            if agent_type == 'Q':
                a = agent.get_action(s_vector, epsilon, env)
            
            s, reward, done, _ = env.step(a)
            ns_vector = env._get_state()
            
            if agent_type == 'SARSA':
                next_a = agent.get_action(ns_vector, epsilon, env)
                agent.update_Q(s_vector, a, reward, ns_vector, next_a, env)
                a = next_a 
            else:
                agent.update_Q(s_vector, a, reward, ns_vector, env)
            
            s_vector = ns_vector
            episode_reward += reward
            
            if done: break 
            
        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards.append(episode_reward)
        
        if len(rewards) == window_size and np.mean(rewards) >= reward_threshold:
            return episode 

    return num_episodes 

def compare_convergence_speed(num_runs=10, max_episodes=5000, reward_threshold=5.0):
    
    q_times = []
    sarsa_times = []
    
    for run in tqdm(range(num_runs), desc="Convergence Comparison Runs"):
        
        ENV_Q = ChaseEnv(size=10, gamma=0.9)
        ENV_SARSA = ChaseEnv(size=10, gamma=0.9)
        AGENT_Q = QAgent(ENV_Q)
        AGENT_SARSA = SARSAAgent(ENV_SARSA)

        q_time = track_convergence_time(ENV_Q, AGENT_Q, max_episodes, 'Q', reward_threshold)
        q_times.append(q_time)
        
        sarsa_time = track_convergence_time(ENV_SARSA, AGENT_SARSA, max_episodes, 'SARSA', reward_threshold)
        sarsa_times.append(sarsa_time)
        
        random.seed(42 + run) 
        np.random.seed(42 + run)

    plt.figure(figsize=(10, 6))
    
    x = np.arange(1, num_runs + 1)
    
    plt.plot(x, q_times, marker='o', linestyle='-', label='Q-Learning Time to Solve')
    plt.plot(x, sarsa_times, marker='s', linestyle='--', label='SARSA Time to Solve')
    
    if np.any(np.array(q_times) == max_episodes) or np.any(np.array(sarsa_times) == max_episodes):
        plt.axhline(max_episodes, color='r', linestyle=':', label='Max Episodes Reached (Failure)')
        
    plt.title(f"Convergence Speed Comparison (Target: Avg Reward > {reward_threshold:.1f})")
    plt.xlabel("Run Number")
    plt.ylabel("Episodes Required for Convergence")
    plt.xticks(x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show(block=False)
    
    print("\n--- Summary of Convergence ---")
    print(f"Q-Learning Average Convergence Time (Episodes): {np.mean(q_times):.1f}")
    print(f"SARSA Average Convergence Time (Episodes): {np.mean(sarsa_times):.1f}")


def visualize_chase(env, agent, num_runs=3, delay=0.1, agent_name='Agent'):
    plt.ion() 
    fig, ax = plt.subplots(figsize=(6, 6))
    plt.show(block=False)

    for run in range(1, num_runs + 1):
        run_seed = random.randint(0, 1000)
        random.seed(run_seed) 
        np.random.seed(run_seed)
        
        env = ChaseEnv(size=10, gamma=0.9) 
        s_vector = env.reset()
        done = False
        
        print(f"\n--- Running Chase Demo {run}/{num_runs} for {agent_name} ---")

        for step in range(200):
            s_vector = env._get_state()
            action = agent.get_greedy_action(s_vector, env)
            s, reward, done, _ = env.step(action)
            
            ax.cla() 
            grid = np.zeros((env.SIZE, env.SIZE))
            
            grid[env.agent_pos[0], env.agent_pos[1]] = 3.0 
            grid[env.prey_pos[0], env.prey_pos[1]] = 4.0 
            for r, c in env.TRAPS:
                grid[r, c] = 1.0 

            ax.imshow(grid, cmap='viridis', vmin=0, vmax=4)
            ax.set_title(f"{agent_name} Run {run} | Score: {env.score} | Step: {step}")
            ax.set_xticks(np.arange(env.SIZE)), ax.set_yticks(np.arange(env.SIZE))
            ax.grid(color='white', linestyle='-', linewidth=0.5)
            
            fig.canvas.draw()
            fig.canvas.flush_events()
            time.sleep(delay) 

            if done:
                print(f"Run {run} finished! Agent crashed after {step} steps. Final Score: {env.score}")
                break
        
        if not done:
            print(f"Run {run} finished! Max steps reached. Final Score: {env.score}")
            
        if run < num_runs:
            ax.cla()
            ax.set_title(f"Run {run} Complete. Score: {env.score}. Click to start Run {run+1}.")
            fig.canvas.draw()
            fig.canvas.flush_events()
            plt.waitforbuttonpress() 

    plt.ioff()
    plt.close(fig) 


if __name__ == '__main__':
    ENV = ChaseEnv(size=10, gamma=0.9)
    AGENT_Q = QAgent(ENV)
    AGENT_SARSA = SARSAAgent(ENV)
    
    REWARDS_Q, FINAL_AGENT_Q = train_agent(ENV, AGENT_Q, num_episodes=5000, agent_type='Q')
    REWARDS_SARSA, FINAL_AGENT_SARSA = train_agent(ENV, AGENT_SARSA, num_episodes=5000, agent_type='SARSA')
    
    window = 100
    smoothed_Q = np.convolve(REWARDS_Q, np.ones(window)/window, mode='valid')
    smoothed_SARSA = np.convolve(REWARDS_SARSA, np.ones(window)/window, mode='valid')
    
    plt.figure(figsize=(10, 5))
    plt.plot(smoothed_Q, label='Q-Learning (Off-Policy)')
    plt.plot(smoothed_SARSA, label='SARSA (On-Policy)')
    plt.title("ChaseRL: Q-Learning vs SARSA Performance Comparison")
    plt.xlabel(f"Episode (Smoothed over {window} episodes)")
    plt.ylabel("Avg. Cumulative Reward")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show(block=False) 
    
    compare_convergence_speed(num_runs=10, max_episodes=5000, reward_threshold=5.0)
    
    visualize_chase(ENV, FINAL_AGENT_Q, num_runs=3, delay=0.1, agent_name='Q-Learning')
    
    visualize_chase(ENV, FINAL_AGENT_SARSA, num_runs=3, delay=0.1, agent_name='SARSA')

Q Training: 100%|██████████| 5000/5000 [00:01<00:00, 3643.41it/s]
SARSA Training: 100%|██████████| 5000/5000 [00:01<00:00, 3840.97it/s]
Convergence Comparison Runs: 100%|██████████| 10/10 [00:06<00:00,  1.61it/s]



--- Summary of Convergence ---
Q-Learning Average Convergence Time (Episodes): 1608.2
SARSA Average Convergence Time (Episodes): 1632.2

--- Running Chase Demo 1/3 for Q-Learning ---
Run 1 finished! Agent crashed after 16 steps. Final Score: 1

--- Running Chase Demo 2/3 for Q-Learning ---
Run 2 finished! Agent crashed after 26 steps. Final Score: 3

--- Running Chase Demo 3/3 for Q-Learning ---
Run 3 finished! Agent crashed after 7 steps. Final Score: 0

--- Running Chase Demo 1/3 for SARSA ---
Run 1 finished! Agent crashed after 31 steps. Final Score: 2

--- Running Chase Demo 2/3 for SARSA ---
Run 2 finished! Agent crashed after 138 steps. Final Score: 9

--- Running Chase Demo 3/3 for SARSA ---
Run 3 finished! Agent crashed after 70 steps. Final Score: 8


: 