In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from reward.utils import torch_utils



In [3]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import reward as rw
import reward.utils as U

In [4]:
GAMMA = 0.99

In [5]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [6]:
env = rw.envs.GymEnv("CartPole-v0")
runner = rw.runners.SingleRunner(env)
batcher = rw.batchers.RolloutBatcher(runner=runner, batch_size=512)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [7]:
class PolicyNN(nn.Module):
    def __init__(self, num_inputs, num_outputs, activation=nn.Tanh):
        super().__init__()
        self.activation = activation()
        
        self.hidden = nn.Linear(num_inputs, 64)
        self.out = nn.Linear(64, num_outputs)
        
    def forward(self, x):
        x = self.activation(self.hidden(x))
        return self.out(x)  

In [8]:
class ValueNN(nn.Module):
    def __init__(self, num_inputs, activation=nn.Tanh):
        super().__init__()
        self.activation = activation()
        
        self.hidden = nn.Linear(num_inputs, 64)
        self.out = nn.Linear(64, 1)
        
    def forward(self, x):
        x = self.activation(self.hidden(x))
        return self.out(x)

In [9]:
class CategoricalPolicy(rw.policy.BasePolicy):
    def create_dist(self, state):
        logits = self.nn(state)
        return rw.distributions.Categorical(logits=logits)
    
    def get_action(self, state, step):
        dist = self.create_dist(state)
        return U.to_np(dist.sample())

In [10]:
p_nn = PolicyNN(batcher.state_info.shape[0], batcher.action_info.shape).to(device)
v_nn = ValueNN(batcher.state_info.shape[0]).to(device)
policy = CategoricalPolicy(p_nn)

p_opt = torch.optim.Adam(p_nn.parameters(), lr=1e-3)
v_opt = torch.optim.Adam(v_nn.parameters(), lr=1e-3)

In [11]:
logger = U.Logger('logs/cart_pole/v0-1')
last_logged_step = 0

Writing logs to: logs/cart_pole/v0-1


In [12]:
for batch in batcher.get_batches(max_steps=3e5, get_action_fn=policy.get_action):
    batch = batch.to_tensor()
    # Calculate state value
    state_t = U.join_first_dims(batch.state_t, 2)
    v_t = v_nn(state_t)
    # Calculate return
    ret = U.estimators.discounted_sum_rewards(
        rewards=batch.reward,
        dones=batch.done,
        gamma=GAMMA,
        v_t_last=v_t[-1]
    ).detach()
    batch = batch.concat_batch()
    # Calculate advantage    
    adv = (ret - v_t).detach()
    
    # Calculate policy loss
    dist = policy.create_dist(batch.state_t)
    log_prob = dist.log_prob(batch.action)
    assert ret.shape == log_prob.shape
    p_loss = -(adv * log_prob).mean()
    
    # Calculate value loss
    v_loss = F.mse_loss(v_t, ret)
    
    # Optimize
    p_opt.zero_grad()
    p_loss.backward()
    p_opt.step()
    
    v_opt.zero_grad()
    v_loss.backward()
    v_opt.step()
    
    # Write logs    
    if batcher.num_steps > last_logged_step:
        last_logged_step = batcher.num_steps + 10000
        batcher.write_logs(logger)
        logger.add_log('policy/loss', p_loss)
        logger.add_log('v/loss', v_loss)
        logger.log(step=batcher.num_steps)    

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=300000), HTML(value='')), layout=Layout(displ…


--------------------------------------------------------------
Env/Reward/Episode (New)                              |  25.63
Env/Length/Episode (New)                              |  25.63
Env/Reward/Episode (Last 50)                          |  25.63
Env/Length/Episode (Last 50)                          |  25.63
policy/loss                                           |  10.62
v/loss                                                | 368.68
--------------------------------------------------------------

--------------------------------------------------------------
Env/Reward/Episode (New)                              |  30.22
Env/Length/Episode (New)                              |  30.22
Env/Reward/Episode (Last 50)                          |  34.80
Env/Length/Episode (Last 50)                          |  34.80
policy/loss                                           |  12.24
v/loss                                                | 509.28
-----------------------------------------------------


--------------------------------------------------------------
Env/Reward/Episode (New)                             |  188.75
Env/Length/Episode (New)                             |  188.75
Env/Reward/Episode (Last 50)                         |  192.12
Env/Length/Episode (Last 50)                         |  192.12
policy/loss                                          |   17.56
v/loss                                               | 1431.57
--------------------------------------------------------------

--------------------------------------------------------------
Env/Reward/Episode (New)                             |  190.74
Env/Length/Episode (New)                             |  190.74
Env/Reward/Episode (Last 50)                         |  190.18
Env/Length/Episode (Last 50)                         |  190.18
policy/loss                                          |   20.02
v/loss                                               | 1861.77
-----------------------------------------------------