In [1]:
import pdb
import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import reward as rw
import reward.utils as U



In [2]:
ENV = 'InvertedDoublePendulum-v2'
LOG_DIR = 'logs/doublependulum/repar-prioritized-schedule-v0-1'
REPAR = True
REWARD_SCALE = 1.
CLIP_GRAD = float('inf')
GAMMA = 0.99
TARGET_UP_WEIGHT = 0.01
BATCH_SIZE = 256
MAX_STEPS = 80000
LOG_FREQ = 1000
REPLAY_BUFFER_MAXLEN = 80000

In [3]:
pr_weight = U.schedules.linear_schedule(.7, .1, final_step=MAX_STEPS)

In [4]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [5]:
env = rw.envs.GymEnv(ENV)
env = rw.envs.wrappers.ActionBound(env)
runner = rw.runners.SingleRunner(env)
batcher = rw.batchers.PrReplayBatcher(
    runner=runner,
    batch_size=256,
    replay_buffer_maxlen=REPLAY_BUFFER_MAXLEN,
    learning_freq=1,
    grad_steps_per_batch=1,
    pr_weight=pr_weight,
    transforms=[
#         rw.batchers.transforms.StateRunNorm(),        
    ],
)

state_features = batcher.get_state_info().shape[0]
num_actions = batcher.get_action_info().shape[0]

Choosing the latest nvidia driver: /usr/lib/nvidia-390, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-390']
Choosing the latest nvidia driver: /usr/lib/nvidia-390, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-390']
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [6]:
class PolicyNN(nn.Module):
    def __init__(self, num_inputs, num_outputs, hidden_units=256,
                 activation=nn.ReLU, log_std_range=(-20, 2)):
        super().__init__()
        self.log_std_range = log_std_range
        
        layers = []
        layers += [nn.Linear(num_inputs, hidden_units), activation()]
        layers += [nn.Linear(hidden_units, hidden_units), activation()]
        self.layers = nn.Sequential(*layers)
        
        self.mean = nn.Linear(hidden_units, num_outputs)
        self.mean.weight.data.uniform_(-3e-3, 3e-3)
        self.mean.bias.data.uniform_(-3e-3, 3e-3)
        
        self.log_std = nn.Linear(hidden_units, num_outputs)
        self.log_std.weight.data.uniform_(-3e-3, 3e-3)
        self.log_std.bias.data.uniform_(-3e-3, 3e-3)
        
    def forward(self, x):
        x = self.layers(x)
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(*self.log_std_range)
        return mean, log_std        

In [7]:
class ValueNN(nn.Module):
    def __init__(self, num_inputs, hidden_units=256, activation=nn.ReLU):
        super().__init__()
        
        layers = []
        layers += [nn.Linear(num_inputs, hidden_units), activation()]
        layers += [nn.Linear(hidden_units, hidden_units), activation()]
        final_layer = nn.Linear(hidden_units, 1)
        final_layer.weight.data.uniform_(-3e-3, 3e-3)
        final_layer.bias.data.uniform_(-3e-3, 3e-3)
        layers += [final_layer]
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)

In [8]:
class QValueNN(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_units=256, activation=nn.ReLU):
        super().__init__()
        
        layers = []
        layers += [nn.Linear(num_inputs + num_actions, hidden_units), activation()]
        layers += [nn.Linear(hidden_units, hidden_units), activation()]
        final_layer = nn.Linear(hidden_units, 1)
        final_layer.weight.data.uniform_(-3e-3, 3e-3)
        final_layer.bias.data.uniform_(-3e-3, 3e-3)
        layers += [final_layer]
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        state, action = x
        x = torch.cat([state, action], dim=1)
        return self.layers(x)

In [9]:
class TanhNormalPolicy(rw.policy.BasePolicy):
    def create_dist(self, state):
        parameters = self.nn(state)
        mean, log_std = parameters
        return rw.distributions.TanhNormal(loc=mean, scale=log_std.exp())

    def get_action(self, state, step):
        dist = self.create_dist(state=state)
        action = U.to_np(dist.sample())
        assert not np.isnan(action).any()
        return action

In [10]:
p_nn = PolicyNN(num_inputs=state_features, num_outputs=num_actions).to(device)
v_nn = ValueNN(num_inputs=state_features).to(device)
v_nn_target = ValueNN(num_inputs=state_features).to(device).eval()
q1_nn = QValueNN(num_inputs=state_features, num_actions=num_actions).to(device)
q2_nn = QValueNN(num_inputs=state_features, num_actions=num_actions).to(device)

In [11]:
U.copy_weights(from_nn=v_nn, to_nn=v_nn_target, weight=1.)

In [12]:
policy = TanhNormalPolicy(nn=p_nn)

In [13]:
p_opt = torch.optim.Adam(p_nn.parameters(), lr=3e-4)
v_opt = torch.optim.Adam(v_nn.parameters(), lr=3e-4)
q1_opt = torch.optim.Adam(q1_nn.parameters(), lr=3e-4)
q2_opt = torch.optim.Adam(q2_nn.parameters(), lr=3e-4)

In [14]:
logger = U.Logger(LOG_DIR)

Writing logs to: logs/doublependulum/repar-prioritized-schedule-v0-1


In [15]:
batcher.populate(n=1000, get_action_fn=policy.get_action)

Populating Replay Buffer...


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




In [None]:
for batch in batcher.get_batches(MAX_STEPS, policy.get_action):
    batch = batch.to_tensor().concat_batch()

    ##### Calculate losses ######
    q1_batch = q1_nn((batch.state_t, batch.action))
    q2_batch = q2_nn((batch.state_t, batch.action))
    v_batch = v_nn(batch.state_t)

    dist = policy.create_dist(batch.state_t)
    if REPAR:
        action, pre_tanh_action = dist.rsample_with_pre()
    else:
        action, pre_tanh_action = dist.sample_with_pre()
    log_prob = dist.log_prob_pre(pre_tanh_action).sum(-1, keepdim=True)
    log_prob /= float(REWARD_SCALE)

    # Q loss
    v_target_tp1 = v_nn_target(batch.state_tp1)
    q_t_next = U.estimators.td_target(rewards=batch.reward, dones=batch.done,
                                         v_tp1=v_target_tp1, gamma=GAMMA)
    q1_loss = F.mse_loss(q1_batch, q_t_next.detach())
    q2_loss = F.mse_loss(q2_batch, q_t_next.detach())

    # V loss    
    q1_new_t = q1_nn((batch.state_t, action))
    q2_new_t = q2_nn((batch.state_t, action))
    q_new_t = torch.min(q1_new_t, q2_new_t)
    next_value = q_new_t - log_prob
    v_loss = F.mse_loss(v_batch, next_value.detach())

    # Policy loss    
    if REPAR:        
        p_loss = (log_prob - q_new_t).mean()
    else:
        next_log_prob = q_new_t - v_batch
        p_loss = (log_prob * (log_prob - next_log_prob).detach()).mean()
    # Policy regularization losses
    mean_loss = 1e-3 * dist.loc.pow(2).mean()
    log_std_loss = 1e-3 * dist.scale.log().pow(2).mean()
    pre_tanh_loss = 0 * pre_tanh_action.pow(2).sum(1).mean()
    # Combine all losses
    p_loss += mean_loss + log_std_loss + pre_tanh_loss

    ###### Optimize ######
    q1_opt.zero_grad()
    q1_loss.backward()
#     torch.nn.utils.clip_grad_norm_(q1_nn.parameters(), CLIP_GRAD)
    q1_grad = U.mean_grad(q1_nn)
    q1_opt.step()

    q2_opt.zero_grad()
    q2_loss.backward()
#     torch.nn.utils.clip_grad_norm_(q2_nn.parameters(), CLIP_GRAD)
    q2_grad = U.mean_grad(q2_nn)
    q2_opt.step()

    v_opt.zero_grad()
    v_loss.backward()
#     torch.nn.utils.clip_grad_norm_(v_nn.parameters(), CLIP_GRAD)
    v_grad = U.mean_grad(v_nn)
    v_opt.step()

    p_opt.zero_grad()
    p_loss.backward()
#     torch.nn.utils.clip_grad_norm_(p_nn.parameters(), CLIP_GRAD)
    p_grad = U.mean_grad(p_nn)
    p_opt.step()

    ###### Update target value network ######
    U.copy_weights(from_nn=v_nn, to_nn=v_nn_target, weight=TARGET_UP_WEIGHT)
    
    ###### Update replay batcher priorities #######
    idx = U.to_np(batch.idx).astype('int')
    td_error = U.to_np((q_new_t - q_t_next).abs())
    batcher.update_pr(idx=idx, pr=td_error)

    ###### Write logs ######
    if batcher.num_steps % LOG_FREQ == 0 and batcher.runner.rewards:
        batcher.write_logs(logger)    

        logger.add_log('policy/loss', p_loss)
        logger.add_log('v/loss', v_loss)
        logger.add_log('q1/loss', q1_loss)
        logger.add_log('q2/loss', q2_loss)

        logger.add_log('policy/grad', p_grad)
        logger.add_log('v/grad', v_grad)
        logger.add_log('q1/grad', q1_grad)
        logger.add_log('q2/grad', q2_grad)

        logger.add_histogram('policy/log_prob', log_prob)
        logger.add_histogram('policy/mean', dist.loc)
        logger.add_histogram('policy/std', dist.scale.exp())
        logger.add_histogram('v/value', v_batch)
        logger.add_histogram('q1/value', q1_batch)
        logger.add_histogram('q2/value', q2_batch)

        logger.log(step=batcher.num_steps)

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=80000), HTML(value='')), layout=Layout(displa…


                         Step 2000                         
--------------------------------------------------------------
Env/Reward/Episode (New)                              |  61.68
Env/Length/Episode (New)                              |   6.77
Env/Reward/Episode (Last 50)                          | 101.90
Env/Length/Episode (Last 50)                          |  11.06
policy/loss                                           | -49.51
v/loss                                                |   3.02
q1/loss                                               |  21.42
q2/loss                                               |  19.52
policy/grad                                           |  -0.01
v/grad                                                |  -0.22
q1/grad                                               |   0.01
q2/grad                                               |  -0.07
--------------------------------------------------------------

                         Step 3000                      


                         Step 11000                         
--------------------------------------------------------------
Env/Reward/Episode (New)                             | 6515.96
Env/Length/Episode (New)                             |  696.50
Env/Reward/Episode (Last 50)                         | 1041.54
Env/Length/Episode (Last 50)                         |  111.54
policy/loss                                          | -256.40
v/loss                                               |   43.29
q1/loss                                              |  565.89
q2/loss                                              |  568.88
policy/grad                                          |   -1.16
v/grad                                               |   -0.62
q1/grad                                              |   -0.44
q2/grad                                              |   -1.06
--------------------------------------------------------------

                         Step 12000                    


                         Step 20000                         
--------------------------------------------------------------
Env/Reward/Episode (New)                             | 1198.53
Env/Length/Episode (New)                             |  128.33
Env/Reward/Episode (Last 50)                         | 2577.65
Env/Length/Episode (Last 50)                         |  275.70
policy/loss                                          | -338.96
v/loss                                               |   37.86
q1/loss                                              |  687.43
q2/loss                                              |  701.87
policy/grad                                          |    2.26
v/grad                                               |   -0.66
q1/grad                                              |    0.19
q2/grad                                              |   -1.18
--------------------------------------------------------------

                         Step 21000                    


                         Step 29000                         
--------------------------------------------------------------
Env/Reward/Episode (New)                             | 3440.26
Env/Length/Episode (New)                             |  368.00
Env/Reward/Episode (Last 50)                         | 4176.35
Env/Length/Episode (Last 50)                         |  446.62
policy/loss                                          | -384.88
v/loss                                               |   36.61
q1/loss                                              |  116.97
q2/loss                                              |  147.22
policy/grad                                          |   -0.60
v/grad                                               |   -0.79
q1/grad                                              |   -1.77
q2/grad                                              |    0.22
--------------------------------------------------------------

                         Step 30000                    