In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import reward as rw
import reward.utils as U



In [2]:
GAMMA = 0.99
NUM_EPOCHS = 10
PPO_CLIP = 0.2
LR = 1e-3
BATCH_SIZE = 4096

In [3]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [4]:
env = rw.env.GymEnv("HalfCheetah-v2")
runner = rw.runner.SingleRunner(env)
batcher = rw.batcher.RolloutBatcher(runner=runner, batch_size=BATCH_SIZE)

Choosing the latest nvidia driver: /usr/lib/nvidia-390, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-390']
Choosing the latest nvidia driver: /usr/lib/nvidia-390, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-390']


In [5]:
class PolicyNN(nn.Module):
    def __init__(self, num_inputs, num_outputs, activation=nn.Tanh):
        super().__init__()
        self.activation = activation()
        
        self.hidden1 = nn.Linear(num_inputs, 64)
        self.hidden2 = nn.Linear(64, 64)
        self.mean = nn.Linear(64, num_outputs)
        self.log_std = nn.Parameter(torch.zeros(1, num_outputs))
        
    def forward(self, x):
        x = self.activation(self.hidden1(x))        
        x = self.activation(self.hidden2(x))        
        mean = self.mean(x)
        log_std = self.log_std.expand_as(mean)
        return mean, log_std        

In [6]:
class ValueNN(nn.Module):
    def __init__(self, num_inputs, activation=nn.Tanh):
        super().__init__()
        self.activation = activation()
        
        self.hidden1 = nn.Linear(num_inputs, 64)
        self.hidden2 = nn.Linear(64, 64)
        self.out = nn.Linear(64, 1)
        
    def forward(self, x):
        x = self.activation(self.hidden1(x))
        x = self.activation(self.hidden2(x))
        return self.out(x)

In [7]:
class GaussianPolicy(rw.policy.BasePolicy):
    def create_dist(self, state):
        mean, log_std = self.nn(state)
        return rw.distributions.Normal(loc=mean, scale=log_std.exp())
    
    def get_action(self, state, step):
        dist = self.create_dist(state)
        return U.to_np(dist.sample())

In [8]:
p_nn = PolicyNN(batcher.state_space.shape[0], batcher.action_space.shape[0]).to(device)
v_nn = ValueNN(batcher.state_space.shape[0]).to(device)
policy = GaussianPolicy(p_nn)

p_opt = torch.optim.Adam(p_nn.parameters(), lr=LR)
v_opt = torch.optim.Adam(v_nn.parameters(), lr=LR)

In [9]:
logger = U.Logger('/tmp/logs/half_cheetah/v0-0')
last_logged_step = 0

Writing logs to: /tmp/logs/half_cheetah/v0-0


In [10]:
for batch in batcher.get_batches(max_steps=5e5, act_fn=policy.get_action):
    batch = batch.to_tensor()
    # Calculate state value
    state_t = U.join_first_dims(batch.state_t, 2)
    v_t = v_nn(state_t)
    # Calculate return
    ret = U.estimators.discounted_sum_rewards(
        rewards=batch.reward,
        dones=batch.done,
        gamma=GAMMA,
        v_t_last=v_t[-1]
    ).detach()
    batch = batch.concat_batch()
    # Calculate advantage    
    adv = (ret - v_t).detach()
    
    # Old policy
    with torch.no_grad():
        old_dist = policy.create_dist(batch.state_t)
        old_logprob = old_dist.log_prob(batch.action).sum(-1, keepdim=True)      
            
    for _ in range(NUM_EPOCHS):        
        #### Calculate policy loss ####
        # New policy
        new_dist = policy.create_dist(batch.state_t)
        new_logprob = new_dist.log_prob(batch.action).sum(-1, keepdim=True)
        prob_ratio = (new_logprob - old_logprob).exp()
        clipped_prob_ratio = prob_ratio.clamp(min=1 - PPO_CLIP, max=1 + PPO_CLIP)
        
        assert prob_ratio.shape == adv.shape
        assert clipped_prob_ratio.shape == adv.shape
        surrogate = prob_ratio * adv
        clipped_surrogate = clipped_prob_ratio * adv
        
        losses = torch.min(surrogate, clipped_surrogate).squeeze()
        assert len(losses.shape) == 1
        p_loss = -losses.mean()
    
        #### Calculate value loss
        v_t = v_nn(state_t)
        v_loss = F.mse_loss(v_t, ret)

        # Optimize
        p_opt.zero_grad()
        p_loss.backward()
        p_opt.step()

        v_opt.zero_grad()
        v_loss.backward()
        v_opt.step()
    
    # Write logs    
    if batcher.num_steps > last_logged_step:
        last_logged_step = batcher.num_steps + 10000
        batcher.write_logs(logger)
        logger.add_log('policy/loss', p_loss)
        logger.add_log('v/loss', v_loss)
        logger.log(step=batcher.num_steps)      

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=500000), HTML(value='')), layout=Layout(displ…


                         Step 4096                         
--------------------------------------------------------------
SingleRunner/Reward                                  | -671.31
SingleRunner/Length                                  | 1000.00
policy/loss                                          |   56.80
v/loss                                               | 3810.94
--------------------------------------------------------------

                         Step 16384                         
--------------------------------------------------------------
SingleRunner/Reward                                  | -684.39
SingleRunner/Length                                  | 1000.00
policy/loss                                          |   57.45
v/loss                                               | 3820.17
--------------------------------------------------------------

                         Step 28672                         
-----------------------------------------------------------


                        Step 237568                        
--------------------------------------------------------------
SingleRunner/Reward                                 | -1166.22
SingleRunner/Length                                 |  1000.00
policy/loss                                         |    58.69
v/loss                                              |  4194.87
--------------------------------------------------------------

                        Step 249856                        
--------------------------------------------------------------
SingleRunner/Reward                                 | -1250.71
SingleRunner/Length                                 |  1000.00
policy/loss                                         |    64.43
v/loss                                              |  4988.26
--------------------------------------------------------------

                        Step 262144                        
-------------------------------------------------------------


                        Step 471040                        
--------------------------------------------------------------
SingleRunner/Reward                                 | -2113.22
SingleRunner/Length                                 |  1000.00
policy/loss                                         |    91.49
v/loss                                              | 10775.83
--------------------------------------------------------------

                        Step 483328                        
--------------------------------------------------------------
SingleRunner/Reward                                 | -2133.19
SingleRunner/Length                                 |  1000.00
policy/loss                                         |    97.97
v/loss                                              | 12168.51
--------------------------------------------------------------

                        Step 495616                        
-------------------------------------------------------------

In [11]:
%debug