In [None]:
!pip install gym[atari]

In [1]:
%matplotlib widget

In [2]:
import torch
from torch import nn
from torch.distributions import Categorical
from torch.optim import Adam, SGD

In [3]:
import numpy

In [4]:
from matplotlib import pyplot as plot

In [5]:
import copy
from time import sleep

In [6]:
import gym

In [7]:
device='cpu'

In [8]:
class ResLinear(nn.Module):
    def __init__(self, n_in, n_out, act=nn.ReLU()):
        super(ResLinear, self).__init__()
        self.act = act
        self.linear = nn.Linear(n_in, n_out)
        self.bn = nn.BatchNorm1d(n_out)
        
        assert(n_in == n_out)
    
    def forward(self, x):
        h = self.act(self.bn(self.linear(x)))
        return h + x

In [9]:
class Player(nn.Module):
    def __init__(self, n_in=128, n_hid=100, n_out=6):
        super(Player, self).__init__()
        self.layers = nn.Sequential(nn.Linear(n_in, n_hid),
                                    nn.BatchNorm1d(n_hid),
                                    nn.ReLU(),
                                    ResLinear(n_hid, n_hid, nn.ReLU()),
                                    nn.Linear(n_hid, n_out))
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, obs, normalized=False):
        if normalized:
            return self.softmax(self.layers(obs))
        else:
            return self.layers(obs)

In [10]:
class Value(nn.Module):
    def __init__(self, n_in=128, n_hid=100):
        super(Value, self).__init__()
        self.layers = nn.Sequential(nn.Linear(n_in, n_hid),
                                    nn.BatchNorm1d(n_hid),
                                    nn.ReLU(),
                                    ResLinear(n_hid, n_hid, nn.ReLU()),
                                    nn.Linear(n_hid, 1))
    
    def forward(self, obs):
        return self.layers(obs)

In [11]:
def copy_params(from_, to_):
    for f_, t_ in zip(from_.parameters(), to_.parameters()):
        t_.data.copy_(f_.data)
        
def avg_params(from_, to_, coeff=0.95):
    for f_, t_ in zip(from_.parameters(), to_.parameters()):
        t_.data.copy_(coeff * t_.data + (1.-coeff) * f_.data)

In [12]:
def normalize_obs(obs):
    return obs.astype('float32') / 255.

In [13]:
# collect data
def collect_one_episode(env, player, max_len=50, discount_factor=0.9, deterministic=False, rendering=False, verbose=False):
    episode = []

    observations = []

    rewards = []
    crewards = []

    actions = []
    action_probs = []

    obs = env.reset()
    
    for ml in range(max_len):
        if rendering:
            env.render()
            sleep(0.05)
            
        obs = normalize_obs(obs)

        out_probs = player(torch.from_numpy(obs[None,:]).to(device), normalized=True).squeeze()
        
        if deterministic:
            action = numpy.argmax(out_probs.to('cpu').data.numpy())
            if verbose:
                print(out_probs, action)
        else:
            act_dist = Categorical(out_probs)
            action = act_dist.sample().item()
        action_prob = out_probs[action].item()

        observations.append(obs)
        actions.append(action)
        action_probs.append(action_prob)

        obs, reward, done, info = env.step(action)
        if deterministic and verbose:
            print(reward, done)
        
        rewards.append(reward)

    rewards = numpy.array(rewards)

    # it's probably not the best idea to compute the discounted cumulative returns here, but well..
    for ri in range(len(rewards)):
        factors = (discount_factor ** numpy.arange(len(rewards)-ri))
        crewards.append(numpy.sum(rewards[ri:] * factors))
        
    # discard the final 10%, because it really doesn't give me a good signal due to the unbounded horizon
    # this is only for training, not for computing the total return of the episode of the given length
    discard = max_len // 10
        
    return observations[:-discard], crewards[:-discard], actions[:-discard], action_probs[:-discard], rewards.sum()

In [14]:
# simple implementation of FIFO-based replay buffer
class Buffer:
    def __init__(self, max_items=10000):
        self.max_items = max_items
        self.buffer = []
        
    def add(self, observations, crewards, actions, action_probs):
        new_n = len(observations)
        old_n = len(self.buffer)
        if new_n + old_n > self.max_items:
            del self.buffer[:new_n]
        for o, c, a, p in zip(observations, crewards, actions, action_probs):
            self.buffer.append({'obs': o, 
                                'crew': c, 
                                'act': a, 
                                'prob': p})
            
    def sample(self, n=100):
        idxs = numpy.random.choice(len(self.buffer),n)
        return [self.buffer[ii] for ii in idxs]

In [15]:
env = gym.make('Pong-ram-v0')

In [16]:
# create a policy
player = Player(n_in=128, n_hid=128, n_out=6).to(device)

In [17]:
# create a value estimator
value = Value(n_in=128, n_hid=128).to(device)

In [18]:
# initialize optimizers
opt_player = Adam(player.parameters(), lr=0.0001)
opt_value = Adam(value.parameters(), lr=0.0001)

In [19]:
# initialize replay buffer
replay_buffer = Buffer(max_items=50000)

In [20]:
n_iter = 1000
init_collect = 1
n_collect = 1
n_value = 200
n_policy = 200
disp_iter = 1
val_iter = 1

max_len = 1000
batch_size = 1000

ent_coeff = 0. #0.001
discount_factor = .95

value_loss = -numpy.Inf
ret = -numpy.Inf
entropy = -numpy.Inf
valid_ret = -numpy.Inf

return_history = []

for ni in range(n_iter):
    player.eval()

    if numpy.mod(ni, val_iter) == 0:
        _, _, _, _, ret_ = collect_one_episode(env, player, max_len=max_len, deterministic=True)
        return_history.append(ret_)
        if valid_ret == -numpy.Inf:
            valid_ret = ret_
        else:
            valid_ret = 0.9 * valid_ret + 0.1 * ret_
        print('Valid run', ret_, valid_ret)

    # collect some episodes using the current policy
    # and push (obs,a,r,p(a)) tuples to the replay buffer.
    nc = n_collect
    if ni == 0:
        nc = init_collect
    for ci in range(nc):
        o_, c_, a_, ap_, ret_ = collect_one_episode(env, player, max_len=max_len, discount_factor=discount_factor)
        replay_buffer.add(o_, c_, a_, ap_)
        if ret == -numpy.Inf:
            ret = ret_
        else:
            ret = 0.9 * ret + 0.1 * ret_
    
    player.train()
        
    # fit a value function
    for vi in range(n_value):
        opt_player.zero_grad()
        opt_value.zero_grad()
        
        batch = replay_buffer.sample(batch_size)
        batch_x = torch.from_numpy(numpy.stack([ex['obs'] for ex in batch]).astype('float32')).to(device)
        batch_y = torch.from_numpy(numpy.stack([ex['crew'] for ex in batch]).astype('float32')).to(device)
        pred_y = value(batch_x).squeeze()
        loss_ = ((batch_y - pred_y) ** 2)
        
        batch_a = torch.from_numpy(numpy.stack([ex['act'] for ex in batch]).astype('float32')[:,None]).to(device)
        batch_pi = player(batch_x, normalized=True)
        batch_q = torch.from_numpy(numpy.stack([ex['prob'] for ex in batch]).astype('float32')).to(device)
        logp = torch.log(batch_pi.gather(1, batch_a.long()))

        # (clipped) importance weight: 
        # because the policy may have changed since the tuple was collected.
        iw = torch.exp((logp.clone().detach() - torch.log(batch_q)).clamp(max=0.))
    
        loss = iw * loss_
        
        loss = loss.mean()
        
        loss.backward()
        opt_value.step()
        
    if value_loss < 0.:
        value_loss = loss_.mean().item()
    else:
        value_loss = 0.9 * value_loss + 0.1 * loss_.mean().item()
    
    if numpy.mod(ni, disp_iter) == 0:
        print('# plays', (ni+1) * n_collect, 'return', ret, 'value_loss', value_loss, 'entropy', -entropy)
    
    # fit a policy
    for pi in range(n_policy):
        opt_player.zero_grad()
        opt_value.zero_grad()
        
        batch = replay_buffer.sample(batch_size)
        
        batch_x = torch.from_numpy(numpy.stack([ex['obs'] for ex in batch]).astype('float32')).to(device)
        batch_r = torch.from_numpy(numpy.stack([ex['crew'] for ex in batch]).astype('float32')[:,None]).to(device)
        batch_v = value(batch_x)
        batch_a = torch.from_numpy(numpy.stack([ex['act'] for ex in batch]).astype('float32')[:,None]).to(device)
        batch_q = torch.from_numpy(numpy.stack([ex['prob'] for ex in batch]).astype('float32')).to(device)

        batch_pi = player(batch_x, normalized=True)
        
        logp = torch.log(batch_pi.gather(1, batch_a.long()))
        
        # advantage
        adv = batch_r - batch_v.clone().detach()
        
        loss = -(adv * logp)
        
        # (clipped) importance weight: 
        # because the policy may have changed since the tuple was collected.
        iw = torch.exp((logp.clone().detach() - torch.log(batch_q)).clamp(max=0.))
    
        loss = iw * loss
        
        # entropy regularization: though, it doesn't look necessary in this specific case.
        ent = (batch_pi * torch.log(batch_pi)).sum(1)
        
        if entropy == -numpy.Inf:
            entropy = ent.mean().item()
        else:
            entropy = 0.9 * entropy + 0.1 * ent.mean().item()
        
        loss = (loss + ent_coeff * ent).mean()
        
        loss.backward()
        opt_player.step()

Valid run -20.0 -20.0
# plays 1 return -17.0 value_loss 0.013182186521589756 entropy inf
Valid run -19.0 -19.9
# plays 2 return -17.0 value_loss 0.012912177853286267 entropy 1.5574188317997752
Valid run -18.0 -19.71
# plays 3 return -16.900000000000002 value_loss 0.012700480185449125 entropy 1.3167072012979841
Valid run -20.0 -19.739
# plays 4 return -16.71 value_loss 0.0127997613362968 entropy 1.286901813873235
Valid run -18.0 -19.5651
# plays 5 return -16.539 value_loss 0.012775040953978899 entropy 1.1745180504039456
Valid run -17.0 -19.308590000000002
# plays 6 return -15.685100000000002 value_loss 0.013014914093576375 entropy 1.1682683885863188
Valid run -18.0 -19.177731000000005
# plays 7 return -15.916590000000003 value_loss 0.013117969909902666 entropy 1.06576830440509
Valid run -15.0 -18.759957900000003
# plays 8 return -16.024931000000002 value_loss 0.013122338749852408 entropy 1.0494577535231562
Valid run -16.0 -18.483962110000004
# plays 9 return -15.722437900000003 value_lo

KeyboardInterrupt: 

In [21]:
plot.figure()

plot.plot(return_history)
plot.grid(True)
plot.xlabel('# of plays x {}'.format(n_collect))
plot.ylabel('Return over the episode of length {}'.format(max_len))

plot.show()
plot.savefig('return_log.pdf', dpi=150)

FigureCanvasNbAgg()

In [None]:
# let the final policy play the pong longer
player.eval()
_, _, _, _, ret_ = collect_one_episode(env, player, max_len=1000000, deterministic=True, rendering=True)

RecursionError: maximum recursion depth exceeded