In [76]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from collections import deque

import gymnasium as gym
env = gym.make("ALE/Breakout-v5")

In [77]:
class ReplayMemory():
    def __init__(self, max_samples):
        self.memory = deque([], maxlen=max_samples)

    def push(self, state, action, next_state, reward):
        self.memory.append((state, action, next_state, reward))

    def sample(self, sample_size):
        return np.random.choice(self.memory, size=sample_size)

    def __len__(self):
        return len(self.memory)


In [78]:
class DQN(nn.Module):
    def __init__(self, width, height, channels, output_size):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=12, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.hidden1 = nn.Linear(int(24 * (width / 4) * (height / 4)), 128)
        self.relu1 = nn.ReLU()
        self.hidden2 = nn.Linear(512, 512)
        self.relu2 = nn.ReLU()
        self.out = nn.Linear(512, output_size)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)

        x = self.hidden1(x)
        x = self.relu1(x)
        x = self.hidden2(x)
        x = self.relu2(x)
        x = self.out(x)

        return x

In [79]:
action_size = env.action_space.n
height, width, channels = env.observation_space.shape

In [80]:
#constants
action_size = env.action_space.n
height, width, channels = env.observation_space.shape
eps_max = 0.95
eps_min = 0.05
eps_decay = 100
learning_rate = 0.005
tau = 0.1
replay_memory_size = 5000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
episodes = 600

In [81]:
target_policy = DQN(width, height, channels, action_size)
running_policy = DQN(width, height, channels, action_size)
target_policy.load_state_dict(running_policy.state_dict())

memory = ReplayMemory(replay_memory_size)

In [82]:
running_optimizer = optim.Adam(target_policy.parameters(), lr=learning_rate, amsgrad=True)
bellmann_error = nn.HuberLoss()

In [96]:
for episode in range(episodes):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).permute(2, 0, 1).unsqueeze(0)
    print(info)
    break

    

{'lives': 5, 'episode_frame_number': 0, 'frame_number': 0}
tensor(True, device='cuda:0')
