In [1]:
# Step 1: Environment Setup with Gymnasium
import ale_py
import gymnasium as gym
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import random
import pygame
from typing import Optional

In [2]:
class AtariPreprocessor:
    """Handles frame preprocessing for Atari environments"""
    def __init__(self, frame_size=84):
        self.frame_size = frame_size
        self.frame_buffer = deque(maxlen=4)

    def preprocess(self, frame):
        """Convert RGB frame to grayscale and resize"""
        gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (self.frame_size, self.frame_size))
        return resized.astype(np.uint8)

    def stack_frames(self, frame, reset=False):
        """Stack 4 consecutive frames"""
        if reset:
            self.frame_buffer.clear()
            for _ in range(4):
                self.frame_buffer.append(frame)
        else:
            self.frame_buffer.append(frame)
        return np.stack(self.frame_buffer, axis=0)

In [3]:
# Step 2: Custom Gymnasium Wrappers
class EpisodicLifeWrapper(gym.Wrapper):
    """End episode only when all lives are exhausted"""
    def __init__(self, env):
        super().__init__(env)
        self.lives = 0
        self.was_real_done = True

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.was_real_done = terminated or truncated

        current_lives = self.env.unwrapped.ale.lives()
        if 0 < current_lives < self.lives:
            terminated = False
        self.lives = current_lives
        return obs, reward, terminated, truncated, info

    def reset(self, **kwargs):
        if self.was_real_done:
            obs, info = self.env.reset(**kwargs)
            self.lives = self.env.unwrapped.ale.lives()
        else:
            obs, _, _, _, info = self.env.step(0)
        return obs, info


In [4]:
# Step 3: Sum Tree Implementation for Prioritized Replay
class SumTree:
    """Binary heap structure for efficient priority sampling"""
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.empty(capacity, dtype=object)
        self.write_ptr = 0
        self.size = 0

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(left + 1, s - self.tree[left])

    def total(self):
        return self.tree[0]

    def add(self, priority, data):
        idx = self.write_ptr + self.capacity - 1
        self.data[self.write_ptr] = data
        self.update(idx, priority)
        self.write_ptr = (self.write_ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def update(self, idx, priority):
        change = priority - self.tree[idx]
        self.tree[idx] = priority
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        data_idx = idx - self.capacity + 1
        return (idx, self.tree[idx], self.data[data_idx])

    def __len__(self):
        """Return current number of stored elements"""
        return self.size


In [5]:
# Step 4: Prioritized Replay Buffer
class PrioritizedReplayBuffer:
    """Experience buffer with importance sampling"""
    def __init__(self, capacity, alpha=0.6, beta=0.4, initial_size=10000):  # More reasonable initial size
        self.alpha = alpha
        self.beta = beta
        self.initial_size = initial_size
        self.tree = SumTree(capacity)
        self.max_priority = 1.0
        self.epsilon = 1e-6

    def add(self, transition):
        """Add new experience with max priority"""
        priority = self.max_priority ** self.alpha
        self.tree.add(priority, transition)

    def sample(self, batch_size):
        """Sample batch with importance sampling weights"""
        batch = []
        indices = []
        priorities = []
        segment = self.tree.total() / batch_size

        self.beta = min(1.0, self.beta + 0.0001)

        for i in range(batch_size):
            a = segment * i
            b = segment * (i + 1)
            s = random.uniform(a, b)
            idx, priority, data = self.tree.get(s)
            priorities.append(priority)
            batch.append(data)
            indices.append(idx)

        probs = np.array(priorities) / self.tree.total()
        weights = (len(self.tree) * probs) ** -self.beta
        weights /= weights.max()

        return batch, indices, np.array(weights, dtype=np.float32)

    def update_priorities(self, indices, priorities):
        """Update priorities after training"""
        priorities = np.abs(priorities) + self.epsilon
        priorities = priorities ** self.alpha
        for idx, priority in zip(indices, priorities):
            self.tree.update(idx, priority)
            self.max_priority = max(self.max_priority, priority)


In [6]:
# Step 5: Deep Q-Network Architecture
class DQN(nn.Module):
    """CNN architecture for Atari games"""
    def __init__(self, input_shape, num_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)


config = {
    "gamma": 0.99,
    "learning_rate": 2.5e-4,  # More stable LR for Atari DQN
    "batch_size": 32,
    "epsilon_decay": 0.9997,
    "epsilon_min": 0.05,
    "target_update": 1000,
    "train_step": 0,
    "epsilon": 1.0,
    "buffer_capacity": 100000,
    "episodes": 2000,
    "learning_starts": 10000,
    "train_frequency": 4
}


In [7]:
# Step 6: DQN Agent Implementation
class DQNAgent:
    """Deep Q-Learning Agent with Prioritized Experience Replay"""
    def __init__(self, state_shape, action_size, config):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.action_size = action_size
        self.gamma = config['gamma']
        self.lr = config['learning_rate']
        self.batch_size = config['batch_size']
        self.buffer_capacity = config['buffer_capacity']

        self.q_net = DQN(state_shape, action_size).to(self.device)
        self.target_net = DQN(state_shape, action_size).to(self.device)
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=self.lr)
        self.buffer = PrioritizedReplayBuffer(self.buffer_capacity)
        self.update_target()

    def update_target(self):
        self.target_net.load_state_dict(self.q_net.state_dict())

    def act(self, state, epsilon=0.0):
        """Epsilon-greedy action selection"""
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device) / 255.0
            with torch.no_grad():
                return self.q_net(state).argmax().item()
        return random.randint(0, self.action_size-1)

    def compute_loss(self, batch, weights):
        """Prioritized Double DQN loss with importance sampling"""
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.from_numpy(np.asarray(states, dtype=np.uint8)).float().to(self.device) / 255.0
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.from_numpy(np.asarray(next_states, dtype=np.uint8)).float().to(self.device) / 255.0
        dones = torch.BoolTensor(dones).to(self.device)
        weights = torch.FloatTensor(weights).to(self.device)

        current_q = self.q_net(states).gather(1, actions.unsqueeze(-1))

        with torch.no_grad():
            next_actions = self.q_net(next_states).argmax(1)
            next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(-1))
            target_q = rewards.unsqueeze(-1) + (1 - dones.float().unsqueeze(-1)) * self.gamma * next_q

        td_error = torch.abs(current_q - target_q).squeeze()
        loss = (weights * F.smooth_l1_loss(current_q.squeeze(), target_q.squeeze(), reduction='none')).mean()

        return loss, td_error.detach().cpu().numpy()

    def update(self, batch_size):
        if len(self.buffer.tree) < self.buffer.initial_size:
            return None

        batch, indices, weights = self.buffer.sample(batch_size)
        loss, td_errors = self.compute_loss(batch, weights)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0)
        self.optimizer.step()

        self.buffer.update_priorities(indices, td_errors)
        return loss.item()


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._obs_buffer = deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        terminated = False
        truncated = False

        for _ in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if terminated or truncated:
                break

        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self._obs_buffer.clear()
        self._obs_buffer.append(obs)
        return obs, info


class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == "FIRE"

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:
            obs, info = self.env.reset(**kwargs)
        return obs, info


In [8]:
# Step 7: Training Loop
class AtariTrainer:
    """Complete training pipeline for Atari games"""
    def __init__(self, config):
        env_name = "ALE/Breakout-v5"
        self.env = gym.make(env_name, render_mode="rgb_array")
        self.preprocessor = AtariPreprocessor()

        self.env = MaxAndSkipEnv(self.env, skip=4)
        self.env = FireResetEnv(self.env)
        self.env = EpisodicLifeWrapper(self.env)
        self.config = config
        self.best_mean_score = -float('inf')

        self.state_shape = (4, 84, 84)
        self.action_size = self.env.action_space.n
        self.agent = DQNAgent(self.state_shape, self.action_size, config)

        self.episodes = config['episodes']
        self.epsilon = config['epsilon']
        self.epsilon_decay = config['epsilon_decay']
        self.epsilon_min = config['epsilon_min']
        self.batch_size = config['batch_size']
        self.target_update = config['target_update']
        self.train_step = 0
        self.learning_starts = config['learning_starts']
        self.train_frequency = config['train_frequency']

    def _process_obs(self, obs, reset=False):
        """Apply preprocessing pipeline"""
        processed = self.preprocessor.preprocess(obs)
        return self.preprocessor.stack_frames(processed, reset)

    def render_env(self, obs, screen):
        frame_rgb = cv2.cvtColor(obs, cv2.COLOR_BGR2RGB)
        frame_resized = cv2.resize(frame_rgb, (672, 672))
        surface = pygame.surfarray.make_surface(np.transpose(frame_resized, (1, 0, 2)))
        screen.fill((0, 0, 0))
        screen.blit(surface, (0, 0))
        pygame.display.flip()
        pygame.time.delay(1000 // 15)


    def train(self):
        episode_rewards = []

        for episode in range(self.config['episodes']):
            obs, _ = self.env.reset()
            state = self._process_obs(obs, reset=True)
            total_reward = 0
            done = False

            if (episode + 1) % 100 == 0:
                pygame.init()
                screen = pygame.display.set_mode((672, 672))
                pygame.display.set_caption(f"Breakout Episode {episode+1}")
                render = True
            else:
                render = False

            while not done:
                action = self.agent.act(state, self.epsilon)
                next_obs, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                next_state = self._process_obs(next_obs)

                self.agent.buffer.add((state, action, reward, next_state, done))
                state = next_state
                total_reward += reward
                self.train_step += 1

                if self.train_step > self.config["learning_starts"] and self.train_step % self.config["train_frequency"] == 0:
                    self.agent.update(self.config["batch_size"])

                if self.train_step % self.config["target_update"] == 0:
                    self.agent.update_target()

                if self.train_step > self.config["learning_starts"]:
                    self.epsilon = max(self.epsilon * self.config["epsilon_decay"], self.config["epsilon_min"])


                if render:
                    self.render_env(next_obs, screen)

            episode_rewards.append(total_reward)
            mean_score = np.mean(episode_rewards[-50:])  #mean over last 50

            if (episode+1)%200 == 0:
                print(f"Episode {episode+1}, Reward: {total_reward}, Mean(50): {mean_score:.2f}, Epsilon: {self.epsilon:.3f}")

                if mean_score > self.best_mean_score:
                  self.best_mean_score = mean_score
                  torch.save(self.agent.q_net.state_dict(), "model.pth")
                  print(f"[Checkpoint] ✅ New best mean score: {mean_score:.2f} at episode {episode+1}. Model saved.")

            if render:
                pygame.quit()

        self.env.close()




In [None]:
if __name__ == "__main__":
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    trainer = AtariTrainer(config)
    trainer.train()

Episode 200, Reward: 1.0, Mean(50): 2.00, Epsilon: 0.579
[Checkpoint] ✅ New best mean score: 2.00 at episode 200. Model saved.
Episode 400, Reward: 12.0, Mean(50): 8.46, Epsilon: 0.050
[Checkpoint] ✅ New best mean score: 8.46 at episode 400. Model saved.
Episode 600, Reward: 10.0, Mean(50): 10.80, Epsilon: 0.050
[Checkpoint] ✅ New best mean score: 10.80 at episode 600. Model saved.
Episode 800, Reward: 13.0, Mean(50): 11.78, Epsilon: 0.050
[Checkpoint] ✅ New best mean score: 11.78 at episode 800. Model saved.
