In [1]:
import pdb
import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import reward as rw
import reward.utils as U



In [2]:
ENV = 'Humanoid-v2'
LOG_DIR = '/tmp/logs/humanoid/paper-v0-1'
REPAR = True
REWARD_SCALE = 20.
CLIP_GRAD = float('inf')
GAMMA = 0.99
TARGET_UP_WEIGHT = 0.005
BATCH_SIZE = 256
MAX_STEPS = 40e6

In [3]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [4]:
env = rw.env.GymEnv(ENV)
env = rw.env.wrappers.ActionBound(env)
runner = rw.runner.SingleRunner(env)
batcher = rw.batcher.ReplayBatcher(
    runner=runner,
    batch_size=256,
    maxlen=1e6,
    learning_freq=1,
    grad_steps_per_batch=1,
    transforms=[
#         rw.batcher.transforms.StateRunNorm(),        
    ],
)

s_features = batcher.s_space.shape[0]
num_acs = batcher.ac_space.shape[0]

Choosing the latest nvidia driver: /usr/lib/nvidia-390, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-390']
Choosing the latest nvidia driver: /usr/lib/nvidia-390, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-390']
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [5]:
class PolicyNN(nn.Module):
    def __init__(self, num_inputs, num_outputs, hidden_units=256,
                 activation=nn.ReLU, log_std_range=(-20, 2)):
        super().__init__()
        self.log_std_range = log_std_range
        
        layers = []
        layers += [nn.Linear(num_inputs, hidden_units), activation()]
        layers += [nn.Linear(hidden_units, hidden_units), activation()]
        self.layers = nn.Sequential(*layers)
        
        self.mean = nn.Linear(hidden_units, num_outputs)
        self.mean.weight.data.uniform_(-3e-3, 3e-3)
        self.mean.bias.data.uniform_(-3e-3, 3e-3)
        
        self.log_std = nn.Linear(hidden_units, num_outputs)
        self.log_std.weight.data.uniform_(-3e-3, 3e-3)
        self.log_std.bias.data.uniform_(-3e-3, 3e-3)
        
    def forward(self, x):
        x = self.layers(x)
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(*self.log_std_range)
        return mean, log_std        

In [6]:
class ValueNN(nn.Module):
    def __init__(self, num_inputs, hidden_units=256, activation=nn.ReLU):
        super().__init__()
        
        layers = []
        layers += [nn.Linear(num_inputs, hidden_units), activation()]
        layers += [nn.Linear(hidden_units, hidden_units), activation()]
        final_layer = nn.Linear(hidden_units, 1)
        final_layer.weight.data.uniform_(-3e-3, 3e-3)
        final_layer.bias.data.uniform_(-3e-3, 3e-3)
        layers += [final_layer]
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.layers(x)

In [7]:
class QValueNN(nn.Module):
    def __init__(self, num_inputs, num_acs, hidden_units=256, activation=nn.ReLU):
        super().__init__()
        
        layers = []
        layers += [nn.Linear(num_inputs + num_acs, hidden_units), activation()]
        layers += [nn.Linear(hidden_units, hidden_units), activation()]
        final_layer = nn.Linear(hidden_units, 1)
        final_layer.weight.data.uniform_(-3e-3, 3e-3)
        final_layer.bias.data.uniform_(-3e-3, 3e-3)
        layers += [final_layer]
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        s, ac = x
        x = torch.cat([s, ac], dim=1)
        return self.layers(x)

In [8]:
class TanhNormalPolicy(rw.policy.BasePolicy):
    def create_dist(self, s):
        parameters = self.nn(s)
        mean, log_std = parameters
        return rw.distributions.TanhNormal(loc=mean, scale=log_std.exp())

    def get_ac(self, s, step):
        dist = self.create_dist(s=s)
        ac = U.to_np(dist.sample())
        assert not np.isnan(ac).any()
        return ac

In [9]:
p_nn = PolicyNN(num_inputs=s_features, num_outputs=num_acs).to(device)
v_nn = ValueNN(num_inputs=s_features).to(device)
v_nn_target = ValueNN(num_inputs=s_features).to(device).eval()
q1_nn = QValueNN(num_inputs=s_features, num_acs=num_acs).to(device)
q2_nn = QValueNN(num_inputs=s_features, num_acs=num_acs).to(device)

In [10]:
U.copy_weights(from_nn=v_nn, to_nn=v_nn_target, weight=1.)

In [11]:
policy = TanhNormalPolicy(nn=p_nn)

In [12]:
p_opt = torch.optim.Adam(p_nn.parameters(), lr=3e-4)
v_opt = torch.optim.Adam(v_nn.parameters(), lr=3e-4)
q1_opt = torch.optim.Adam(q1_nn.parameters(), lr=3e-4)
q2_opt = torch.optim.Adam(q2_nn.parameters(), lr=3e-4)

In [13]:
logger = U.Logger(LOG_DIR)

Writing logs to: /tmp/logs/humanoid/paper-v0-1


In [14]:
batcher.populate(n=1000, act_fn=policy.get_ac)

Populating Replay Buffer...


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




In [15]:
for batch in batcher.get_batches(MAX_STEPS, policy.get_ac):
    batch = batch.to_tensor().concat_batch()

    ##### Calculate losses ######
    q1_batch = q1_nn((batch.s, batch.ac))
    q2_batch = q2_nn((batch.s, batch.ac))
    v_batch = v_nn(batch.s)

    dist = policy.create_dist(batch.s)
    if REPAR:
        ac, pre_tanh_ac = dist.rsample_with_pre()
    else:
        ac, pre_tanh_ac = dist.sample_with_pre()
    log_prob = dist.log_prob_pre(pre_tanh_ac).sum(-1, keepdim=True)
    log_prob /= float(REWARD_SCALE)

    # Q loss
    v_target_tp1 = v_nn_target(batch.sn)
    q_value_tp1 = U.estimators.td_target(rs=batch.r, ds=batch.d,
                                         v_tp1=v_target_tp1, gamma=GAMMA)
    q1_loss = F.mse_loss(q1_batch, q_value_tp1.detach())
    q2_loss = F.mse_loss(q2_batch, q_value_tp1.detach())

    # V loss    
    q1_new_t = q1_nn((batch.s, ac))
    q2_new_t = q2_nn((batch.s, ac))
    q_new_t = torch.min(q1_new_t, q2_new_t)
    next_value = q_new_t - log_prob
    v_loss = F.mse_loss(v_batch, next_value.detach())

    # Policy loss    
    if REPAR:        
        p_loss = (log_prob - q_new_t).mean()
    else:
        next_log_prob = q_new_t - v_batch
        p_loss = (log_prob * (log_prob - next_log_prob).detach()).mean()
    # Policy regularization losses
    mean_loss = 1e-3 * dist.loc.pow(2).mean()
    log_std_loss = 1e-3 * dist.scale.log().pow(2).mean()
    pre_tanh_loss = 0 * pre_tanh_ac.pow(2).sum(1).mean()
    # Combine all losses
    p_loss += mean_loss + log_std_loss + pre_tanh_loss

    ###### Optimize ######
    q1_opt.zero_grad()
    q1_loss.backward()
#     torch.nn.utils.clip_grad_norm_(q1_nn.parameters(), CLIP_GRAD)
    q1_grad = U.mean_grad(q1_nn)
    q1_opt.step()

    q2_opt.zero_grad()
    q2_loss.backward()
#     torch.nn.utils.clip_grad_norm_(q2_nn.parameters(), CLIP_GRAD)
    q2_grad = U.mean_grad(q2_nn)
    q2_opt.step()

    v_opt.zero_grad()
    v_loss.backward()
#     torch.nn.utils.clip_grad_norm_(v_nn.parameters(), CLIP_GRAD)
    v_grad = U.mean_grad(v_nn)
    v_opt.step()

    p_opt.zero_grad()
    p_loss.backward()
#     torch.nn.utils.clip_grad_norm_(p_nn.parameters(), CLIP_GRAD)
    p_grad = U.mean_grad(p_nn)
    p_opt.step()

    ###### Update target value network ######
    U.copy_weights(from_nn=v_nn, to_nn=v_nn_target, weight=TARGET_UP_WEIGHT)

    ###### Write logs ######
    if batcher.num_steps % 4000 == 0 and batcher.runner.rs:
        batcher.write_logs(logger)    

        logger.add_log('policy/loss', p_loss)
        logger.add_log('v/loss', v_loss)
        logger.add_log('q1/loss', q1_loss)
        logger.add_log('q2/loss', q2_loss)

        logger.add_log('policy/grad', p_grad)
        logger.add_log('v/grad', v_grad)
        logger.add_log('q1/grad', q1_grad)
        logger.add_log('q2/grad', q2_grad)

        logger.add_histogram('policy/log_prob', log_prob)
        logger.add_histogram('policy/mean', dist.loc)
        logger.add_histogram('policy/std', dist.scale.exp())
        logger.add_histogram('v/value', v_batch)
        logger.add_histogram('q1/value', q1_batch)
        logger.add_histogram('q2/value', q2_batch)

        logger.log(step=batcher.num_steps)

HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=40000000), HTML(value='')), layout=Layout(dis…




ValueError: <class 'list'> not suported