In [1]:
import gymnasium as gym
import ale_py
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
ENV = "PongDeterministic-v4"
BATCH_SIZE = 64
GAMMA = 0.97
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY = 0.99
TARGET_UPDATE = 1000
MAX_MEMORY_SIZE = 25000
MIN_MEMORY_SIZE = 24000
LEARNING_RATE = 0.00025

In [None]:
class DualDQN(nn.Module):
    def __init__(self, action_space):
        super(DualDQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.advantage_stream = nn.Linear(1536, 128)
        self.advantage_stream2 = nn.Linear(128, action_space)
        self.value_stream = nn.Linear(1536, 128)
        self.value_stream2 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)

        value = self.value_stream2(F.leaky_relu(self.value_stream(x)))
        advantage = self.advantage_stream2(F.leaky_relu(self.advantage_stream(x)))
        q_values = value + (advantage - advantage.mean())

        return q_values

In [None]:
gym.register_envs(ale_py)
env = gym.make(ENV)
policy_net = DualDQN(env.action_space.n).to(device)
target_net = DualDQN(env.action_space.n).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
memory = deque(maxlen=MAX_MEMORY_SIZE)
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)

In [None]:
def preprocess_observation(image):
    target_h = 80
    target_w = 64
    crop_dim = [20, image.shape[0], 0, image.shape[1]]
    frame = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    frame = frame[crop_dim[0] : crop_dim[1], crop_dim[2] : crop_dim[3]]
    frame = cv2.resize(frame, (target_w, target_h))
    frame = frame.reshape(target_w, target_h) / 255
    return torch.tensor(frame, dtype=torch.float32).to(device=device)


In [None]:
def optimize_model():
    if len(memory) < MIN_MEMORY_SIZE:
        return

    transitions = random.sample(memory, BATCH_SIZE)
    batch_state, batch_action, batch_reward, batch_next_state, batch_done = zip(
        *transitions
    )

    batch_state = torch.cat(batch_state).unflatten(0, (BATCH_SIZE, -1)).to(device)
    batch_action = torch.tensor(batch_action).to(device)
    batch_reward = torch.tensor(batch_reward).to(device)
    batch_next_state = (
        torch.cat(batch_next_state).unflatten(0, (BATCH_SIZE, -1)).to(device)
    )
    batch_done = torch.tensor(batch_done, dtype=torch.bool).to(device)

    current_q_values = (
        policy_net(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    )
    next_q_values = policy_net(batch_next_state)
    next_target_q_values = (
        target_net(batch_next_state)
        .gather(1, next_q_values.max(1)[1].unsqueeze(1))
        .squeeze(1)
    )
    expected_q_values = batch_reward + (GAMMA * next_target_q_values) * (~batch_done)

    loss = nn.SmoothL1Loss()(current_q_values, expected_q_values.detach())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
def select_action(state, epsilon, action_space):
    if random.random() < epsilon:
        return random.randrange(action_space)
    else:
        with torch.no_grad():
            state = state.unsqueeze(0)
            return torch.argmax(policy_net(state)).item()

In [None]:
def run_train():
    num_episodes = 1000
    epsilon = EPSILON_START
    writer = SummaryWriter(
        f'runs/pong_dqn_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
    )
    global_step = 0
    for episode in tqdm(range(num_episodes), desc="Training Episodes"):
        obs, _ = env.reset()
        state = preprocess_observation(obs)
        state = torch.stack([state] * 4, axis=0)

        total_reward = 0
        done = False
        steps = 0
        total_loss = 0

        while not done:
            action = select_action(state, epsilon, env.action_space.n)
            next_obs, reward, done, truncated, _ = env.step(action)
            done = done or truncated
            total_reward += reward
            next_state = preprocess_observation(next_obs)
            next_state = torch.stack((next_state, state[0], state[1], state[2]))

            memory.append((state, action, reward, next_state, done))
            state = next_state
            steps += 1
            global_step += 1

            loss = optimize_model()
            if loss is not None:
                total_loss += loss

            if steps % TARGET_UPDATE == 0:
                target_net.load_state_dict(policy_net.state_dict())

            if global_step % 1000 == 0:
                if epsilon > EPSILON_END:
                    epsilon *= EPSILON_DECAY

        writer.add_scalar("Total Loss", total_loss, episode)
        writer.add_scalar("Steps", steps, episode)
        writer.add_scalar("Epsilon", epsilon, episode)
        writer.add_scalar("Total Reward", total_reward, episode)
    writer.close()

In [None]:
run_train()

In [20]:
# env.reset()
# obs = env.step(0)
# obs = preprocess_observation(obs[0])