In [6]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Setup environment

env = gym.make("ALE/Pong-v5", )
num_actions = env.action_space.n
num_observations = env.observation_space.shape[0]  
print(f"There are {num_actions} possible actions: {env.unwrapped.get_action_meanings()} \nand {num_observations} observations")

There are 6 possible actions: ['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'] 
and 210 observations


In [8]:
# convert rgb to grayscale
def preProcess(self, image):
        """
        Process image crop resize, grayscale and normalize the images
        """
        frame = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # To grayscale
        frame = frame[self.crop_dim[0]:self.crop_dim[1], self.crop_dim[2]:self.crop_dim[3]]  # Cut 20 px from top
        frame = cv2.resize(frame, (self.target_w, self.target_h))  # Resize
        frame = frame.reshape(self.target_w, self.target_h) / 255  # Normalize

In [7]:
class DQN(nn.Module):
    def __init__(self, num_actions):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.fc1 = nn.Linear(7*7*64, 512)
        self.relu4 = nn.ReLU()
        self.fc2 = nn.Linear(512, num_actions)


    def forward(self, x):
        x.to(device)
        x = x.unsqueeze(0)

        # block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        # block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        # block 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu4(x)
        x = self.fc2(x)

        return x
    
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [None]:
class Agent():
    def __init__(self, env, device):
        self.env = env
        self.num_actions = self.env.action_space.n
        self.device = device

        # model setup
        self.model = DQN(self.num_actions).to(device)
        self.target_model = DQN(self.num_actions).to(device)
        self.target_model.load_state_dict(self.model.state_dict())

        # buffer and optimizer setup 
        self.buffer = ReplayBuffer()
        self.alpha = 0.0001
        self.gamma = 0.99
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.alpha)
        self.loss_fn = nn.MSELoss()
        self.epsilon = 1
        self.epsilon_decay = 0.99
        self.epsilon_minimum = 0.05

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())


    def select_action(self, state):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()  # Random action
        else:
            with torch.no_grad():
                state = torch.tensor(state, dtype=torch.float32, device=self.device)
                q_values = self.model(state)
                return torch.argmax(q_values)  # Action with the highest Q-value
            
    def train(self, batch_size):
        if len(self.buffer) < batch_size:
            return 0, 0
        
        samples = self.buffer.sample(batch_size)
        state, action, reward, next_state, done = zip(*samples)

        # Convert batches to tensors
        batch_state = torch.tensor(batch_state, dtype=torch.float32, device=self.device)
        batch_action = torch.tensor(batch_action, device=self.device)
        batch_reward = torch.tensor(batch_reward, dtype=torch.float32, device=self.device)
        batch_next_state = torch.tensor(batch_next_state, dtype=torch.float32, device=self.device)
        batch_done = torch.tensor(batch_done, dtype=torch.float32, device=self.device)

        # Calculate current Q-values
        q_values = self.model(batch_state)

        # Calculate next Q-values from target model
        next_q_values = self.target_model(next_state).max(1)[0].detach()
        expected_q_values = batch_reward + (1 - done) * self.gamma * next_q_values

        loss = self.loss_fn(q_values, expected_q_values.unsqueeze(1))

        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update epsilon
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_minimum)

        return loss.item(), torch.max(q_values).item()
    
    def store_transition(self, state, action, reward, next_state, done):
        self.buffer.add(state, action, reward, next_state, done)
        
    def update_epsilon(self):
        if self.epsilon > self.epsilon_minimum:
            self.epsilon *= self.epsilon_decay
        else:
            self.epsilon = self.epsilon_minimum
        
        