In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distributions

import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
import numpy as np
import gym

In [37]:
train_env = gym.make('CartPole-v0')
test_env = gym.make('CartPole-v0')

assert isinstance(train_env.observation_space, gym.spaces.Box)
assert isinstance(train_env.action_space, gym.spaces.Discrete)

In [38]:
SEED = 1234

train_env.seed(SEED);
test_env.seed(SEED);
np.random.seed(SEED);
torch.manual_seed(SEED);

In [39]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout = 0.0):
        super().__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc_1(x)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.fc_2(x)
        return x

In [40]:
INPUT_DIM = train_env.observation_space.shape[0]
HIDDEN_DIM = 256
OUTPUT_DIM = train_env.action_space.n

actor = MLP(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM)
critic = MLP(INPUT_DIM, HIDDEN_DIM, 1)

In [41]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0)
        
actor.apply(init_weights)
critic.apply(init_weights)

MLP(
  (fc_1): Linear(in_features=4, out_features=256, bias=True)
  (fc_2): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [42]:
LEARNING_RATE = 0.01

actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-4)

In [43]:
def train(env, actor, critic, actor_optimizer, critic_optimizer, n_steps, discount_factor):
    
    log_prob_actions = torch.zeros(n_steps)
    entropies = torch.zeros(n_steps)
    values = torch.zeros(n_steps)
    rewards = torch.zeros(n_steps)
    masks = torch.zeros(n_steps)
    episode_reward = 0

    state = env.state

    for step in range(n_steps):

        state = torch.FloatTensor(state).unsqueeze(0)
        
        action_preds = actor(state)
        value_pred = critic(state).squeeze(-1)

        action_probs = F.softmax(action_preds, dim = -1)
                
        dist = distributions.Categorical(action_probs)

        action = dist.sample()
        
        log_prob_action = dist.log_prob(action)
        
        entropy = dist.entropy()
        
        state, reward, done, _ = env.step(action.item())

        log_prob_actions[step] = log_prob_action
        entropies[step] = entropy
        values[step] = value_pred
        rewards[step] = reward
        masks[step] = 1 - done
    
        if done:
            state = env.reset()
    
    next_value = critic(torch.FloatTensor(state).unsqueeze(0)).squeeze(-1)
    returns = calculate_returns(rewards, next_value, masks, discount_factor)
    advantages = calculate_advantages(returns, values)
    
    policy_loss, value_loss = update_policy(advantages, log_prob_actions, returns, values, entropies, actor_optimizer, critic_optimizer)

    return policy_loss, value_loss

In [44]:
def calculate_returns(rewards, next_value, masks, discount_factor, normalize = False):
    
    returns = torch.zeros_like(rewards)
    R = next_value.item()
    
    for i, (r, m) in enumerate(zip(reversed(rewards), reversed(masks))):
        R = r + R * discount_factor * m
        returns[i] = R
    
    if normalize:
        
        returns = (returns - returns.mean()) / returns.std()
        
    return returns

In [45]:
def calculate_advantages(returns, values, normalize = False):
    
    advantages = returns - values
    
    if normalize:
        
        advantages = (advantages - advantages.mean()) / advantages.std()
        
    return advantages

In [46]:
def update_policy(advantages, log_prob_actions, returns, values, entropies, actor_optimizer, critic_optimizer):
        
    advantages = advantages.detach()
    returns = returns.detach()
        
    policy_loss = - (advantages * log_prob_actions).mean() - 0.001 * entropies.mean()
    
    value_loss = 0.5 * F.smooth_l1_loss(returns, values).mean()
        
    actor_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    
    policy_loss.backward()
    value_loss.backward()
    
    actor_optimizer.step()
    critic_optimizer.step()
    
    return policy_loss.item(), value_loss.item()

In [47]:
def evaluate(env, actor, critic):
    
    done = False
    episode_reward = 0
    
    state = env.reset()
    
    while not done:
        
        state = torch.FloatTensor(state).unsqueeze(0)
        
        action_preds = actor(state)
        
        action_probs = F.softmax(action_preds, dim = -1)
        
        dist = distributions.Categorical(action_probs)

        action = dist.sample() 
        
        state, reward, done, _ = env.step(action.item())
        
        episode_reward += reward
        
    return episode_reward

In [49]:
MAX_STEPS = 100_000
N_UPDATE_STEPS =  100
DISCOUNT_FACTOR = 0.99
N_TRIALS = 25
REWARD_THRESHOLD = 475
PRINT_EVERY = 10

episode_rewards = []

_ = train_env.reset()

for step in range(MAX_STEPS):
        
    policy_loss, value_loss = train(train_env, actor, critic, actor_optimizer, critic_optimizer, N_UPDATE_STEPS, DISCOUNT_FACTOR)
    
    episode_reward = evaluate(test_env, actor, critic)
    
    episode_rewards.append(episode_reward)
    
    mean_trial_rewards = np.mean(episode_rewards[-N_TRIALS:])
    
    if step % PRINT_EVERY == 0:
            
        print(f'| Steps: {step:6} | Mean Rewards: {mean_trial_rewards:6.2f} |')
    
    if mean_trial_rewards >= REWARD_THRESHOLD:
        
        print(f'Reached reward threshold in {N_UPDATE_STEPS*step} steps')
        
        break

| Steps:      0 | Mean Rewards:  10.00 |
| Steps:   1000 | Mean Rewards:  19.00 |
| Steps:   2000 | Mean Rewards:  19.10 |
| Steps:   3000 | Mean Rewards:  19.72 |
| Steps:   4000 | Mean Rewards:  18.40 |
| Steps:   5000 | Mean Rewards:  18.52 |
| Steps:   6000 | Mean Rewards:  19.16 |
| Steps:   7000 | Mean Rewards:  17.60 |
| Steps:   8000 | Mean Rewards:  18.64 |
| Steps:   9000 | Mean Rewards:  18.64 |
| Steps:  10000 | Mean Rewards:  17.80 |
| Steps:  11000 | Mean Rewards:  21.56 |
| Steps:  12000 | Mean Rewards:  23.00 |
| Steps:  13000 | Mean Rewards:  24.20 |
| Steps:  14000 | Mean Rewards:  22.96 |
| Steps:  15000 | Mean Rewards:  25.00 |
| Steps:  16000 | Mean Rewards:  21.28 |
| Steps:  17000 | Mean Rewards:  20.88 |
| Steps:  18000 | Mean Rewards:  19.40 |
| Steps:  19000 | Mean Rewards:  19.44 |
| Steps:  20000 | Mean Rewards:  19.00 |
| Steps:  21000 | Mean Rewards:  23.56 |
| Steps:  22000 | Mean Rewards:  24.00 |
| Steps:  23000 | Mean Rewards:  21.96 |
| Steps:  24000 

KeyboardInterrupt: 