In [1]:
import torch
import torch.nn as nn
from torch.nn import MSELoss
import torch.nn.functional as F
import copy
import numpy as np
import torch
from torch.optim import Adam
from buffer import ReplayBuffer
import gym

# 0. Define Q-network & policy-network

In [2]:
# critic network definition
# multi-layer perceptron (with 2 hidden layers)
class Critic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden1, hidden2):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(obs_dim + act_dim, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, 1)
        
    
    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        return self.fc3(x)
    
    
# actor network definition
# multi-layer perceptron (with 2 hidden layers)
class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, ctrl_range, hidden1, hidden2):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(obs_dim, hidden1)
        self.fc2 = nn.Linear(hidden1, hidden2)
        self.fc3 = nn.Linear(hidden2, act_dim)
        self.ctrl_range = ctrl_range
        
    def forward(self, obs):
        x = F.relu(self.fc1(obs))
        x = F.relu(self.fc2(x))
        
        return self.ctrl_range * torch.tanh(self.fc3(x))

# 1. Define DDPG agent

In [3]:
class DDPGAgent:
    def __init__(self, obs_dim, act_dim, ctrl_range, hidden1, hidden2):
        
        # networks
        self.actor = Actor(obs_dim, act_dim, ctrl_range, hidden1, hidden2)
        self.critic = Critic(obs_dim, act_dim, hidden1, hidden2)
        
        # target networks
        self.targ_actor = copy.deepcopy(self.actor)
        self.targ_critic = copy.deepcopy(self.critic)
        
        
    def act(self, obs):
        # numpy ndarray to torch tensor
        # we first add an extra dimension
        obs = obs[np.newaxis]
        with torch.no_grad():
            obs_tensor = torch.Tensor(obs)
            act_tensor = self.actor(obs_tensor)

        # torch tensor to numpy ndarray
        # remove extra dimension
        action = act_tensor.numpy()
        action = np.squeeze(action, axis=0)
        
        return action
    
    

## 1.1.Test

In [4]:
agent = DDPGAgent(4, 2, 3, 32, 32)
action = agent.act(np.array([3., -1., 2., -5.]))
print(action)

[-0.89769304 -1.2820984 ]


# 2. Implement one-step param update

In [5]:
def update(agent, replay_buf, gamma, actor_optim, critic_optim, tau, batch_size):
    # agent : agent with networks to be trained
    # replay_buf : replay buf from which we sample a batch
    # actor_optim / critic_optim : torch optimizers
    # tau : parameter for soft target update
    
    batch = replay_buf.sample_batch(batch_size=batch_size)

    # target construction does not need backward ftns
    with torch.no_grad():
        # unroll batch
        obs = torch.Tensor(batch.obs)
        act = torch.Tensor(batch.act)
        next_obs = torch.Tensor(batch.next_obs)
        rew = torch.Tensor(batch.rew)
        done = torch.Tensor(batch.done)
        
        ################
        # train critic #
        ################
        mask = torch.Tensor([1.]) - done
        target = rew + gamma * mask * agent.targ_critic(next_obs, agent.targ_actor(next_obs))
    
    out = agent.critic(obs, act)
    
    loss_ftn = MSELoss()
    critic_loss = loss_ftn(out, target)
    # alternative : loss = torch.mean((target - out)**2)
    
    critic_optim.zero_grad()
    critic_loss.backward()
    critic_optim.step()
    
    ###############
    # train actor #
    ###############
    
    # freeze critic during actor training (why?)
    for p in agent.critic.parameters():
        p.requires_grad_(False)
    
    actor_loss = -torch.mean(agent.critic(obs, agent.actor(obs)))
    
    actor_optim.zero_grad()
    actor_loss.backward()
    actor_optim.step()
    
    
    # unfreeze critic after actor training
    for p in agent.critic.parameters():
        p.requires_grad_(True)
        
    # soft target update (both actor & critic network)
    for p, targ_p in zip(agent.actor.parameters(), agent.targ_actor.parameters()):
        targ_p.data.copy_((1. - tau) * targ_p + tau * p)
    for p, targ_p in zip(agent.critic.parameters(), agent.targ_critic.parameters()):
        targ_p.data.copy_((1. - tau) * targ_p + tau * p)
        
        
    return

In [6]:
def evaluate(agent, env, num_episodes=5):
    
    test_env = copy.deepcopy(env)
    sum_scores = 0.
    
    for i in range(num_episodes):
        obs = test_env.reset()
        done = False
        score = 0.
        
        while not done:
            action = agent.act(obs)
            obs, rew, done, _ = test_env.step(action)
            score += rew
        sum_scores += score
    avg_score = sum_scores / num_episodes
    
    return avg_score

# 3. Combining these, we finally have...

In [7]:
def train(agent, env, gamma, 
          actor_lr, critic_lr, tau, noise_std,
          ep_len, num_updates, batch_size,
          init_buffer=5000, buffer_size=100000,
          start_train=2000, train_interval=50,
          eval_interval=2000):
    
    actor_optim = Adam(agent.actor.parameters(), lr=actor_lr)
    critic_optim = Adam(agent.critic.parameters(), lr=critic_lr)
    
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    ctrl_range = env.action_space.high[0]
    
    replay_buf = ReplayBuffer(obs_dim, act_dim, buffer_size)
    
    
    # main loop
    obs = env.reset()
    done = False
    step_count = 0
    
    for t in range(num_updates + 1):
        if t < init_buffer:
            # perform random action until we collect sufficiently many samples
            # this is for exploration purpose
            action = env.action_space.sample()
        else:
            # executes noisy action
            # a_t = \pi(s_t) + N(0, \sigma^2)
            action = agent.act(obs) + noise_std * np.random.randn(act_dim)
            action = np.clip(action, -ctrl_range, ctrl_range)
            
        next_obs, rew, done, _ = env.step(action)
        step_count += 1
        if step_count == ep_len:
            # if the next_state is not terminal but done is set to True by gym env wrapper
            done = False
            
        replay_buf.append(obs, action, next_obs, rew, done)
        obs = next_obs
        
        if done == True or step_count == ep_len:
            # reset environment if current environment reaches a terminal state 
            # or step count reaches predefined length
            obs = env.reset()
            done = False
            step_count = 0
        
        
        if t > start_train and t % train_interval == 0:
            # start training after fixed number of steps
            # this may mitigate overfitting of networks to the 
            # small number of samples collected during the initial stage of training
            for _ in range(train_interval):
                update(agent, replay_buf, gamma, actor_optim, critic_optim, tau, batch_size)

        if t % eval_interval == 0:
            score = evaluate(agent, env)
            print('[iteration {}] evaluation score : {}'.format(t, score))

# 4. Let's test the code!

In [8]:
env = gym.make('BipedalWalker-v3')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
ctrl_range = env.action_space.high[0]

print('observation space dim : {} / action space dim : {}'.format(obs_dim, act_dim))
print('ctrl range : ', ctrl_range)

observation space dim : 24 / action space dim : 4
ctrl range :  1.0




In [9]:
agent = DDPGAgent(obs_dim=obs_dim, act_dim=act_dim, ctrl_range=ctrl_range, hidden1=256, hidden2=256)

In [10]:
gamma = 0.99
actor_lr = 1e-4
critic_lr = 1e-3
tau = 1e-3
noise_std = 0.1
ep_len = 500
num_updates = 1000000
batch_size = 128

In [11]:
train(agent, env, gamma,
      actor_lr, critic_lr, tau, noise_std,
      ep_len, num_updates, batch_size,
      init_buffer=5000, buffer_size=1000000,
      start_train=2000, train_interval=50,
      eval_interval=5000)



[iteration 0] evaluation score : -105.18096577831925
[iteration 5000] evaluation score : -149.42335325775898
[iteration 10000] evaluation score : -117.20837319950499
[iteration 15000] evaluation score : -103.57316732793561
[iteration 20000] evaluation score : -111.85199683076887
[iteration 25000] evaluation score : -105.68771600659984
[iteration 30000] evaluation score : -126.66331381985417
[iteration 35000] evaluation score : -122.11559117571228
[iteration 40000] evaluation score : -119.58603488314411
[iteration 45000] evaluation score : -118.27418510137
[iteration 50000] evaluation score : -122.26617781308087
[iteration 55000] evaluation score : -128.72361778657395
[iteration 60000] evaluation score : -128.63449127786168
[iteration 65000] evaluation score : -138.25478470315934
[iteration 70000] evaluation score : -139.09374152110217
[iteration 75000] evaluation score : -133.98284320065378
[iteration 80000] evaluation score : -138.70358006132116
[iteration 85000] evaluation score : -1

[iteration 715000] evaluation score : -152.84649391752228
[iteration 720000] evaluation score : -144.79366149822053
[iteration 725000] evaluation score : -135.76441646748708
[iteration 730000] evaluation score : -135.9217606600875
[iteration 735000] evaluation score : -119.40010283725296
[iteration 740000] evaluation score : -133.21465693016407
[iteration 745000] evaluation score : -199.10399893962625
[iteration 750000] evaluation score : -116.17272406831023
[iteration 755000] evaluation score : -130.77469738684812
[iteration 760000] evaluation score : -138.87501428273188
[iteration 765000] evaluation score : -139.0995160091315
[iteration 770000] evaluation score : -111.87070219704533
[iteration 775000] evaluation score : -69.03112494620791
[iteration 780000] evaluation score : -149.1098507742875
[iteration 785000] evaluation score : -139.85497107735662
[iteration 790000] evaluation score : -164.39760835992072
[iteration 795000] evaluation score : -162.20011362472016
[iteration 800000]

# 5. Watch the trained agent

In [12]:
obs = env.reset()
done = False
score = 0.

while not done:
    env.render()
    obs, rew, done, _ = env.step(agent.act(obs))
    score += rew
    
env.close()
print('score : ', score)

score :  -71.75615367728673
