In [1]:
import gym
import torch
from torch import nn
from collections import namedtuple
import random
from tqdm.notebook import tqdm_notebook as tqdm
import numpy as np

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
Experience = namedtuple("Experience", ["state", "action", "reward", "next_state"])

In [4]:
class ReplayMemory:
    def __init__(self, memory_capacity):
        self.memory = []
        self.capacity = memory_capacity
        # the pointer should always point to the next position
        self.pointer = 1
    
    def __len__(self):
        return len(self.memory)
    
    def append(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.pointer] = experience
        self.pointer = (self.pointer + 1) % self.capacity
            
    def sample(self, sample_size):
        return random.sample(self.memory, sample_size)

In [42]:
class StateCache(ReplayMemory):
    def __init__(self, capacity = 10):
        super().__init__(capacity)
    
    def get_average_states(self):
        if len(self.memory) == 1:
            return self.memory[0]
        else:
            return torch.stack(self.memory).mean(dim = 0)

In [43]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Sequential(
                        nn.Linear(128,64),
                        nn.ReLU(),
                        nn.Linear(64,32),
                        nn.ReLU(),
                        nn.Linear(32, 16),
                        nn.ReLU(),
                        nn.Linear(16,6))
    def forward(self, x):
        x = x / 255.0 #normalize input
        x = self.weights(x)
        return x

In [44]:
def epsilon(start = 1, stop = 0.1, num = 1000000):
    epsilons = np.linspace(start, stop, num)
    for i in epsilons:
        yield i
    while True:
        yield stop

In [45]:
env = gym.make("Pong-ram-v0")

# Training Loop

In [46]:
replay_memory = ReplayMemory(10000)
dqn = DQN()
dqn.to(DEVICE)
episodes = 5000
epsilon_generator = iter(epsilon())
bs = 100
gamma = 0.999
optimizer = torch.optim.RMSprop(dqn.parameters())

env.reset()
current_state, reward, done, _ = env.step(env.action_space.sample())
current_state = torch.tensor(current_state, dtype = torch.float32, device = DEVICE)
for episode in tqdm(range(episodes)):
    env.reset()
    done = False
    while not done:
        state_cache = StateCache(5)
        state_cache.append(current_state)
        current_state = state_cache.get_average_states()
        
        if random.uniform(0,1) < next(epsilon_generator):
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                action = dqn(current_state).argmax().item()
                
        next_state, reward, done, _ = env.step(action)
        next_state = torch.tensor(next_state, dtype = torch.float32, device = DEVICE)
        
        state_cache.append(next_state)
        next_state = state_cache.get_average_states()
        
        replay_memory.append(Experience(current_state, action, reward, next_state))
        
        if len(replay_memory) < bs:
            break
        
        
        random_minibatch = replay_memory.sample(bs)
        y = torch.tensor([experience.reward for experience in random_minibatch], device = DEVICE)
        if not done:
            batch_next_state = torch.stack([exp.next_state for exp in random_minibatch])
            y += gamma * dqn(batch_next_state).max(dim = 1)[0]
        

        batch_current_state = torch.stack([experience.state for experience in random_minibatch])
        l = ((y - dqn(batch_current_state).max(dim = 1)[0])**2).mean()
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        current_state = next_state
        
        env.render()
        
    if episode%100 == 0:
        torch.save(dqn.state_dict(), f"./dqn_at_ep{episode}")
env.close()
        

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

KeyboardInterrupt: 

In [48]:
next(epsilon_generator)

0.7028611028611029

In [47]:
env.close()

In [16]:
current_state = torch.tensor(foo[0], dtype = torch.float32, device = DEVICE)

NameError: name 'foo' is not defined

In [17]:
dqn = DQN()
dqn.to(DEVICE)

DQN(
  (weights): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
    (5): ReLU()
    (6): Linear(in_features=16, out_features=6, bias=True)
  )
)

In [41]:
torch.stack([current_state, current_state]).mean(dim = 1)

tensor([71.5000, 71.5000], device='cuda:0')

In [34]:
torch.stack([current_state, current_state])

tensor([[71.6562],
        [71.6562]], device='cuda:0')

# Scratch

In [3]:
env = gym.make("Pong-ram-v0")

In [5]:
env.action_space

Discrete(6)

In [11]:
env.observation_space.sample()

array([127,  60, 165,  12,  20, 164, 119,  31,  88, 226,  47, 210,  13,
        28, 171, 211, 108, 253,   7, 246,  11, 149, 135,  62,  83,  66,
        75, 113,  67, 243,  64, 172, 253, 198,   9, 183, 210,  73, 165,
        69,   7, 105, 107,  53,  50, 159, 243, 219,  14, 201, 236,  51,
        98,  44,  91, 213, 159, 247, 214,  50,   7,  11,  96,  76, 142,
       233,  83,  90, 224,  48, 155,  49, 137, 123, 231,  72,  84, 210,
       145, 101, 239,  67, 130,  63,  19, 235, 250, 231, 168,  65,  96,
        83,  12, 198,  72, 193, 114, 162,  63, 251, 105, 111,  74,  34,
       203, 252,  74,  29, 122,  88, 230, 139,  49,  72,  88, 172, 129,
       185,  20, 104, 210, 162, 232,  44, 172,  37,  87, 129], dtype=uint8)

In [12]:
env.render()

True

In [13]:
env.close()

In [17]:
env.reset()
done = False
while not done:
    env.render()
    output, reward, done, _ = env.step(env.action_space.sample())
env.close()

(array([192,   0,  64,   0, 110,  38,   0,   7,  23,  15,   0,  63,  14,
         21,   0,  63, 255,   0,   0,   2,   0,  52,   0,  24, 128,  32,
          1,  86, 247,  86, 247,  86, 247, 134, 243, 245, 243, 240, 240,
        242, 242,  32,  32,  64,  64,  64, 188,  65, 189, 205,  52,  38,
         37,  37,  51,   0, 255,   0, 255, 109,  38,  37,  37, 192, 192,
        192, 192, 192, 192, 207, 247, 202, 247, 212, 247, 202, 247,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0, 180,  85,  54, 236, 242, 121, 240], dtype=uint8),
 0.0,
 True,
 {'ale.lives': 0, 'TimeLimit.truncated': False})