# Imports

In [43]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
import random as rand
from copy import deepcopy
from gym.envs.atari.atari_env import AtariEnv

# Criação do Enviroment

In [44]:
env_si = AtariEnv(game="space_invaders", obs_type="image")
print("Número de ações:")
print(env_si.action_space)
print("Significado das ações:")
print(env_si.get_action_meanings())

Número de ações:
Discrete(6)
Significado das ações:
['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']


# Buffer do Replay para DQN

In [45]:
class ReplayBuffer():

    def __init__(self, capacity = 1000):
        
        self.buffer = np.empty((capacity, 4), dtype=object)
        self.pointer = 0

    def add(self, state, action, reward, state_next):

        self.pointer = (self.pointer + 1) % self.buffer.shape[0]
        self.buffer[self.pointer][0] = state
        self.buffer[self.pointer][1] = action
        self.buffer[self.pointer][2] = reward
        self.buffer[self.pointer][3] = state_next
    
    def sample(self, n):

        idx = np.random.randint(0, self.buffer.shape[0], n)
        return self.buffer[idx,:]

    def save(self):

        np.save("dqnmemorybuffer", self.buffer)

    def load(self):

        self.buffer = np.load(file="dqnmemorybuffer.npy", allow_pickle=True)


# Definição da Q Network

In [46]:
class QNetwork(nn.Module):

    def __init__(self):
        super(QNetwork, self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 7, padding=3),
                                   nn.ReLU(True),
                                   nn.MaxPool2d(4, 4))

        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, padding=2),
                                   nn.ReLU(True),
                                   nn.MaxPool2d(4, 4))
        
        self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1),
                                   nn.ReLU(True),
                                   nn.MaxPool2d(4, 4))

        self.linear = nn.Sequential(nn.Linear(768, 6),
                                    nn.ReLU(True))

    def forward(self, x):

        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.linear(y)
        return y


# Definição da função de loss e optimizer

In [47]:
qnet = QNetwork()
loss_function = nn.MSELoss()
optimizer = optim.SGD(qnet.parameters(), 0.00025, 0.95)

# Povoamento do Buffer

In [48]:
%%script false --no-raise-error
state = env_si.render(mode="rgb_array").transpose((2,0,1))
buffer = ReplayBuffer()

for i in range(1000):

    action = rand.randint(0,5)
    next_state, reward, done, _ = env_si.step(action)
    next_state = next_state.transpose((2,0,1))
    buffer.add(state, action, reward, next_state)
    state = next_state

    if done:
        env_si.reset()
        env_si.seed()

env_si.reset()
buffer.save()

# Carregar Buffer salvo

In [49]:
buffer = ReplayBuffer()
buffer.load()

# Treino

In [None]:
loss_array = np.array([])
episode_reward_array = np.array([])

epsilon = 1
min_epsilon = 0.1
n_episodes = 1000

for episode in range(n_episodes):

    if episode%100==0:

        qnet_copy = deepcopy(qnet)
        epsilon = max(0.1, epsilon-0.1)

    env_si.reset()
    env_si.seed()
    state = env_si.render(mode="rgb_array").transpose((2,0,1))

    done = False
    total_reward = 0

    while not done:

        a = rand.random()

        if a < epsilon:

            action = rand.randint(0, 5)

        else:

            action = torch.argmax(qnet_copy.forward(state))

        next_state, reward, done, _ = env_si.step(action)
        buffer.add(state, action, reward, next_state)
        total_reward += reward

        samples = buffer.sample()
        Xs = np.zeros((32,3,210,160))
        Ys = np.zeros((32,6))

        for i in range(32):
            Xs[i] = samples[i][0]
            Ys[i] = reward + qnet_copy.forward(next_state)
        
        optimizer.zero_grad()
        output = qnet.forward(Xs)
        loss = loss_function(output, Ys)
        loss.backward()
        optimizer.step()

        loss_array = np.append(loss_array, loss)
        print(f"Loss é: {loss}")

    episode_reward_array = np.append(episode_reward_array, total_reward)
    print(f"Reward total do episódio foi {total_reward}")

            

    

