In [1]:
# Imports
import torch
import numpy as np
import gymnasium as gym
from collections import deque
import pygame
import random
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# DQN model which takes in the state as an input and outputs predicted q values for every possible action
class DQN(torch.nn.Module):
    def __init__(self, state_space, action_space):
        super().__init__()
        # Add your architecture parameters here
        # You can use nn.Functional
        # Remember that the input is of size batch_size x state_space
        # and the output is of size batch_size x action_space (ulta ho sakta hai dekh lo)
        # TODO: Add code here
        self.layer1 = nn.Linear(state_space, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, action_space)

    def forward(self, input):
        # TODO: Complete based on your implementation
        outs1 = F.relu(self.layer1(input))
        outs2 = F.relu(self.layer2(outs1))
        return self.layer3(outs2)

In [3]:
# While training neural networks, we split the data into batches.
# To improve the training, we need to remove the "correlation" between game states
# The buffer starts storing states and once it reaches maximum capacity, it replaces
# states at random which reduces the correlation.
class ExperienceBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)


In [4]:
# TODO: Implement training logic for CartPole environment here
# Remember to use the ExperienceBuffer and a target network
# Details can be found in the book sent in the group

def train(qnet, targetnet, buffer, env, episodes=1000, batch_size=64, gamma=0.99, epsilon_max = 1, epsilon_min=0.1, epsilon_decay=200, target_update=100, lr=1e-3):
  optimiser = torch.optim.Adam(qnet.parameters(), lr = lr)
  loss_fn = nn.MSELoss()

  epsilon = epsilon_max
  ep_rewards = []

  for episode in range(episodes):
    state, _= env.reset()
    state = torch.FloatTensor(state).unsqueeze(0)
    total_reward = 0
    done = False

    while not done:
      if np.random.uniform(0,1) < epsilon:
        action = env.action_space.sample()
      else:
        q_values = qnet(state)
        action = q_values.argmax().item()

      next_state, reward, done, *info = env.step(action)
      total_reward += reward
      next_state = torch.FloatTensor(next_state).unsqueeze(0)
      buffer.push(state, action, reward, next_state, done)
      state = next_state

      if len(buffer) >= batch_size:
        states, actions, rewards, next_states, dones = buffer.sample(batch_size)
        states = torch.cat(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.cat(next_states)
        dones = torch.FloatTensor(dones)

        # Add an extra dimension to actions to match the dimensions of qnet(states)
        actions = actions.unsqueeze(1)

        current_q_values = qnet(states).gather(1, actions).squeeze(1)

        with torch.no_grad():
            next_q_values = targetnet(next_states).max(1)[0]
            target_q_values = rewards + (1 - dones) * gamma * next_q_values

        loss = loss_fn(current_q_values, target_q_values)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    ep_rewards.append(total_reward)
    epsilon = epsilon_min + (epsilon_max - epsilon_min) * np.exp(-1. * episode / epsilon_decay)

    if episode % target_update == 0:
      targetnet.load_state_dict(qnet.state_dict())

    if episode % 100 == 0:
      avg_reward = np.mean(ep_rewards[-100:]) if episode > 0 else total_reward
      print(f'Episode {episode}, Average Reward: {avg_reward:.2f}, Epsilon: {epsilon:.3f}')

  return ep_rewards

In [5]:
def evaluate_cartpole_model(model, episodes=10, render=True):
    env = gym.make("CartPole-v1", render_mode="human" if render else None)
    obs_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n

    model.eval()

    rewards = []

    for episode in range(episodes):
        obs, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            state = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                q_values = model(state)
                action = torch.argmax(q_values, dim=1).item()

            obs, reward, done, _, _ = env.step(action)
            total_reward += reward

            if render:
                env.render()

        rewards.append(total_reward)
        print(f"Episode {episode + 1}: Reward = {total_reward}")

    env.close()
    avg_reward = sum(rewards) / episodes
    print(f"Average reward over {episodes} episodes: {avg_reward}")

In [6]:

if __name__ == "__main__":
    env = gym.make("CartPole-v1")
    state_space = env.observation_space.shape[0]
    action_space = env.action_space.n

    qnet = DQN(state_space, action_space)
    targetnet = DQN(state_space, action_space)
    targetnet.load_state_dict(qnet.state_dict())

    buffer = ExperienceBuffer(capacity=10000)
train(qnet, targetnet, buffer, env)


Episode 0, Average Reward: 19.00, Epsilon: 1.000
Episode 100, Average Reward: 17.10, Epsilon: 0.646
Episode 200, Average Reward: 14.95, Epsilon: 0.431
Episode 300, Average Reward: 18.57, Epsilon: 0.301
Episode 400, Average Reward: 25.67, Epsilon: 0.222
Episode 500, Average Reward: 40.35, Epsilon: 0.174
Episode 600, Average Reward: 98.37, Epsilon: 0.145
Episode 700, Average Reward: 131.45, Epsilon: 0.127
Episode 800, Average Reward: 173.81, Epsilon: 0.116
Episode 900, Average Reward: 321.37, Epsilon: 0.110


[19.0,
 13.0,
 13.0,
 12.0,
 17.0,
 16.0,
 13.0,
 21.0,
 39.0,
 14.0,
 11.0,
 15.0,
 11.0,
 12.0,
 22.0,
 21.0,
 19.0,
 20.0,
 12.0,
 18.0,
 12.0,
 10.0,
 18.0,
 23.0,
 10.0,
 32.0,
 69.0,
 21.0,
 20.0,
 31.0,
 18.0,
 10.0,
 14.0,
 20.0,
 10.0,
 20.0,
 11.0,
 15.0,
 21.0,
 10.0,
 21.0,
 9.0,
 22.0,
 12.0,
 59.0,
 16.0,
 12.0,
 20.0,
 26.0,
 16.0,
 14.0,
 11.0,
 12.0,
 11.0,
 26.0,
 10.0,
 13.0,
 13.0,
 15.0,
 13.0,
 10.0,
 30.0,
 9.0,
 22.0,
 13.0,
 40.0,
 9.0,
 23.0,
 12.0,
 31.0,
 12.0,
 11.0,
 15.0,
 14.0,
 15.0,
 21.0,
 16.0,
 19.0,
 9.0,
 15.0,
 13.0,
 18.0,
 17.0,
 24.0,
 11.0,
 22.0,
 21.0,
 12.0,
 12.0,
 12.0,
 13.0,
 14.0,
 11.0,
 17.0,
 10.0,
 11.0,
 11.0,
 12.0,
 13.0,
 14.0,
 11.0,
 13.0,
 32.0,
 16.0,
 9.0,
 8.0,
 14.0,
 13.0,
 23.0,
 13.0,
 9.0,
 20.0,
 15.0,
 10.0,
 10.0,
 21.0,
 17.0,
 16.0,
 57.0,
 22.0,
 10.0,
 11.0,
 10.0,
 15.0,
 13.0,
 9.0,
 26.0,
 12.0,
 8.0,
 16.0,
 26.0,
 9.0,
 14.0,
 11.0,
 12.0,
 10.0,
 14.0,
 15.0,
 14.0,
 38.0,
 12.0,
 19.0,
 11.0,
 14.0,
 1

In [7]:
# TODO: Run evaluation for cartpole here
evaluate_cartpole_model(qnet, episodes=10)

Episode 1: Reward = 416.0
Episode 2: Reward = 420.0
Episode 3: Reward = 149.0
Episode 4: Reward = 362.0
Episode 5: Reward = 114.0
Episode 6: Reward = 135.0
Episode 7: Reward = 117.0
Episode 8: Reward = 378.0
Episode 9: Reward = 304.0
Episode 10: Reward = 407.0
Average reward over 10 episodes: 280.2


In [8]:
class SnakeGame(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 10}

    def __init__(self, size=10, render_mode=None):
        super().__init__()
        self.size = size
        self.cell_size = 30
        self.screen_size = self.size * self.cell_size
        self.render_mode = render_mode

        self.action_space = gym.spaces.Discrete(4)  # 0: right, 1: up, 2: left, 3: down
        self.observation_space = gym.spaces.Box(0, 2, shape=(self.size, self.size), dtype=np.uint8)

        self.screen = None
        self.clock = None

        self.snake = deque()
        self.food = None
        self.direction = [1, 0]

        if self.render_mode == "human":
            pygame.init()
            self.screen = pygame.display.set_mode((self.screen_size, self.screen_size))
            self.clock = pygame.time.Clock()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.snake.clear()
        mid = self.size // 2
        self.snake.appendleft([mid, mid])
        self.direction = [1, 0]
        self._place_food()
        obs = self._get_obs()

        if self.render_mode == "human":
            self._render_init()

        return obs, {}

    def step(self, action):
        # TODO: Change reward schema to avoid the following
        # 1) 180 degree turns
        # 2) Wall collisions
        # 3) Being slow at collecting food
        dist_before = abs(self.snake[0][0] - self.food[0]) + abs(self.snake[0][1] - self.food[1])


        if action == 0 and self.direction != [-1, 0]: self.direction = [1, 0]
        elif action == 1 and self.direction != [0, 1]: self.direction = [0, -1]
        elif action == 2 and self.direction != [1, 0]: self.direction = [-1, 0]
        elif action == 3 and self.direction != [0, -1]: self.direction = [0, 1]

        head = self.snake[0]
        new_head = [head[0] + self.direction[0], head[1] + self.direction[1]]

        done = False
        reward = 0

        if not (0 <= new_head[0] < self.size and 0 <= new_head[1] < self.size):
            done = True
            reward = -10
        else:
            body_to_check = list(self.snake)[:-1] if new_head != self.food else list(self.snake)
            if new_head in body_to_check:
                done = True
                reward = -10

        if not done:
            self.snake.appendleft(new_head)
            if new_head == self.food:
                reward = 10
                self._place_food()
            else:
                self.snake.pop()
                dist_after = abs(self.snake[0][0] - self.food[0]) + abs(self.snake[0][1] - self.food[1])
                if dist_after < dist_before:
                    reward = 1
                else:
                    reward = -1

        obs = self._get_obs()

        if self.render_mode == "human":
            self.render()

        return obs, reward, done, False, {}

    def _get_obs(self):
        # TODO: Return an observation state, take inspiration from the observation_space attribute
        obs = np.zeros((self.size, self.size), dtype=np.uint8)
        if self.food:
          obs[self.food[1], self.food[0]] = 2
        for part in self.snake:
          obs[part[1], part[0]] = 1
        return obs

    def _place_food(self):
        positions = set(tuple(p) for p in self.snake)
        empty = [(x, y) for x in range(self.size) for y in range(self.size) if (x, y) not in positions]
        self.food = list(random.choice(empty)) if empty else None

    def render(self):
        if self.screen is None:
            self._render_init()

        self.screen.fill((0, 0, 0))
        for x, y in self.snake:
            pygame.draw.rect(
                self.screen, (0, 255, 0),
                pygame.Rect(x * self.cell_size, y * self.cell_size, self.cell_size, self.cell_size)
            )
        if self.food:
            fx, fy = self.food
            pygame.draw.rect(
                self.screen, (255, 0, 0),
                pygame.Rect(fx * self.cell_size, fy * self.cell_size, self.cell_size, self.cell_size)
            )

        pygame.display.flip()
        self.clock.tick(self.metadata["render_fps"])

    def _render_init(self):
        pygame.init()
        self.screen = pygame.display.set_mode((self.size * self.cell_size, self.size * self.cell_size))
        self.clock = pygame.time.Clock()

    def close(self):
        if self.screen:
            pygame.quit()
            self.screen = None

In [9]:
# TODO: Implement training logic for Snake Game here
def train_snake():
    learning_rate = 1e-4
    batch_size = 64
    gamma = 0.99
    num_episodes = 1000
    buffer_capacity = 10000
    target_update_freq = 15

    epsilon_start = 1.0
    epsilon_end = 0.05
    epsilon_decay = 0.995

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = SnakeGame()

    state_space_size = env.observation_space.shape[0] * env.observation_space.shape[1]
    action_space_size = env.action_space.n

    policy_net = DQN(state_space_size, action_space_size).to(device)
    target_net = DQN(state_space_size, action_space_size).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = torch.optim.Adam(policy_net.parameters(), lr=learning_rate)
    replay_buffer = ExperienceBuffer(capacity=buffer_capacity)

    epsilon = epsilon_start
    all_scores = []


    for episode in range(num_episodes):
        state, _ = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=device).flatten().unsqueeze(0)

        done = False
        episode_score = 0

        while not done:
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                q_values = policy_net(state)
                action = q_values.argmax().item()

            next_state, reward, done, _, _ = env.step(action)
            episode_score += reward

            replay_buffer.push(state, action, reward, next_state, done)

            state = torch.tensor(next_state, dtype=torch.float32, device=device).flatten().unsqueeze(0)

            if len(replay_buffer) >= batch_size:
                optimize_model(policy_net, target_net, replay_buffer, optimizer, batch_size, gamma, device)

        all_scores.append(episode_score)

        if epsilon > epsilon_end:
            epsilon *= epsilon_decay

        if episode % target_update_freq == 0:
            target_net.load_state_dict(policy_net.state_dict())

        if (episode + 1) % 50 == 0:
            avg_score = np.mean(all_scores[-50:])
            print(f"Episode {episode + 1}/{num_episodes} | Avg Score (Last 50): {avg_score:.2f} | Epsilon: {epsilon:.3f}")

    print("Training complete.")
    torch.save(policy_net.state_dict(), 'snake_model_weights.pth')
    print("Model weights saved to snake_model_weights.pth")
    env.close()

def optimize_model(policy_net, target_net, buffer, optimizer, batch_size, gamma, device):
    states, actions, rewards, next_states_np, dones = buffer.sample(batch_size)

    states = torch.cat(states)
    next_states_list = [torch.tensor(s, dtype=torch.float32, device=device).flatten().unsqueeze(0) for s in next_states_np]
    next_states = torch.cat(next_states_list)
    actions = torch.tensor(actions, dtype=torch.long, device=device).unsqueeze(1)
    rewards = torch.tensor(rewards, dtype=torch.float32, device=device).unsqueeze(1)
    dones = torch.tensor(dones, dtype=torch.float32, device=device).unsqueeze(1)

    current_q_values = policy_net(states).gather(1, actions)

    with torch.no_grad():
        next_q_values = target_net(next_states).max(1)[0].unsqueeze(1)

    expected_q_values = rewards + (gamma * next_q_values * (1 - dones))

    loss = F.smooth_l1_loss(current_q_values, expected_q_values)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [10]:
train_snake()

Episode 50/1000 | Avg Score (Last 50): -12.04 | Epsilon: 0.778
Episode 100/1000 | Avg Score (Last 50): -10.84 | Epsilon: 0.606
Episode 150/1000 | Avg Score (Last 50): -10.80 | Epsilon: 0.471
Episode 200/1000 | Avg Score (Last 50): -10.40 | Epsilon: 0.367
Episode 250/1000 | Avg Score (Last 50): -9.20 | Epsilon: 0.286
Episode 300/1000 | Avg Score (Last 50): -9.94 | Epsilon: 0.222
Episode 350/1000 | Avg Score (Last 50): -9.16 | Epsilon: 0.173
Episode 400/1000 | Avg Score (Last 50): -6.60 | Epsilon: 0.135
Episode 450/1000 | Avg Score (Last 50): -8.84 | Epsilon: 0.105
Episode 500/1000 | Avg Score (Last 50): -8.66 | Epsilon: 0.082
Episode 550/1000 | Avg Score (Last 50): -8.60 | Epsilon: 0.063
Episode 600/1000 | Avg Score (Last 50): -9.34 | Epsilon: 0.050
Episode 650/1000 | Avg Score (Last 50): -7.58 | Epsilon: 0.050
Episode 700/1000 | Avg Score (Last 50): -9.26 | Epsilon: 0.050
Episode 750/1000 | Avg Score (Last 50): -11.00 | Epsilon: 0.050
Episode 800/1000 | Avg Score (Last 50): -9.74 | Eps

In [11]:
def evaluate_snake_model(model, size=20, episodes=10, render=True):
    env = SnakeGame(size=size, render_mode="human" if render else None)
    model.eval()

    rewards = []

    for episode in range(episodes):
        obs, _ = env.reset()
        total_reward = 0
        done = False
        steps = 0
        max_steps = 10000

        while not done and steps < max_steps:
            state = torch.tensor(obs, dtype=torch.float32).flatten().unsqueeze(0)
            with torch.no_grad():
                q_values = model(state)
                action = torch.argmax(q_values, dim=1).item()

            obs, reward, done, _, _ = env.step(action)
            total_reward += reward

            if render:
                env.render()
            steps +=1

        rewards.append(total_reward)
        print(f"Episode {episode + 1}: Reward = {total_reward}")

    env.close()
    avg_reward = sum(rewards) / episodes

    print(f"Average reward over {episodes} episodes: {avg_reward}")

In [12]:
# TODO: Run evaluation for Snake Game here

GRID_SIZE = 10
MODEL_WEIGHTS_PATH = 'snake_model_weights.pth'

state_space = GRID_SIZE * GRID_SIZE
action_space = 4
model_to_evaluate = DQN(state_space, action_space)


model_to_evaluate.load_state_dict(torch.load(MODEL_WEIGHTS_PATH))
print("Model weights loaded successfully.")

evaluate_snake_model(
    model=model_to_evaluate,
    size=10,
    episodes=10,
    render=False
)


Model weights loaded successfully.
Episode 1: Reward = -16
Episode 2: Reward = -5
Episode 3: Reward = -18
Episode 4: Reward = -14
Episode 5: Reward = -10
Episode 6: Reward = -2
Episode 7: Reward = -4
Episode 8: Reward = 3
Episode 9: Reward = -5
Episode 10: Reward = -12
Average reward over 10 episodes: -8.3


In [13]:
class ChaseEscapeEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self, render_mode=None):
        super().__init__()

        self.dt = 0.1
        self.max_speed = 0.4
        self.agent_radius = 0.05
        self.target_radius = 0.05
        self.chaser_radius = 0.07
        self.chaser_speed = 0.03

        self.action_space = gym.spaces.MultiDiscrete([3, 3])  # actions in {0,1,2} map to [-1,0,1]
        self.observation_space = gym.spaces.Box(
            low=-1,
            high=1,
            shape=(8,),
            dtype=np.float32,
        )

        self.render_mode = render_mode
        self.screen_size = 500
        self.np_random = None

        if render_mode == "human":
            pygame.init()
            self.screen = pygame.display.set_mode((self.screen_size, self.screen_size))
            self.clock = pygame.time.Clock()

    def sample_pos(self, far_from=None, min_dist=0.5):
        while True:
            pos = self.np_random.uniform(low=-0.8, high=0.8, size=(2,))
            if far_from is None or np.linalg.norm(pos - far_from) >= min_dist:
                return pos

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)

        self.agent_pos = self.sample_pos()
        self.agent_vel = np.zeros(2, dtype=np.float32)
        self.target_pos = self.sample_pos(far_from=self.agent_pos, min_dist=0.5)
        self.chaser_pos = self.sample_pos(far_from=self.agent_pos, min_dist=0.7)

        return self._get_obs(), {}

    def _get_obs(self):
        # TODO: Decide how to pass the state (don't use pixel values)
        agent_x, agent_y = self.agent_pos
        agent_vx, agent_vy = self.agent_vel
        target_x, target_y = self.target_pos
        chaser_x, chaser_y = self.chaser_pos
        return np.array([agent_x, agent_y, agent_vx, agent_vy, target_x, target_y, chaser_x, chaser_y], dtype=np.float32)

    def _get_info(self):
        return {}

    def step(self, action):
        # TODO: Add reward scheme
        # 1) Try to make the agent stay within bounds
        # 2) The agent shouldn't idle around
        # 3) The agent should go for the reward
        # 4) The agent should avoid the chaser
        reward = 0.0
        prev_dist_to_target = np.linalg.norm(self.agent_pos - self.target_pos)
        accel = (np.array(action) - 1) * 0.1
        self.agent_vel += accel
        self.agent_vel = np.clip(self.agent_vel, -self.max_speed, self.max_speed)

        new_pos = self.agent_pos.copy()
        new_pos += self.agent_vel * self.dt

        if np.any(new_pos < -1) or np.any(new_pos > 1):
            reward -= 0.1


        self.agent_pos = np.clip(new_pos, -1, 1)

        if np.linalg.norm(self.agent_vel) < 0.01:
            reward -= 0.05

        direction = self.agent_pos - self.chaser_pos
        norm = np.linalg.norm(direction)
        if norm > 1e-5:
            self.chaser_pos += self.chaser_speed * direction / norm

        dist_to_target = np.linalg.norm(self.agent_pos - self.target_pos)
        dist_to_chaser = np.linalg.norm(self.agent_pos - self.chaser_pos)

        if dist_to_chaser < 0.5:
            if norm > 1e-5:
                projected_chaser_pos = self.chaser_pos - self.chaser_speed * direction / norm
                prev_dist_to_chaser = np.linalg.norm(self.agent_pos - projected_chaser_pos)
            else:
                prev_dist_to_chaser = dist_to_chaser

            delta_chaser = prev_dist_to_chaser - dist_to_chaser
            reward += delta_chaser * 0.3

        delta = prev_dist_to_target - dist_to_target
        reward += delta * 0.5


        terminated = False

        info = {}
        if dist_to_target < self.agent_radius + self.target_radius:
            reward += 10.0
            self.target_pos = self.sample_pos(far_from=self.agent_pos, min_dist=0.5)
            info["target_captured"] = True

        if dist_to_target < 0.3:
            reward += 0.1

        if dist_to_chaser < self.agent_radius + self.chaser_radius:
            reward -= 3.0
            terminated = True
            info["caught_by_chaser"] = True

        return self._get_obs(), reward, terminated, False, info

    def render(self):
        if self.render_mode != "human":
            return

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                self.close()

        self.screen.fill((255, 255, 255))

        def to_screen(p):
            x = int((p[0] + 1) / 2 * self.screen_size)
            y = int((1 - (p[1] + 1) / 2) * self.screen_size)
            return x, y

        pygame.draw.circle(self.screen, (0, 255, 0), to_screen(self.target_pos), int(self.target_radius * self.screen_size))
        pygame.draw.circle(self.screen, (0, 0, 255), to_screen(self.agent_pos), int(self.agent_radius * self.screen_size))
        pygame.draw.circle(self.screen, (255, 0, 0), to_screen(self.chaser_pos), int(self.chaser_radius * self.screen_size))

        pygame.display.flip()
        self.clock.tick(self.metadata["render_fps"])

    def close(self):
        if self.render_mode == "human":
            pygame.quit()


In [14]:
# TODO: Train and evaluate CatMouseEnv
def select_action(state, model, epsilon, action_space):
    if random.random() < epsilon:
        return [random.randint(0, 2), random.randint(0, 2)]
    with torch.no_grad():
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        q_values = model(state_tensor)
        action = q_values.argmax(dim=1).item()
        return [action // 3, action % 3]  # Convert flat index to (x,y) action

env = ChaseEscapeEnv()
dqn = DQN(state_space=8, action_space=9)  # 3x3 = 9 actions
target_dqn = DQN(8, 9)
target_dqn.load_state_dict(dqn.state_dict())
optimizer = torch.optim.Adam(dqn.parameters(), lr=1e-3)
buffer = ExperienceBuffer(capacity=10000)

gamma = 0.99
batch_size = 64
epsilon = 1.0
min_epsilon = 0.05
epsilon_decay = 0.995
max_steps = 500

for episode in range(500):
    state, _ = env.reset()
    total_reward = 0
    steps = 0
    done = False
    steps = 0
    target_captures = 0
    caught_by_chaser = False

    while not done and steps < max_steps:
        steps += 1
        action = select_action(state, dqn, epsilon, env.action_space)
        flat_action = action[0] * 3 + action[1]
        next_state, reward, terminated, truncated, info = env.step(action)
        buffer.push(state, flat_action, reward, next_state, terminated)
        state = next_state
        total_reward += reward
        done = terminated

        # Track extra info
        if info.get("target_captured"):
            target_captures += 1
        if info.get("caught_by_chaser"):
            caught_by_chaser = True

        # Train if enough samples
        if len(buffer) > batch_size:
            states, actions, rewards, next_states, dones = buffer.sample(batch_size)
            states = torch.FloatTensor(states)
            actions = torch.LongTensor(actions)
            rewards = torch.FloatTensor(rewards)
            next_states = torch.FloatTensor(next_states)
            dones = torch.FloatTensor(dones)

            q_values = dqn(states)
            next_q_values = target_dqn(next_states)

            q_target = rewards + gamma * next_q_values.max(1)[0] * (1 - dones)
            q_expected = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

            loss = F.mse_loss(q_expected, q_target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if epsilon > min_epsilon:
        epsilon *= epsilon_decay

    if episode % 10 == 0:
        target_dqn.load_state_dict(dqn.state_dict())

    # ✅ Clean, compact log per episode
    print(
        f"Ep {episode:03d} | "
        f"Reward: {total_reward:.2f} | "
        f"Steps: {steps} | "
        f"Captures: {target_captures} | "
        f"Caught: {caught_by_chaser}"
    )

# Evaluation phase
state, _ = env.reset()
done = False
total_reward = 0
while not done:
    env.render()
    action = select_action(state, dqn, epsilon=0.0, action_space=env.action_space)
    state, reward, terminated, _, _ = env.step(action)
    total_reward += reward
    done = terminated
    pygame.time.delay(50)

print(f"Total reward in evaluation: {total_reward:.2f}")
env.close()


Ep 000 | Reward: -3.35 | Steps: 32 | Captures: 0 | Caught: True
Ep 001 | Reward: -5.37 | Steps: 55 | Captures: 0 | Caught: True


  states = torch.FloatTensor(states)


Ep 002 | Reward: -3.07 | Steps: 41 | Captures: 0 | Caught: True
Ep 003 | Reward: -6.71 | Steps: 62 | Captures: 0 | Caught: True
Ep 004 | Reward: -2.67 | Steps: 28 | Captures: 0 | Caught: True
Ep 005 | Reward: -1.65 | Steps: 34 | Captures: 0 | Caught: True
Ep 006 | Reward: -7.27 | Steps: 74 | Captures: 0 | Caught: True
Ep 007 | Reward: -1.62 | Steps: 38 | Captures: 0 | Caught: True
Ep 008 | Reward: -3.00 | Steps: 16 | Captures: 0 | Caught: True
Ep 009 | Reward: -3.00 | Steps: 41 | Captures: 0 | Caught: True
Ep 010 | Reward: -1.07 | Steps: 62 | Captures: 0 | Caught: True
Ep 011 | Reward: -2.76 | Steps: 10 | Captures: 0 | Caught: True
Ep 012 | Reward: -2.85 | Steps: 18 | Captures: 0 | Caught: True
Ep 013 | Reward: -5.44 | Steps: 40 | Captures: 0 | Caught: True
Ep 014 | Reward: -3.40 | Steps: 45 | Captures: 0 | Caught: True
Ep 015 | Reward: -2.13 | Steps: 14 | Captures: 0 | Caught: True
Ep 016 | Reward: -2.85 | Steps: 25 | Captures: 0 | Caught: True
Ep 017 | Reward: -3.09 | Steps: 44 | Cap