In [1]:
import torch
from torch import nn
from torch.distributions import Categorical
from torch.optim import Adam, SGD

In [2]:
import copy
from time import sleep

In [3]:
import numpy

In [4]:
import gym

In [38]:
class Player(nn.Module):
    def __init__(self, n_in=128, n_hid=100, n_out=6):
        super(Player, self).__init__()
        self.layers = nn.Sequential(nn.Linear(n_in, n_hid),
                                   nn.Tanh(),
                                   nn.Linear(n_hid, n_out))
        self.softmax = nn.Softmax()
    
    def forward(self, obs, normalized=False):
        if normalized:
            return self.softmax(self.layers(obs))
        else:
            return self.layers(obs)

In [39]:
class Value(nn.Module):
    def __init__(self, n_in=128, n_hid=100):
        super(Value, self).__init__()
        self.layers = nn.Sequential(nn.Linear(n_in, n_hid),
                                   nn.Tanh(),
                                   nn.Linear(n_hid, 1))
    
    def forward(self, obs):
        return self.layers(obs)

In [40]:
def copy_params(from_, to_):
    for f_, t_ in zip(from_.parameters(), to_.parameters()):
        t_.data.copy_(f_.data)

In [41]:
def normalize_obs(obs):
    return obs.astype('float32') / 255.

In [42]:
env = gym.make('Pong-ram-v0')

In [43]:
# collect data
def collect_one_episode(env, player, max_len=50, discount_factor=0.9, deterministic=False, rendering=False):
    episode = []

    observations = []

    rewards = []
    crewards = []

    actions = []
    action_probs = []

    obs = env.reset()
    for ml in range(max_len):
        if rendering:
            env.render()
            sleep(0.5)
        obs = normalize_obs(obs)

        out_probs = player(torch.from_numpy(obs), normalized=True)
        if deterministic:
            action = numpy.argmax(out_probs.data.numpy())
        else:
            act_dist = Categorical(out_probs)
            action = act_dist.sample().item()
        action_prob = out_probs[action].item()

        observations.append(obs)
        actions.append(action)
        action_probs.append(action_prob)

        obs, reward, done, info = env.step(action)
        
#         if not deterministic:
#             reward = reward - (out_probs.data.numpy() * numpy.log(out_probs.data.numpy())).sum()

        rewards.append(reward)

    rewards = numpy.array(rewards)

    for ri in range(len(rewards)):
        factors = (discount_factor ** numpy.arange(len(rewards)-ri))
        crewards.append(numpy.sum(rewards[ri:] * factors))

    return observations, crewards, actions, action_probs, rewards.sum()

In [44]:
class Buffer:
    def __init__(self, max_items=10000):
        self.max_items = max_items
        self.buffer = []
        
    def add(self, observations, crewards, actions, action_probs):
        new_n = len(observations)
        old_n = len(self.buffer)
        if new_n + old_n > self.max_items:
            del self.buffer[:new_n]
        for o, c, a, p in zip(observations, crewards, actions, action_probs):
            self.buffer.append((o, c, a, p))
            
    def sample(self, n=100):
        idxs = numpy.random.choice(len(self.buffer),n)
        return [self.buffer[ii] for ii in idxs]

In [70]:
# create two models
player = Player(n_in=128, n_hid=32, n_out=6)
player_old = Player(n_in=128, n_hid=32, n_out=6)
copy_params(player, player_old)

In [65]:
# create a value estimator
value = Value(n_in=128, n_hid=32)

In [66]:
opt_player = SGD(player.parameters(), lr=0.0001)
opt_value = Adam(value.parameters(), lr=0.0001)

In [67]:
# initialize replay buffer
replay_buffer = Buffer(max_items=5000)

In [68]:
# o_, c_, a_, ap_, ret_ = collect_one_episode(env, player, max_len=10000, discount_factor=0.9, rendering=True)

In [69]:
n_iter = 1000 #1000
n_collect = 10 #100
n_value = 100 #100
n_policy = 5 #10
disp_iter = 1
val_iter = 1

max_len = 1000
batch_size = 1000

coeff = 0. #1.
ent_coeff = 0. #0.001
discount_factor = 0.9

value_loss = -numpy.Inf
ret = -numpy.Inf
entropy = -numpy.Inf

for ni in range(n_iter):
    if numpy.mod(ni, val_iter) == 0:
        _, _, _, _, ret_ = collect_one_episode(env, player, max_len=max_len, deterministic=True)
        print('Valid run', ret_)
    
    for ci in range(n_collect):
        o_, c_, a_, ap_, ret_ = collect_one_episode(env, player, max_len=max_len, discount_factor=discount_factor)
        replay_buffer.add(o_, c_, a_, ap_)
        if ret == -numpy.Inf:
            ret = ret_
        else:
            ret = 0.9 * ret + 0.1 * ret_
    
    # fit a value function
    for vi in range(n_value):
        opt_value.zero_grad()
        
        batch = replay_buffer.sample(batch_size)
        batch_x = torch.from_numpy(numpy.stack([ex[0] for ex in batch]).astype('float32'))
        batch_y = torch.from_numpy(numpy.stack([ex[1] for ex in batch]).astype('float32'))
        pred_y = value(batch_x).squeeze()
        loss = ((batch_y - pred_y) ** 2)
        
        batch_q = torch.from_numpy(numpy.stack([ex[3] for ex in batch]).astype('float32'))
        logp = torch.log(batch_pi.gather(1, batch_a.long()))

        iw = torch.exp((logp.detach() - torch.log(batch_q)).clamp(max=0.))
    
#         print('iw', iw.mean())
        
        loss = iw * loss
        
        loss = loss.mean()
        
        loss.backward()
        opt_value.step()
        
    if value_loss < 0.:
        value_loss = loss.item()
    else:
        value_loss = 0.9 * value_loss + 0.1 * loss.item()
    
    if numpy.mod(ni, disp_iter) == 0:
        print('# plays', (ni+1) * n_collect, 'return', ret, 'value_loss', value_loss, 'neg entropy', entropy)
    
    # fit a policy
    for pi in range(n_policy):
        opt_player.zero_grad()
        
        batch = replay_buffer.sample(batch_size)
        batch_x = torch.from_numpy(numpy.stack([ex[0] for ex in batch]).astype('float32'))
        batch_r = torch.from_numpy(numpy.stack([ex[1] for ex in batch]).astype('float32')[:,None])
        batch_v = value(batch_x)
        batch_a = torch.from_numpy(numpy.stack([ex[2] for ex in batch]).astype('float32')[:,None])
        batch_q = torch.from_numpy(numpy.stack([ex[3] for ex in batch]).astype('float32'))

        batch_pi = player(batch_x, normalized=True)
        batch_pi_old = player_old(batch_x, normalized=True)
        
        logp = torch.log(batch_pi.gather(1, batch_a.long()))
        logp_old = torch.log(batch_pi_old.gather(1, batch_a.long()))
        
        loss = -((batch_r - batch_v.detach()) * logp)
        
#         print('adv', (batch_r - batch_v).mean().item())

        iw = torch.exp((logp.detach() - torch.log(batch_q)).clamp(max=0.))
    
#         print('iw', iw.mean())
        
        loss = iw * loss
        
        kl = -(batch_pi_old * torch.log(batch_pi)).sum(1)
        ent = -(batch_pi * torch.log(batch_pi)).sum(1)
        
        if entropy == -numpy.Inf:
            entropy = ent.mean().item()
        else:
            entropy = 0.9 * entropy + 0.1 * ent.mean().item()
        
        loss = (loss + coeff * kl - ent_coeff * ent).mean()
        
        loss.backward()
        opt_player.step()
        
    copy_params(player, player_old)

  # This is added back by InteractiveShellApp.init_path()


Valid run -20.0
# plays 10 return -14.546244003000002 value_loss 0.06610212475061417 neg entropy -inf
Valid run -20.0
# plays 20 return -15.139324659880023 value_loss 0.06532078608870506 neg entropy 1.7746279826521874
Valid run -20.0
# plays 30 return -16.18089329567443 value_loss 0.0650927808135748 neg entropy 1.7744690969500816
Valid run -20.0
# plays 40 return -14.504050098460315 value_loss 0.06350601192563773 neg entropy 1.774468991636902
Valid run -20.0
# plays 50 return -14.980427906663396 value_loss 0.06259739116951824 neg entropy 1.7740966458308804
Valid run -20.0
# plays 60 return -15.690164217225906 value_loss 0.06149487390629947 neg entropy 1.7743870604394663
Valid run -20.0
# plays 70 return -16.63365527147517 value_loss 0.06037816206675023 neg entropy 1.7740050872715298
Valid run -20.0
# plays 80 return -15.282565009019107 value_loss 0.05823260366971419 neg entropy 1.7742102274364973
Valid run -20.0


KeyboardInterrupt: 