In [None]:
!pip install gym[atari]

In [1]:
%matplotlib widget

In [2]:
from IPython.core.debugger import set_trace

In [3]:
import torch
from torch import nn
from torch.distributions import Categorical
from torch.optim import Adam, SGD

In [4]:
import numpy

In [5]:
from matplotlib import pyplot as plot

In [6]:
import copy
from time import sleep

In [7]:
import gym

In [8]:
from utils import Buffer, collect_one_episode, normalize_obs, copy_params, avg_params

In [9]:
device='cpu'

In [10]:
class ResLinear(nn.Module):
    def __init__(self, n_in, n_out, act=nn.ReLU()):
        super(ResLinear, self).__init__()
        self.act = act
        self.linear = nn.Linear(n_in, n_out)
        self.bn = nn.BatchNorm1d(n_out)
        
        assert(n_in == n_out)
    
    def forward(self, x):
        h = self.act(self.bn(self.linear(x)))
        return h + x

In [11]:
class Player(nn.Module):
    def __init__(self, n_in=128, n_hid=100, n_out=6):
        super(Player, self).__init__()
        self.layers = nn.Sequential(nn.Linear(n_in, n_hid),
                                    nn.BatchNorm1d(n_hid),
                                    nn.ReLU(),
                                    ResLinear(n_hid, n_hid, nn.ReLU()),
                                    nn.Linear(n_hid, n_out))
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, obs, normalized=False):
        if normalized:
            return self.softmax(self.layers(obs))
        else:
            return self.layers(obs)

In [12]:
class Qnet(nn.Module):
    def __init__(self, n_in=128, n_act=6, n_hid=100):
        super(Qnet, self).__init__()
        self.layers = nn.Sequential(nn.Linear(n_in, n_hid), 
                                    nn.BatchNorm1d(n_hid),
                                    nn.ReLU(),
                                    ResLinear(n_hid, n_hid, nn.ReLU()),
                                    nn.Linear(n_hid, n_act))
    
    def forward(self, obs, act):
#         set_trace()
        return self.layers(obs).gather(1, act.long())
    
    def value(self, obs):
        return self.layers(obs).max(dim=1)
    
    def q(self, obs):
        return self.layers(obs)

In [68]:
env = gym.make('Pong-ram-v0')

n_frames = 1

In [74]:
# create a policy
player = Player(n_in=128*n_frames, n_hid=32, n_out=6).to(device)

In [75]:
# create a q estimator
qnet = Qnet(n_in=128*n_frames, n_hid=32, n_act=6).to(device)
qold = Qnet(n_in=128*n_frames, n_hid=32, n_act=6).to(device)
copy_params(qnet, qold)

In [76]:
# initialize optimizers
opt_player = Adam(player.parameters(), lr=0.0001)
opt_q = Adam(qnet.parameters(), lr=0.0001)

In [77]:
# initialize replay buffer
replay_buffer = Buffer(max_items=50000, n_frames=n_frames)

In [None]:
n_iter = 1000
init_collect = 1
n_collect = 1
n_q = 150
n_policy = 150
disp_iter = 1
val_iter = 1

max_len = 1000
batch_size = 1000

ent_coeff = 0. #0.001
discount_factor = .95

q_loss = -numpy.Inf
ret = -numpy.Inf
entropy = -numpy.Inf
valid_ret = -numpy.Inf

return_history = []

for ni in range(n_iter):
    player.eval()

    if numpy.mod(ni, val_iter) == 0:
        _, _, _, _, _, ret_ = collect_one_episode(env, player, max_len=max_len, deterministic=True, n_frames=n_frames)
        return_history.append(ret_)
        if valid_ret == -numpy.Inf:
            valid_ret = ret_
        else:
            valid_ret = 0.9 * valid_ret + 0.1 * ret_
        print('Valid run', ret_, valid_ret)

    # collect some episodes using the current policy
    # and push (obs,a,r,p(a)) tuples to the replay buffer.
    nc = n_collect
    if ni == 0:
        nc = init_collect
    for ci in range(nc):
        o_, r_, c_, a_, ap_, ret_ = collect_one_episode(env, player, max_len=max_len, discount_factor=discount_factor, n_frames=n_frames)
        replay_buffer.add(o_, r_, c_, a_, ap_)
        if ret == -numpy.Inf:
            ret = ret_
        else:
            ret = 0.9 * ret + 0.1 * ret_
    
    # fit a q function
    # TD learning: min_Q (Q(s,a) - (r + \gamma \max_a' Q(s',a')))^2
    qnet.train()
    qold.eval()
    player.eval()
    for qi in range(n_q):
        opt_player.zero_grad()
        opt_q.zero_grad()
        
        batch = replay_buffer.sample(batch_size)
        
        batch_x = torch.from_numpy(numpy.stack([ex['current']['obs'] for ex in batch]).astype('float32')).to(device)
        batch_xn = torch.from_numpy(numpy.stack([ex['next']['obs'] for ex in batch]).astype('float32')).to(device)
        
        batch_y = torch.from_numpy(numpy.stack([ex['current']['rew'] for ex in batch]).astype('float32')).to(device)
        batch_a = torch.from_numpy(numpy.stack([ex['current']['act'] for ex in batch]).astype('float32')[:,None]).to(device)
        
        q_pred = qnet(batch_x, batch_a).squeeze()
        q_next = qold.value(batch_xn)[0].squeeze()
        
        loss_ = ((q_pred - (batch_y + discount_factor * q_next)) ** 2)

        # importance weighting
        batch_pi = player(batch_x, normalized=True)        
        logp = torch.log(batch_pi.gather(1, batch_a.long())).clone().detach()
        
        batch_q = torch.from_numpy(numpy.stack([ex['current']['prob'] for ex in batch]).astype('float32')).to(device)
        
        iw = torch.exp((logp - torch.log(batch_q)).clamp(max=0.))
        loss = iw * loss_
        
        loss = loss.mean()
        
        loss.backward()
        opt_q.step()
        
    if q_loss == -numpy.Inf:
        q_loss = loss_.mean().item()
    else:
        q_loss = 0.9 * q_loss + 0.1 * loss_.mean().item()
        
    copy_params(qnet, qold)
        
    if numpy.mod(ni, disp_iter) == 0:
        print('# plays', (ni+1) * n_collect, 'return', ret, 'q_loss', q_loss, 'entropy', -entropy)
    
    # fit a policy
    # advantage: (r(a,s) + \gamma * V(s') - V(s))
    qnet.eval()
    player.train()
    for pi in range(n_policy):
        opt_player.zero_grad()
        opt_q.zero_grad()
        
        batch = replay_buffer.sample(batch_size)
        
        batch_x = torch.from_numpy(numpy.stack([ex['current']['obs'] for ex in batch]).astype('float32')).to(device)
        batch_r = torch.from_numpy(numpy.stack([ex['current']['crew'] for ex in batch]).astype('float32')[:,None]).to(device)
        batch_a = torch.from_numpy(numpy.stack([ex['current']['act'] for ex in batch]).astype('float32')[:,None]).to(device)

        batch_pi = player(batch_x, normalized=True)        
        batch_pin = player(batch_xn, normalized=True)        
        
        logp = torch.log(batch_pi.gather(1, batch_a.long()))
        greedy_a = batch_pi.max(1)[1].resize_(batch_pi.size(0), 1)
        
        # V(s) = E_pi(a|s)[Q(s,a)]
#         batch_v = (qnet.q(batch_x) * batch_pi).sum(1).clone().detach()
#         batch_vn = (qnet.q(batch_xn) * batch_pin).sum(1).clone().detach()
        batch_v = qnet.value(batch_x)[0].clone().detach()
        batch_vn = qnet.value(batch_xn)[0].clone().detach()
        
        # r(s,a) + \gamma * E_pi(a|s')[Q(s',a)] - E_pi(a|s)[Q(s,a)]
        adv = batch_r + discount_factor * batch_vn - batch_v
        adv = adv / adv.abs().max().clamp(min=1.)
        
        loss = -(adv * logp)
        
        batch_q = torch.from_numpy(numpy.stack([ex['current']['prob'] for ex in batch]).astype('float32')).to(device)

        # (clipped) importance weight: 
        # because the policy may have changed since the tuple was collected.
        iw = torch.exp((logp.clone().detach() - torch.log(batch_q)).clamp(max=0.))
    
        loss = iw * loss
        
        # entropy regularization: though, it doesn't look necessary in this specific case.
        ent = (batch_pi * torch.log(batch_pi)).sum(1)
        
        if entropy == -numpy.Inf:
            entropy = ent.mean().item()
        else:
            entropy = 0.9 * entropy + 0.1 * ent.mean().item()
        
        loss = (loss + ent_coeff * ent).mean()
        
        loss.backward()
        opt_player.step()

Valid run -20.0 -20.0
# plays 1 return -18.0 q_loss 0.021508656442165375 entropy inf
Valid run -20.0 -20.0
# plays 2 return -18.2 q_loss 0.021850073523819447 entropy 1.6776935750536182
Valid run -20.0 -20.0
# plays 3 return -18.38 q_loss 0.021976803634315728 entropy 1.5746033315842718
Valid run -20.0 -20.0
# plays 4 return -18.541999999999998 q_loss 0.022038987742736934 entropy 1.4610136357878813
Valid run -20.0 -20.0
# plays 5 return -18.6878 q_loss 0.02190871332604438 entropy 1.4479458931157334
Valid run -19.0 -19.9
# plays 6 return -18.41902 q_loss 0.021644383694399146 entropy 1.26043908765652
Valid run -20.0 -19.91
# plays 7 return -17.977117999999997 q_loss 0.021101396740080786 entropy 1.0175646834039835
Valid run -20.0 -19.919
# plays 8 return -18.1794062 q_loss 0.021025721544565446 entropy 0.8565105594441192
Valid run -19.0 -19.827099999999998
# plays 9 return -18.16146558 q_loss 0.020564443936534244 entropy 0.8571373841567795
Valid run -19.0 -19.744389999999996
# plays 10 retur

In [None]:
plot.figure()

plot.plot(return_history)
plot.grid(True)
plot.xlabel('# of plays x {}'.format(n_collect))
plot.ylabel('Return over the episode of length {}'.format(max_len))

plot.show()
plot.savefig('return_log.pdf', dpi=150)

In [22]:
# let the final policy play the pong longer
player.eval()
_, _, _, _, ret_ = collect_one_episode(env, player, max_len=1000000, deterministic=True, rendering=True)

KeyboardInterrupt: 