In [1]:
# Taken from this repo: https://github.com/MEfeTiryaki/trpo

import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import math
import numpy as np
import random
from collections import namedtuple
from collections import deque
from itertools import count
import gym
import scipy.optimize
torch.utils.backcompat.broadcast_warning.enabled = True
torch.utils.backcompat.keepdim_warning.enabled = True
torch.set_default_tensor_type('torch.DoubleTensor')

In [2]:
class Policy(nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(num_inputs, 64)
        self.affine2 = nn.Linear(64, 64)
        
        self.action_mean = nn.Linear(64, num_outputs)
        self.action_mean.weight.data.mul_(0.1)
        self.action_mean.bias.data.mul_(0.0)
        
        self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs))
        
        self.saved_actions = []
        
        self.rewards = []
        self.final_value = 0
        
    def forward(self, x):
        x = torch.tanh(self.affine1(x))
        x = torch.tanh(self.affine2(x))
        
        action_mean = self.action_mean(x)
        action_log_std = self.action_log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)
        
        return action_mean, action_log_std, action_std
    
class Value(nn.Module):
    def __init__(self, num_inputs):
        super(Value, self).__init__()
        self.affine1 = nn.Linear(num_inputs, 64)
        self.affine2 = nn.Linear(64, 64)
        self.value_head = nn.Linear(64, 1)
        self.value_head.weight.data.mul_(0.1)
        self.value_head.bias.data.mul_(0.0)
        
    def forward(self, x):
        x = torch.tanh(self.affine1(x))
        x = torch.tanh(self.affine2(x))
        
        state_values = self.value_head(x)
        return state_values

In [3]:
def normal_entropy(std):
    var = std.pow(2)
    entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi)
    return entropy.sum(1, keepdim=True)

def normal_log_density(x, mean, log_std, std):
    var = std.pow(2)
    log_density = -(x - mean).pow(2) / (
        2 * var) - 0.5 * math.log(2 * math.pi) - log_std
    return log_density.sum(1, keepdim=True)

def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    flat_params = torch.cat(params)
    return flat_params

def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(
            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size
        
def get_flat_grad_from(net, grad_grad=False):
    grads = []
    for param in net.parameters():
        if grad_grad:
            grads.append(param.grad.grad.view(-1))
        else:
            grads.append(param.grad.view(-1))
    flat_grad = torch.cat(grads)
    return flat_grad
    

In [4]:
def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10):
    x = torch.zeros(b.size())
    r = b - Avp(x)
    p = r
    rdotr = torch.dot(r, r)
    
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x

In [5]:
# taken from https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb
Transition = namedtuple('Transition', ('state','action', 'mask', 'next_state', 'reward'))

class Memory(object):
    def __init__(self):
        self.memory = []
    
    def push(self, *args):
        """saves transition"""
        self.memory.append(Transition(*args))
    
    def sample(self):
        return Transition(*zip(*self.memory))
    
    def __len__(self):
        return len(self.memory)

In [6]:
# from https://github.com/joschu/modular_rl
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
    def __init__(self, shape):
        self._n = 0
        self._M = np.zeros(shape)
        self._S = np.zeros(shape)
        
    def push(self, x):
        x = np.asarray(x)
        assert x.shape == self._M.shape
        self._n += 1
        if self._n == 1:
            self._M[...] = x
        else:
            oldM = self._M.copy()
            self._M[...] = oldM + (x - oldM) / self._n
            self._S[...] = self._S + (x - oldM) * (x - self._M)
            
    @property
    def n(self):
        return self_n
        
    @property
    def mean(self):
        return self._M
    
    @property
    def var(self):
        return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
    
    @property
    def std(self):
        return np.sqrt(self.var)
    
    @property
    def shape(self):
        return self._M.shape

In [7]:
class ZFilter:
    """
    y = (x - mean) / std
    using running estimates of mean, std
    """
    
    def __init__(self, shape, demean=True, destd=True, clip=10.0):
        self.demean= demean
        self.destd = destd
        self.clip = clip
        
        self.rs = RunningStat(shape)
        
    def __call__(self, x, update=True):
        if update:
            self.rs.push(x)
        if self.demean:
            x = x - self.rs.mean
        if self.destd:
            x = x / (self.rs.std + 1e-8)
        if self.clip:
            x = np.clip(x, -self.clip, self.clip)
        return x
    
    def output_shape(self, input_space):
        return input_space.shape

In [8]:
def linesearch(model, f, x, fullstep, expected_improve_rate, max_backtracks=10,
              accept_ratio=.1):
    fval = f(True).data
    print("fval before", fval.item())
    for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
        xnew = x + stepfrac * fullstep
        set_flat_params_to(model, xnew)
        newfval = f(True).data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve
        print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item())
        
        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            print("fval after", newfval.item())
            return True, xnew
    return False, x        

In [9]:
def trpo_step(model, get_loss, get_kl, max_kl, damping):
    loss = get_loss()
    grads = torch.autograd.grad(loss, model.parameters())
    loss_grad = torch.cat([grad.view(-1) for grad in grads]).data
    
    def Fvp(v):
        kl = get_kl()
        kl = kl.mean()
        
        grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])
        
        kl_v = (flat_grad_kl * Variable(v)).sum()
        grads = torch.autograd.grad(kl_v, model.parameters())
        flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data
        
        return flat_grad_grad_kl + v * damping
    
    stepdir = conjugate_gradients(Fvp, -loss_grad, 10)
    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)
    
    lm = torch.sqrt(shs / max_kl)
    fullstep = stepdir / lm[0]
    
    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
    print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm()))
    
    prev_params = get_flat_params_from(model)
    success, new_params = linesearch(model, get_loss, prev_params, fullstep, neggdotstepdir / lm[0])
    set_flat_params_to(model, new_params)
    
    return loss
        
        

In [10]:
def select_action(state):
    state = torch.from_numpy(state).unsqueeze(0)
    action_mean, _, action_std = policy_net(Variable(state))
    action = torch.normal(action_mean, action_std)
    return action


def update_params(batch):
    rewards = torch.Tensor(batch.reward)
    masks = torch.Tensor(batch.mask)
    actions = torch.Tensor(np.concatenate(batch.action, 0))
    states = torch.Tensor(batch.state)
    values = value_net(Variable(states))
    
    returns = torch.Tensor(actions.size(0), 1)
    deltas = torch.Tensor(actions.size(0), 1)
    advantages = torch.Tensor(actions.size(0), 1)
    
    prev_return = 0
    prev_value = 0
    prev_advantage = 0
    
    for i in reversed(range(rewards.size(0))):
        returns[i] = rewards[i] + gamma * prev_return * masks[i]
        deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values.data[i]
        advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]
        
        prev_return = returns[i, 0]
        prev_value = values.data[i, 0]
        prev_advantage = advantages[i, 0]
        
    targets = Variable(returns)
        
    def get_value_loss(flat_params):
        set_flat_params_to(value_net, torch.Tensor(flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)
                
        values_ = value_net(Variable(states))
        value_loss = (values_ - targets).pow(2).mean() # mean squared error
        
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg
        value_loss.backward()
        return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())
    
    flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, 
                                get_flat_params_from(value_net).double().numpy(), maxiter=25)
    set_flat_params_to(value_net, torch.Tensor(flat_params))
    
    advantages = (advantages - advantages.mean()) / advantages.std()
    action_means, action_log_stds, action_stds = policy_net(Variable(states))
    fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()

    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                action_means, action_log_stds, action_stds = policy_net(Variable(states))
        else:
            action_means, action_log_stds, action_stds = policy_net(Variable(states))
            
        log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
        action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
        return action_loss.mean()
    
    def get_kl():
        mean1, log_std1, std1 = policy_net(Variable(states))
        mean0 = Variable(mean1.data)
        log_std0 = Variable(log_std1.data)
        std0 = Variable(std1.data)
        kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
        return kl.sum(1, keepdim=True)
    
    trpo_step(policy_net, get_loss, get_kl, max_kl, damping)  

In [11]:
env_name = "BipedalWalker-v3"
seed = 543
gamma = 0.995
tau = 0.97
l2_reg = 0.0001
max_kl = 0.01
damping = .1
batch_size = 15000
log_interval = 1
render = True

env = gym.make(env_name)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
env.seed(seed)
torch.manual_seed(seed)
policy_net = Policy(num_inputs, num_actions)
value_net = Value(num_inputs)

running_state = ZFilter((num_inputs,), clip=5)
running_reward = ZFilter((1,), demean=False, clip=10)



In [12]:
for i_episode in count(1):
    memory = Memory()
    
    num_steps = 0
    reward_batch = 0
    num_episodes = 0
    while num_steps < batch_size:
        state = env.reset()
        state = running_state(state)
        reward_sum = 0
        
        for t in range(10000):
            action = select_action(state)
            action = action.data[0].numpy()
            next_state, reward, done, _ = env.step(action)
            reward_sum += reward
            
            next_state = running_state(next_state) # need to understand this line
            
            mask = 1
            if done:
                mask = 0
            
            memory.push(state, np.array([action]), mask, next_state, reward)
            
            if render:
                env.render()
            if done:
                break
                
            state = next_state
        
        num_steps += (t-1)
        num_episodes += 1
        reward_batch += reward_sum
        
    reward_batch /= num_episodes
    batch = memory.sample()
    update_params(batch)
    
    if i_episode % log_interval == 0:
        print('Episode {}\tLast Reward: {}\t Average Reward {:.2f}'.format(
        i_episode, reward_sum, reward_batch))

('lagrange multiplier:', tensor(0.5765), 'grad_norm:', tensor(0.0542))
fval before -1.7532360391341687e-18
a/e/r 0.010656751657885214 0.011640831390811483 0.9154631056933753
fval after -0.010656751657885216
Episode 1	Last Reward: -114.51168962065101	 Average Reward -114.72
('lagrange multiplier:', tensor(0.3720), 'grad_norm:', tensor(0.0356))
fval before -9.806724436172494e-18
a/e/r 0.007940812518537918 0.007586430311059919 1.0467126425667366
fval after -0.007940812518537928
Episode 2	Last Reward: -124.593293644332	 Average Reward -112.17
('lagrange multiplier:', tensor(0.5592), 'grad_norm:', tensor(0.0616))
fval before -1.0333800612627887e-17
a/e/r 0.010557414354480339 0.011131448225994989 0.9484313397627702
fval after -0.010557414354480349
Episode 3	Last Reward: -101.70240413587075	 Average Reward -111.17
('lagrange multiplier:', tensor(0.6933), 'grad_norm:', tensor(0.0918))
fval before -2.3543496877405573e-18
a/e/r 0.014395014155051459 0.013882288213480845 1.036933820540674
fval aft

('lagrange multiplier:', tensor(0.6470), 'grad_norm:', tensor(0.1322))
fval before -5.861977570020827e-17
a/e/r 0.012426681880380392 0.012694483984153333 0.9789040575333947
fval after -0.012426681880380451
Episode 32	Last Reward: 42.87545658730558	 Average Reward 37.44
('lagrange multiplier:', tensor(0.5821), 'grad_norm:', tensor(0.1408))
fval before -2.2648549702353192e-17
a/e/r 0.011257061700635338 0.011917952668799671 0.9445466023795768
fval after -0.011257061700635361
Episode 33	Last Reward: 52.21992763319794	 Average Reward 52.11
('lagrange multiplier:', tensor(0.6273), 'grad_norm:', tensor(0.1689))
fval before 3.885780586188048e-18
a/e/r 0.01235365259498061 0.012567956794145208 0.9829483660172644
fval after -0.012353652594980606
Episode 34	Last Reward: 46.91998688753894	 Average Reward 60.17
('lagrange multiplier:', tensor(0.4192), 'grad_norm:', tensor(0.0783))
fval before 4.6388725520618364e-18
a/e/r 0.0089455639270449 0.008679083810617 1.0307037150743859
fval after -0.008945563

('lagrange multiplier:', tensor(0.4357), 'grad_norm:', tensor(0.1399))
fval before -6.439293542825908e-18
a/e/r 0.008864121347936144 0.009210307490563167 0.9624131829494592
fval after -0.008864121347936151
Episode 63	Last Reward: 148.90854949314382	 Average Reward 154.36
('lagrange multiplier:', tensor(0.4837), 'grad_norm:', tensor(0.1481))
fval before 4.440892098500626e-18
a/e/r 0.009564948549295018 0.0101220188285948 0.9449645086881232
fval after -0.009564948549295012
Episode 64	Last Reward: 154.41602001176022	 Average Reward 158.35
('lagrange multiplier:', tensor(0.4849), 'grad_norm:', tensor(0.1437))
fval before -8.881784197001253e-19
a/e/r 0.00997601704904702 0.010305265308316972 0.9680504820187182
fval after -0.009976017049047022
Episode 65	Last Reward: 154.95999772397587	 Average Reward 160.73
('lagrange multiplier:', tensor(0.4925), 'grad_norm:', tensor(0.1555))
fval before 2.6645352591003758e-18
a/e/r 0.009823998368318822 0.009822608036417262 1.0001415440681747
fval after -0.0

KeyboardInterrupt: 