# Trust Region Policy Optimization Practice

In [1]:
import numpy as np
import time
import csv
import torch
import os
import copy
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Independent
from torch.distributions.normal import Normal
from torch.optim import Adam
from itertools import chain
from memory import OnPolicyMemory
from utils import cg, fisher_vector_product, backtracking_line_search, update_model, flat_params

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('current device : ', device)

current device :  cpu


# 0. Network Architectures

In [3]:
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden1, hidden2):
        # actor f_\phi(s)
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(obs_dim, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)

        self.fc3 = nn.Linear(hidden2, act_dim)  # for \mu
        self.fc4 = nn.Linear(hidden2, act_dim)  # for \sigma

    def forward(self, obs):
        x = torch.tanh(self.fc1(obs))
        x = torch.tanh(self.fc2(x))

        mu = self.fc3(x)
        log_sigma = self.fc4(x)

        sigma = torch.exp(log_sigma)

        return mu, sigma

    def log_prob(self, obs, act):
        mu, sigma = self.forward(obs)
        act_distribution = Independent(Normal(mu, sigma), 1)
        log_prob = act_distribution.log_prob(act)
        return log_prob

class Critic(nn.Module):
    # critic V(s ; \theta)
    def __init__(self, obs_dim, hidden1, hidden2):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(obs_dim, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, 1)

    def forward(self, obs):
        x = torch.tanh(self.fc1(obs))
        x = torch.tanh(self.fc2(x))

        return self.fc3(x)

# 1. Agent Definition

In [4]:
class TRPOAgent:
    def __init__(
                 self,
                 obs_dim,
                 act_dim,

                 hidden1=64,
                 hidden2=32,
                 ):

        self.obs_dim = obs_dim
        self.act_dim = act_dim
        
        self.hidden1 = hidden1
        self.hidden2 = hidden2

        self.pi = Actor(obs_dim, act_dim, hidden1, hidden2).to(device)
        self.V = Critic(obs_dim, hidden1, hidden2).to(device)

    def act(self, obs):
        obs = torch.tensor(obs, dtype=torch.float).to(device)
        with torch.no_grad():
            mu, sigma = self.pi(obs)
            act_distribution = Independent(Normal(mu, sigma), 1)
            action = act_distribution.sample()

            log_prob = act_distribution.log_prob(action)
            val = self.V(obs)

        action = action.cpu().numpy()
        log_prob = log_prob.cpu().numpy()
        val = val.cpu().numpy()

        return action, log_prob, val

# 2. Actor & Critic Update

In [5]:
def update(agent, memory, critic_optim, delta, num_updates):
    
    batch = memory.load()

    states = torch.Tensor(batch['state']).to(device)
    actions = torch.Tensor(batch['action']).to(device)
    target_v = torch.Tensor(batch['val']).to(device)
    A = torch.Tensor(batch['A']).to(device)
    old_log_probs = torch.Tensor(batch['log_prob']).to(device)
    
    for _ in range(num_updates):
        ################
        # train critic #
        ################
        out = agent.V(states)
        critic_loss = torch.mean((out - target_v)**2)

        critic_optim.zero_grad()
        critic_loss.backward()
        critic_optim.step()

        ###############
        # train actor #
        ###############
        log_probs = agent.pi.log_prob(states, actions)

        # \pi(a_t | s_t ; \phi) / \pi(a_t | s_t ; \phi_old)
        prob_ratio = torch.exp(log_probs - old_log_probs)

        actor_loss = torch.mean(prob_ratio * A)
        loss_grad = torch.autograd.grad(actor_loss, agent.pi.parameters())
        # flatten gradients of params
        g = torch.cat([grad.view(-1) for grad in loss_grad]).data

        s = cg(fisher_vector_product, g, agent.pi, states)

        sAs = torch.sum(fisher_vector_product(s, agent.pi, states) * s, dim=0, keepdim=True)
        step_size = torch.sqrt(2 * delta / sAs)[0]
        step = step_size * s

        old_actor = Actor(agent.obs_dim, agent.act_dim, agent.hidden1, agent.hidden2)
        old_actor.load_state_dict(agent.pi.state_dict())

        params = flat_params(agent.pi)

        backtracking_line_search(old_actor, agent.pi, actor_loss, g,
                                 old_log_probs, params, step, delta, A, states, actions)
    return

In [6]:
def evaluate(agent, env, num_episodes=5):

    scores = np.zeros(num_episodes)
    for i in range(num_episodes):
        obs = env.reset()
        done = False
        score = 0.
        while not done:
            action = agent.act(obs)[0]
            obs, rew, done, _ = env.step(action)
            score += rew

        scores[i] = score
    avg_score = np.mean(scores)
    std_score = np.std(scores)
    return avg_score, std_score

# 3. Training!

In [7]:
def train(env, agent, max_iter, gamma=0.99, lr=3e-4, lam=0.95, delta=1e-3, steps_per_epoch=4000, eval_interval=4000):
    
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    max_ep_len = env._max_episode_steps
    memory = OnPolicyMemory(obs_dim, act_dim, gamma, lam, lim=steps_per_epoch)
    test_env = copy.deepcopy(env)
    critic_optim = Adam(agent.V.parameters(), lr=lr)
    os.makedirs('./learning_curves/', exist_ok=True)
    log_file = open('./learning_curves/res.csv',
                    'w',
                    encoding='utf-8',
                    newline=''
                   )
    logger = csv.writer(log_file)
    num_epochs = max_iter // steps_per_epoch
    total_t = 0
    begin = time.time()
    for epoch in range(num_epochs):
        # start agent-env interaction
        state = env.reset()
        step_count = 0
        ep_reward = 0

        for t in range(steps_per_epoch):
            # collect transition samples by executing the policy
            action, log_prob, v = agent.act(state)

            next_state, reward, done, _ = env.step(action)
            memory.append(state, action, reward, v, log_prob)

            ep_reward += reward
            step_count += 1

            if (step_count == max_ep_len) or (t == steps_per_epoch - 1):
                # termination of env by env wrapper, or by truncation due to memory size
                s_last = torch.tensor(next_state, dtype=torch.float).to(device)
                v_last = agent.V(s_last).item()
                memory.compute_values(v_last)
            elif done:
                # episode done as the agent reach a terminal state
                v_last = 0.0
                memory.compute_values(v_last)

            state = next_state

            if done:
                state = env.reset()
                step_count = 0
                ep_reward = 0

            if total_t % eval_interval == 0:
                avg_score, std_score = evaluate(agent, test_env, num_episodes=5)
                elapsed_t = time.time() - begin
                print('[elapsed time : {:.1f}s| iter {}] score = {:.2f}'.format(elapsed_t, total_t, avg_score), u'\u00B1', '{:.4f}'.format(std_score))
                evaluation_log = [t, avg_score, std_score]
                logger.writerow(evaluation_log)

            total_t += 1

        # train agent at the end of each epoch
        update(agent, memory, critic_optim, delta, num_updates=1)
    log_file.close()
    return

In [8]:
env = gym.make('LunarLanderContinuous-v2')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
print('observation space dim. : {} / action space dim. : {}'.format(obs_dim, act_dim))

observation space dim. : 8 / action space dim. : 2


In [9]:
agent = TRPOAgent(obs_dim, act_dim)

In [10]:
train(env, agent, max_iter=600000, gamma=0.99, lr=3e-4, lam=0.95, steps_per_epoch=4000, eval_interval=4000)

[elapsed time : 0.3s| iter 0] score = -218.82 ± 157.0771
[elapsed time : 3.4s| iter 4000] score = -272.99 ± 177.5552
[elapsed time : 6.6s| iter 8000] score = -358.84 ± 43.0449
[elapsed time : 9.6s| iter 12000] score = -394.81 ± 162.3906
[elapsed time : 12.5s| iter 16000] score = -257.32 ± 177.0494
[elapsed time : 15.7s| iter 20000] score = -262.12 ± 160.9821
[elapsed time : 18.9s| iter 24000] score = -209.06 ± 58.4583
[elapsed time : 22.0s| iter 28000] score = -246.32 ± 136.6173
[elapsed time : 25.1s| iter 32000] score = -219.50 ± 74.6381
[elapsed time : 28.1s| iter 36000] score = -226.45 ± 113.7727
[elapsed time : 32.4s| iter 40000] score = -122.09 ± 116.1492
[elapsed time : 35.6s| iter 44000] score = -216.59 ± 129.6992
[elapsed time : 38.8s| iter 48000] score = -205.48 ± 158.1739
[elapsed time : 42.0s| iter 52000] score = -121.50 ± 71.5140
[elapsed time : 45.4s| iter 56000] score = -226.53 ± 98.5263
[elapsed time : 48.3s| iter 60000] score = -87.40 ± 22.1071
[elapsed time : 51.4s| it

[elapsed time : 1272.7s| iter 536000] score = -1.06 ± 44.4518
[elapsed time : 1292.9s| iter 540000] score = 52.14 ± 27.4368
[elapsed time : 1311.6s| iter 544000] score = 51.36 ± 52.5124
[elapsed time : 1327.1s| iter 548000] score = 6.20 ± 75.6901
[elapsed time : 1345.0s| iter 552000] score = 19.65 ± 60.1252
[elapsed time : 1363.1s| iter 556000] score = 28.47 ± 64.6490
[elapsed time : 1384.6s| iter 560000] score = 3.58 ± 43.3114
[elapsed time : 1406.7s| iter 564000] score = -3.18 ± 16.9393
[elapsed time : 1427.5s| iter 568000] score = -2.01 ± 26.7813
[elapsed time : 1448.2s| iter 572000] score = 0.52 ± 60.1744
[elapsed time : 1467.7s| iter 576000] score = -17.85 ± 35.0557
[elapsed time : 1487.6s| iter 580000] score = 7.40 ± 35.1258
[elapsed time : 1509.6s| iter 584000] score = 2.05 ± 41.8522
[elapsed time : 1528.8s| iter 588000] score = -42.14 ± 55.0282
[elapsed time : 1549.0s| iter 592000] score = -33.80 ± 30.4635
[elapsed time : 1567.5s| iter 596000] score = -20.94 ± 43.5944


# 4. Watch how your agent solve the task!

In [11]:
# env = wrap_env(env)
obs = env.reset()
done = False
score = 0.
while not done:
    # env.render()
    obs, rew, done, _ = env.step(agent.act(obs)[0])
    score += rew
env.close()
print('score : ', score)
# show_video()

score :  -22.608029337759753
