In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distributions

import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
import numpy as np
import gym

In [37]:
train_env = gym.make('CartPole-v0')
test_env = gym.make('CartPole-v0')

assert isinstance(train_env.observation_space, gym.spaces.Box)
assert isinstance(train_env.action_space, gym.spaces.Discrete)

In [38]:
SEED = 1234

train_env.seed(SEED);
test_env.seed(SEED);
np.random.seed(SEED);
torch.manual_seed(SEED);

In [39]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout = 0.0):
        super().__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc_1(x)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.fc_2(x)
        return x

In [40]:
INPUT_DIM = train_env.observation_space.shape[0]
HIDDEN_DIM = 256
OUTPUT_DIM = train_env.action_space.n

actor = MLP(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM)
critic = MLP(INPUT_DIM, HIDDEN_DIM, 1)

In [41]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0)
        
actor.apply(init_weights)
critic.apply(init_weights)

MLP(
  (fc_1): Linear(in_features=4, out_features=256, bias=True)
  (fc_2): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [42]:
LEARNING_RATE = 0.01

actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-4)

In [43]:
def train(env, actor, critic, actor_optimizer, critic_optimizer, n_steps, discount_factor):
    
    log_prob_actions = torch.zeros(n_steps)
    entropies = torch.zeros(n_steps)
    values = torch.zeros(n_steps)
    rewards = torch.zeros(n_steps)
    masks = torch.zeros(n_steps)
    episode_reward = 0

    state = env.state

    for step in range(n_steps):

        state = torch.FloatTensor(state).unsqueeze(0)
        
        action_preds = actor(state)
        value_pred = critic(state).squeeze(-1)

        action_probs = F.softmax(action_preds, dim = -1)
                
        dist = distributions.Categorical(action_probs)

        action = dist.sample()
        
        log_prob_action = dist.log_prob(action)
        
        entropy = dist.entropy()
        
        state, reward, done, _ = env.step(action.item())

        log_prob_actions[step] = log_prob_action
        entropies[step] = entropy
        values[step] = value_pred
        rewards[step] = reward
        masks[step] = 1 - done
    
        if done:
            state = env.reset()
    
    next_value = critic(torch.FloatTensor(state).unsqueeze(0)).squeeze(-1)
    returns = calculate_returns(rewards, next_value, masks, discount_factor)
    advantages = calculate_advantages(returns, values)
    
    policy_loss, value_loss = update_policy(advantages, log_prob_actions, returns, values, entropies, actor_optimizer, critic_optimizer)

    return policy_loss, value_loss

In [44]:
def calculate_returns(rewards, next_value, masks, discount_factor, normalize = False):
    
    returns = torch.zeros_like(rewards)
    R = next_value.item()
    
    for i, (r, m) in enumerate(zip(reversed(rewards), reversed(masks))):
        R = r + R * discount_factor * m
        returns[i] = R
    
    if normalize:
        
        returns = (returns - returns.mean()) / returns.std()
        
    return returns

In [45]:
def calculate_advantages(returns, values, normalize = False):
    
    advantages = returns - values
    
    if normalize:
        
        advantages = (advantages - advantages.mean()) / advantages.std()
        
    return advantages

In [46]:
def update_policy(advantages, log_prob_actions, returns, values, entropies, actor_optimizer, critic_optimizer):
        
    advantages = advantages.detach()
    returns = returns.detach()
        
    policy_loss = - (advantages * log_prob_actions).mean() - 0.001 * entropies.mean()
    
    value_loss = 0.5 * F.smooth_l1_loss(returns, values).mean()
        
    actor_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    
    policy_loss.backward()
    value_loss.backward()
    
    actor_optimizer.step()
    critic_optimizer.step()
    
    return policy_loss.item(), value_loss.item()

In [47]:
def evaluate(env, actor, critic):
    
    done = False
    episode_reward = 0
    
    state = env.reset()
    
    while not done:
        
        state = torch.FloatTensor(state).unsqueeze(0)
        
        action_preds = actor(state)
        
        action_probs = F.softmax(action_preds, dim = -1)
        
        dist = distributions.Categorical(action_probs)

        action = dist.sample() 
        
        state, reward, done, _ = env.step(action.item())
        
        episode_reward += reward
        
    return episode_reward

In [50]:
MAX_STEPS = 100_000
N_UPDATE_STEPS =  100
DISCOUNT_FACTOR = 0.99
N_TRIALS = 25
REWARD_THRESHOLD = 475
PRINT_EVERY = 10

episode_rewards = []

_ = train_env.reset()

for step in range(MAX_STEPS):
        
    policy_loss, value_loss = train(train_env, actor, critic, actor_optimizer, critic_optimizer, N_UPDATE_STEPS, DISCOUNT_FACTOR)
    
    episode_reward = evaluate(test_env, actor, critic)
    
    episode_rewards.append(episode_reward)
    
    mean_trial_rewards = np.mean(episode_rewards[-N_TRIALS:])
    
    if step % PRINT_EVERY == 0:
            
        print(f'| Steps: {step:6} | Mean Rewards: {mean_trial_rewards:6.2f} |')
    
    if mean_trial_rewards >= REWARD_THRESHOLD:
        
        print(f'Reached reward threshold in {N_UPDATE_STEPS*step} steps')
        
        break

| Steps:      0 | Mean Rewards:  75.00 |
| Steps:     10 | Mean Rewards:  36.36 |
| Steps:     20 | Mean Rewards:  39.90 |
| Steps:     30 | Mean Rewards:  41.00 |
| Steps:     40 | Mean Rewards:  45.08 |
| Steps:     50 | Mean Rewards:  42.28 |
| Steps:     60 | Mean Rewards:  46.36 |
| Steps:     70 | Mean Rewards:  48.04 |
| Steps:     80 | Mean Rewards:  39.60 |
| Steps:     90 | Mean Rewards:  46.48 |
| Steps:    100 | Mean Rewards:  46.64 |
| Steps:    110 | Mean Rewards:  47.96 |
| Steps:    120 | Mean Rewards:  47.68 |
| Steps:    130 | Mean Rewards:  42.12 |
| Steps:    140 | Mean Rewards:  47.88 |
| Steps:    150 | Mean Rewards:  48.40 |
| Steps:    160 | Mean Rewards:  47.32 |
| Steps:    170 | Mean Rewards:  44.40 |
| Steps:    180 | Mean Rewards:  41.80 |
| Steps:    190 | Mean Rewards:  48.28 |
| Steps:    200 | Mean Rewards:  40.36 |
| Steps:    210 | Mean Rewards:  34.28 |
| Steps:    220 | Mean Rewards:  42.36 |
| Steps:    230 | Mean Rewards:  45.80 |
| Steps:    240 

| Steps:   2000 | Mean Rewards: 173.60 |
| Steps:   2010 | Mean Rewards: 167.48 |
| Steps:   2020 | Mean Rewards: 167.04 |
| Steps:   2030 | Mean Rewards: 170.12 |
| Steps:   2040 | Mean Rewards: 174.20 |
| Steps:   2050 | Mean Rewards: 165.60 |
| Steps:   2060 | Mean Rewards: 170.56 |
| Steps:   2070 | Mean Rewards: 173.44 |
| Steps:   2080 | Mean Rewards: 173.24 |
| Steps:   2090 | Mean Rewards: 178.12 |
| Steps:   2100 | Mean Rewards: 179.04 |
| Steps:   2110 | Mean Rewards: 173.44 |
| Steps:   2120 | Mean Rewards: 175.08 |
| Steps:   2130 | Mean Rewards: 174.76 |
| Steps:   2140 | Mean Rewards: 184.96 |
| Steps:   2150 | Mean Rewards: 188.32 |
| Steps:   2160 | Mean Rewards: 189.32 |
| Steps:   2170 | Mean Rewards: 182.24 |
| Steps:   2180 | Mean Rewards: 178.12 |
| Steps:   2190 | Mean Rewards: 181.20 |
| Steps:   2200 | Mean Rewards: 175.08 |
| Steps:   2210 | Mean Rewards: 170.68 |
| Steps:   2220 | Mean Rewards: 166.52 |
| Steps:   2230 | Mean Rewards: 179.08 |
| Steps:   2240 

| Steps:   4000 | Mean Rewards: 178.24 |
| Steps:   4010 | Mean Rewards: 183.32 |
| Steps:   4020 | Mean Rewards: 175.60 |
| Steps:   4030 | Mean Rewards: 170.68 |
| Steps:   4040 | Mean Rewards: 162.20 |
| Steps:   4050 | Mean Rewards: 160.84 |
| Steps:   4060 | Mean Rewards: 163.48 |
| Steps:   4070 | Mean Rewards: 170.20 |
| Steps:   4080 | Mean Rewards: 160.88 |
| Steps:   4090 | Mean Rewards: 155.52 |
| Steps:   4100 | Mean Rewards: 156.84 |
| Steps:   4110 | Mean Rewards: 158.36 |
| Steps:   4120 | Mean Rewards: 157.96 |
| Steps:   4130 | Mean Rewards: 154.28 |
| Steps:   4140 | Mean Rewards: 154.56 |
| Steps:   4150 | Mean Rewards: 156.60 |
| Steps:   4160 | Mean Rewards: 161.60 |
| Steps:   4170 | Mean Rewards: 151.72 |
| Steps:   4180 | Mean Rewards: 143.76 |
| Steps:   4190 | Mean Rewards: 141.88 |
| Steps:   4200 | Mean Rewards: 136.76 |
| Steps:   4210 | Mean Rewards: 126.64 |
| Steps:   4220 | Mean Rewards: 124.44 |
| Steps:   4230 | Mean Rewards: 125.36 |
| Steps:   4240 

| Steps:   6000 | Mean Rewards: 187.56 |
| Steps:   6010 | Mean Rewards: 193.00 |
| Steps:   6020 | Mean Rewards: 189.52 |
| Steps:   6030 | Mean Rewards: 182.84 |
| Steps:   6040 | Mean Rewards: 180.32 |
| Steps:   6050 | Mean Rewards: 176.92 |
| Steps:   6060 | Mean Rewards: 174.72 |
| Steps:   6070 | Mean Rewards: 171.48 |
| Steps:   6080 | Mean Rewards: 180.80 |
| Steps:   6090 | Mean Rewards: 174.68 |
| Steps:   6100 | Mean Rewards: 166.08 |
| Steps:   6110 | Mean Rewards: 166.20 |
| Steps:   6120 | Mean Rewards: 166.56 |
| Steps:   6130 | Mean Rewards: 175.80 |
| Steps:   6140 | Mean Rewards: 178.20 |
| Steps:   6150 | Mean Rewards: 168.28 |
| Steps:   6160 | Mean Rewards: 168.68 |
| Steps:   6170 | Mean Rewards: 176.44 |
| Steps:   6180 | Mean Rewards: 188.76 |
| Steps:   6190 | Mean Rewards: 194.60 |
| Steps:   6200 | Mean Rewards: 194.64 |
| Steps:   6210 | Mean Rewards: 191.80 |
| Steps:   6220 | Mean Rewards: 194.04 |
| Steps:   6230 | Mean Rewards: 189.48 |
| Steps:   6240 

| Steps:   8000 | Mean Rewards: 192.44 |
| Steps:   8010 | Mean Rewards: 197.20 |
| Steps:   8020 | Mean Rewards: 198.52 |
| Steps:   8030 | Mean Rewards: 199.08 |
| Steps:   8040 | Mean Rewards: 199.08 |
| Steps:   8050 | Mean Rewards: 200.00 |
| Steps:   8060 | Mean Rewards: 200.00 |
| Steps:   8070 | Mean Rewards: 200.00 |
| Steps:   8080 | Mean Rewards: 200.00 |
| Steps:   8090 | Mean Rewards: 200.00 |
| Steps:   8100 | Mean Rewards: 200.00 |
| Steps:   8110 | Mean Rewards: 200.00 |
| Steps:   8120 | Mean Rewards: 200.00 |
| Steps:   8130 | Mean Rewards: 200.00 |
| Steps:   8140 | Mean Rewards: 199.08 |
| Steps:   8150 | Mean Rewards: 197.84 |
| Steps:   8160 | Mean Rewards: 198.36 |
| Steps:   8170 | Mean Rewards: 194.16 |
| Steps:   8180 | Mean Rewards: 193.04 |
| Steps:   8190 | Mean Rewards: 192.76 |
| Steps:   8200 | Mean Rewards: 198.20 |
| Steps:   8210 | Mean Rewards: 195.12 |
| Steps:   8220 | Mean Rewards: 196.92 |
| Steps:   8230 | Mean Rewards: 198.72 |
| Steps:   8240 

KeyboardInterrupt: 