In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np

# HACK because multinomial not yet implemented on MPS
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']="1"

import torch
import torch.nn as nn
import gym
import scipy.signal
from pdb import set_trace

In [2]:
useMPS = False
if useMPS and torch.backends.mps.is_available():
    device = torch.device("mps") # mps for my M1 Mac
else:
    device = torch.device("cpu")
print(torch.__version__)
print(device)

1.13.0.dev20220629
cpu


## Hyperparameters and Setup

In [3]:
epochs = 10000
gamma = 0.99
lr = 1e-4
hid_dim = 64

In [4]:
env_name = "ALE/Pong-v5"
env = gym.make(env_name,
               render_mode='rgb_array'
#                render_mode='human'
              )

obs_dim = env.observation_space.shape
print("obs_dim", obs_dim)

action_dim = env.action_space.n
print("action_dim", action_dim)

obs_dim (210, 160, 3)
action_dim 6


A.L.E: Arcade Learning Environment (version 0.7.5+db37282)
[Powered by Stella]
  logger.warn(


## Model

In [5]:
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        activation = nn.Tanh
        self.last_obs = np.zeros((80,80))
        
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=6, stride=2),
            nn.MaxPool2d(2, stride=2),
            activation(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            activation(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            activation(),
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(2304, hid_dim),
            activation(),
            nn.Linear(hid_dim, hid_dim),
            activation(),
        )
        
        self.actor_head = nn.Linear(hid_dim, action_dim)
        
        self.critic_head = nn.Linear(hid_dim, 1)
        
    def forward(self, x):
        # TODO add conv layers
#         set_trace()
        body = self.preprocess(x)
        body = self.cnn(body)
        body = body.reshape(-1)
        body = self.mlp(body)
        
        logits = self.actor_head(body)
        pi = torch.distributions.Categorical(logits=logits)
        action = pi.sample()
        logp = pi.log_prob(action)
        
        state_value = self.critic_head(body).squeeze()
        
        return action.item(), logp, state_value
    
    def preprocess(self, obs):
        """Preprocess Pong observation"""

        # slice off top and bottom, and downsample
        obs = obs[34:194:2,::2,2]

        # result is diff between two frames to make easier for NN w/o memory
        result = obs - self.last_obs

        # Simplify values to {0, 1}
        result[result != 0] = 1
        result[result != 1] = 0

        self.last_obs = obs

        result = torch.as_tensor(result,
                                 dtype=torch.float32,
                                 device=device)
        result = result.unsqueeze(dim=0)
        
        return result

agent = ActorCritic().to(device)
print(agent)

ActorCritic(
  (cnn): Sequential(
    (0): Conv2d(1, 32, kernel_size=(6, 6), stride=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Tanh()
    (3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (4): Tanh()
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (6): Tanh()
  )
  (mlp): Sequential(
    (0): Linear(in_features=2304, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Tanh()
  )
  (actor_head): Linear(in_features=64, out_features=6, bias=True)
  (critic_head): Linear(in_features=64, out_features=1, bias=True)
)


In [6]:
optimizer = torch.optim.Adam(agent.parameters(), lr=lr)

In [7]:
def get_returns(rewards, discount):
    # https://stackoverflow.com/a/47971187/379547
    r = rewards[::-1]
    a = [1, -discount]
    b = [1]
    y = scipy.signal.lfilter(b, a, x=r)
    y = y[::-1].copy()
    
    
    
    
    
    return torch.as_tensor(y, dtype=torch.float32, device=device)

In [8]:
def get_loss(returns, logprobs, values):
    actor_loss = -(logprobs * returns).mean()
    critic_loss = -(values - returns).pow(2).mean()
    # TODO add entropy bonus
    return actor_loss + critic_loss

In [9]:
def train_one_epoch():
    obs = env.reset()
    done = False
    rewards = []
    logprobs = []
    values = []
    while not done:
        action, logp, value = agent(obs)
        obs, reward, done, info = env.step(action)
        rewards.append(reward)
        logprobs.append(logp)
        values.append(value)
    
    returns = get_returns(rewards, discount=gamma)
    logprobs = torch.stack(logprobs)
    values = torch.stack(values)
    
    optimizer.zero_grad()
    loss = get_loss(returns, logprobs, values)
    loss.backward()
    optimizer.step()
    return torch.as_tensor(rewards).sum().item(), loss.item(), len(rewards)

In [10]:
%time
for epoch in range(epochs):
    rewards, loss, length = train_one_epoch()
    print(f'{epoch} Rewards:{rewards}, Loss:{loss:.5f}, Ep.Length:{length}')

CPU times: user 0 ns, sys: 1 µs, total: 1 µs
Wall time: 3.1 µs
0 Rewards:-20.0, Loss:-8.32459, Ep.Length:927
1 Rewards:-21.0, Loss:-8.24611, Ep.Length:960
2 Rewards:-21.0, Loss:-10.17312, Ep.Length:843
3 Rewards:-21.0, Loss:-11.78729, Ep.Length:765
4 Rewards:-21.0, Loss:-11.82553, Ep.Length:783
5 Rewards:-20.0, Loss:-6.62542, Ep.Length:1210
6 Rewards:-21.0, Loss:-10.64589, Ep.Length:854
7 Rewards:-21.0, Loss:-10.85267, Ep.Length:853
8 Rewards:-20.0, Loss:-8.84638, Ep.Length:1008
9 Rewards:-17.0, Loss:-8.33974, Ep.Length:1113
10 Rewards:-20.0, Loss:-9.75285, Ep.Length:1010
11 Rewards:-20.0, Loss:-10.48933, Ep.Length:928
12 Rewards:-21.0, Loss:-13.59065, Ep.Length:825
13 Rewards:-20.0, Loss:-14.30078, Ep.Length:861
14 Rewards:-21.0, Loss:-15.26849, Ep.Length:855
15 Rewards:-18.0, Loss:-9.81456, Ep.Length:1146
16 Rewards:-21.0, Loss:-15.08982, Ep.Length:902
17 Rewards:-21.0, Loss:-13.33950, Ep.Length:1051
18 Rewards:-20.0, Loss:-14.20237, Ep.Length:1007
19 Rewards:-20.0, Loss:-17.51690, E

KeyboardInterrupt: 

## Watch

In [None]:
env = gym.make(env_name, render_mode='human')
obs = env.reset()
done = False
with torch.no_grad():
    while not done:
        action, _, _ = agent(obs)
        obs, _, done, _ = env.step(action)