# Imports

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
import random as rand
from gym.envs.atari.atari_env import AtariEnv

# Criação do Enviroment

In [7]:
env_si = AtariEnv(game="space_invaders", obs_type="image")
print("Número de ações:")
print(env_si.action_space)
print("Significado das ações:")
print(env_si.get_action_meanings())

Número de ações:
Discrete(6)
Significado das ações:
['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']


# Buffer do Replay para DQN

In [1]:
class ReplayBuffer():

    def __init__(self, capacity = 1000):
        
        self.buffer = np.empty((capacity, 4), dtype=object)
        self.pointer = 0
        
        for exp in range(capacity):

            self.buffer[exp][0] = np.zeros((210, 160, 3))
            self.buffer[exp][1] = 0
            self.buffer[exp][2] = 2
            self.buffer[exp][3] = np.zeros((210, 160, 3))

    def add(self, state, action, reward, state_next):

        self.pointer = (self.pointer + 1) % self.buffer.shape[0]
        self.buffer[self.pointer][0] = state
        self.buffer[self.pointer][1] = action
        self.buffer[self.pointer][2] = reward
        self.buffer[self.pointer][3] = state_next
    
    def sample(self, n):

        idx = np.random.randint(0, self.buffer.shape[0], n)
        return self.buffer[idx,:]


# Definição da Q Network

In [4]:
class QNetwork(nn.Module):

    def __init__(self):
        super(QNetwork, self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 7, padding=3),
                                   nn.ReLU(True),
                                   nn.MaxPool2d(4, 4))

        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2),
                                   nn.ReLU(True),
                                   nn.MaxPool2d(4, 4))
        
        self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1),
                                   nn.ReLU(True),
                                   nn.MaxPool2d(4, 4))

        self.linear = nn.Sequential(nn.Linear(768, 6),
                                    nn.ReLU(True))

    def forward(self, x):

        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.linear(y)
        return y


# Definição da função de loss e optimizer

In [6]:
qnet = QNetwork()
loss_function = nn.MSELoss()
optimizer = optim.SGD(qnet.parameters(), 0.00025, 0.95)