In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

import random
import numpy as np
from collections import deque
import time

import pygame
from pygame import display

import gymnasium as gym
env = gym.make("ALE/Breakout-v5", render_mode="human")
#env = gym.make("ALE/Breakout-v5")

: 

In [2]:
class ReplayMemory():
    def __init__(self, max_samples):
        self.memory = deque([], maxlen=max_samples)

    def push(self, state, action, next_state, reward):
        self.memory.append((state, action, next_state, reward))

    def sample(self, sample_size):
        return random.sample(self.memory, sample_size)

    def __len__(self):
        return len(self.memory)


In [3]:
class DQN(nn.Module):
    def __init__(self, width, height, channels, output_size):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=12, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.hidden1 = nn.Linear(int((24 * np.floor((width / 4)) * np.floor((height / 4)))), 265)
        self.relu1 = nn.ReLU()
        self.hidden2 = nn.Linear(265, 265)
        self.relu2 = nn.ReLU()
        self.out = nn.Linear(265, output_size)

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.out(x)

        return x

In [4]:
#constants
action_size = env.action_space.n
height, width, channels = env.observation_space.shape

learning_rate = 0.005
tau = 0.01
gamma = 0.99
replay_memory_size = 5000
batch_size = 150

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
episodes = 600

In [5]:
target_policy = DQN(width, height, channels, action_size).to(device)
policy_net = DQN(width, height, channels, action_size).to(device)
target_policy.load_state_dict(policy_net.state_dict())

memory = ReplayMemory(replay_memory_size)

In [6]:
policy_optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate, amsgrad=True)
bellmann_error = nn.HuberLoss()

In [7]:
def choose_action(state, steps):
    eps_max = 0.95
    eps_min = 0.05
    eps_step = 1000

    threshold = eps_min + (eps_max - eps_min) * np.exp(-1 * (steps / eps_step))

    selection = np.random.rand()

    if selection > threshold:
        with torch.no_grad():
            print(state.shape)
            return policy_net(state).max(1).indices.item()
    else:
        return env.action_space.sample()


In [8]:


def optimize():
    if len(memory) < batch_size:
        return
    
    
    transitions = memory.sample(batch_size)

    start = time.time()

    #there won't be None values in state, which is why we can cat it
    states = torch.cat([t[0] for t in transitions])
    actions = torch.tensor([t[1] for t in transitions], device=device).unsqueeze(-1)
    rewards = torch.tensor([t[3] for t in transitions], device=device)
    
    #there will be None values in next_state which is why we cant cat it -> create mask
    next_states = [t[2] for t in transitions]
    non_final_next_states_idxs = torch.tensor([i for i, t in enumerate(next_states) if t is not None], device=device)
    non_final_next_states = torch.cat([t for t in next_states if t is not None])
    
    end = time.time()
    print("duration: ", end-start)
    
    start = time.time()
    
    q_values_policy_net = policy_net(states).gather(1, actions)
    
    q_values_next_states = torch.zeros(batch_size, dtype=torch.float32, device=device)
    
    with torch.no_grad():
        q_values_next_states[non_final_next_states_idxs] = target_policy(non_final_next_states).max(1).values
        
    expected_q_values = rewards + (gamma * q_values_next_states)
    
    loss = bellmann_error(q_values_policy_net, expected_q_values.unsqueeze(-1))
    
    
    policy_optimizer.zero_grad()
    
    loss.backward()
    print("loss: ", loss.item())
    
    end = time.time()
    
    
    policy_optimizer.step()
    
    print("Optimization and Inference time: ", end-start)

: 

In [9]:
for episode in range(episodes):
    state, info = env.reset()

    display.set_mode((600, 500))

    state = (torch.tensor(state, device=device).float() / 255).permute(2, 0, 1).unsqueeze(0)
    
    done = False
    step = 0
    
    print("Episode Done")
    print("------------------------------------------------------")
    
    while not done:
        action = choose_action(state, step)
        

        new_state, reward, terminated, truncated, _ = env.step(action)

        done = terminated or truncated
        
        print(done)

        if terminated:
            new_state = None
        else:
            new_state = (torch.tensor(new_state, device=device).float() / 255).permute(2, 0, 1).unsqueeze(0)

        memory.push(state, action, new_state, reward)

        state = new_state

        # optimize step here
        optimize()
        # until here

        #update weights of target with policy
        target_dic = target_policy.state_dict()
        policy_dic = policy_net.state_dict()

        for keys in target_dic:
            target_dic[keys] = policy_dic[keys] * tau + target_dic[keys] * (1-tau)

        target_policy.load_state_dict(target_dic)

        step += 1
        print(step)

    

Episode Done
------------------------------------------------------
False
1
torch.Size([1, 3, 210, 160])
False
2
False
3
False
4
False
5
False
6
torch.Size([1, 3, 210, 160])
False
7
False
8
False
9
False
10
False
11
False
12
False
13
False
14
False
15
False
16
False
17
False
18
torch.Size([1, 3, 210, 160])
False
19
False
20
False
21
False
22
torch.Size([1, 3, 210, 160])
False
23
False
24
False
25
torch.Size([1, 3, 210, 160])
False
26
False
27
False
28
False
29
False
30
False
31
False
32
False
33
False
34
False
35
False
36
False
37
torch.Size([1, 3, 210, 160])
False
38
False
39
False
40
False
41
False
42
False
43
False
44
False
45
False
46
torch.Size([1, 3, 210, 160])
False
47
False
48
False
49
False
50
False
51
False
52
False
53
False
54
False
55
False
56
False
57
False
58
False
59
False
60
False
61
False
62
False
63
False
64
torch.Size([1, 3, 210, 160])
False
65
torch.Size([1, 3, 210, 160])
False
66
False
67
torch.Size([1, 3, 210, 160])
False
68
False
69
False
70
False
71
False
72
Fal