In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as distributions

import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
import numpy as np
import gym

In [2]:
train_env = gym.make('CartPole-v1')
test_env = gym.make('CartPole-v1')

assert isinstance(train_env.observation_space, gym.spaces.Box)
assert isinstance(train_env.action_space, gym.spaces.Discrete)

In [3]:
SEED = 1234

train_env.seed(SEED);
test_env.seed(SEED);
np.random.seed(SEED);
torch.manual_seed(SEED);

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout = 0.0):
        super().__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc_1(x)
        x = self.dropout(x)
        x = F.relu(x)
        x = self.fc_2(x)
        return x

In [5]:
INPUT_DIM = train_env.observation_space.shape[0]
HIDDEN_DIM = 256
OUTPUT_DIM = train_env.action_space.n

actor = MLP(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM)
critic = MLP(INPUT_DIM, HIDDEN_DIM, 1)

In [6]:
"""def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0)
        
actor.apply(init_weights)
critic.apply(init_weights)"""

'def init_weights(m):\n    if type(m) == nn.Linear:\n        torch.nn.init.xavier_normal_(m.weight)\n        m.bias.data.fill_(0)\n        \nactor.apply(init_weights)\ncritic.apply(init_weights)'

In [7]:
LEARNING_RATE = 0.01

actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-4)

In [8]:
def train(env, actor, critic, actor_optimizer, critic_optimizer, n_steps, discount_factor):
    
    log_prob_actions = torch.zeros(n_steps)
    entropies = torch.zeros(n_steps)
    values = torch.zeros(n_steps)
    rewards = torch.zeros(n_steps)
    masks = torch.zeros(n_steps)
    episode_reward = 0

    state = env.state

    for step in range(n_steps):

        state = torch.FloatTensor(state).unsqueeze(0)
        
        action_preds = actor(state)
        value_pred = critic(state).squeeze(-1)

        action_probs = F.softmax(action_preds, dim = -1)
                
        dist = distributions.Categorical(action_probs)

        action = dist.sample()
        
        log_prob_action = dist.log_prob(action)
        
        entropy = dist.entropy()
        
        state, reward, done, _ = env.step(action.item())

        log_prob_actions[step] = log_prob_action
        entropies[step] = entropy
        values[step] = value_pred
        rewards[step] = reward
        masks[step] = 1 - done
    
        if done:
            state = env.reset()
    
    next_value = critic(torch.FloatTensor(state).unsqueeze(0)).squeeze(-1)
    returns = calculate_returns(rewards, next_value, masks, discount_factor)
    advantages = calculate_advantages(returns, values)
    
    policy_loss, value_loss = update_policy(advantages, log_prob_actions, returns, values, entropies, actor_optimizer, critic_optimizer)

    return policy_loss, value_loss

In [9]:
def calculate_returns(rewards, next_value, masks, discount_factor, normalize = False):
    
    returns = torch.zeros_like(rewards)
    R = next_value.item()
    
    for i, (r, m) in enumerate(zip(reversed(rewards), reversed(masks))):
        R = r + R * discount_factor * m
        returns[i] = R
    
    if normalize:
        
        returns = (returns - returns.mean()) / returns.std()
        
    return returns

In [10]:
def calculate_advantages(returns, values, normalize = False):
    
    advantages = returns - values
    
    if normalize:
        
        advantages = (advantages - advantages.mean()) / advantages.std()
        
    return advantages

In [11]:
def update_policy(advantages, log_prob_actions, returns, values, entropies, actor_optimizer, critic_optimizer):
        
    advantages = advantages.detach()
    returns = returns.detach()
        
    policy_loss = - (advantages * log_prob_actions).mean() - 0.001 * entropies.mean()
    
    value_loss = 0.5 * F.smooth_l1_loss(returns, values).mean()
        
    actor_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    
    policy_loss.backward()
    value_loss.backward()
    
    actor_optimizer.step()
    critic_optimizer.step()
    
    return policy_loss.item(), value_loss.item()

In [12]:
def evaluate(env, actor, critic):
    
    done = False
    episode_reward = 0
    
    state = env.reset()
    
    while not done:
        
        state = torch.FloatTensor(state).unsqueeze(0)
        
        action_preds = actor(state)
        
        action_probs = F.softmax(action_preds, dim = -1)
        
        dist = distributions.Categorical(action_probs)

        action = dist.sample() 
        
        state, reward, done, _ = env.step(action.item())
        
        episode_reward += reward
        
    return episode_reward

In [13]:
MAX_STEPS = 100_000
N_UPDATE_STEPS =  100
DISCOUNT_FACTOR = 0.99
N_TRIALS = 25
REWARD_THRESHOLD = 475
PRINT_EVERY = 10

episode_rewards = []

_ = train_env.reset()

for step in tqdm(range(MAX_STEPS)):
        
    policy_loss, value_loss = train(train_env, actor, critic, actor_optimizer, critic_optimizer, N_UPDATE_STEPS, DISCOUNT_FACTOR)
    
    episode_reward = evaluate(test_env, actor, critic)
    
    episode_rewards.append(episode_reward)
    
    mean_trial_rewards = np.mean(episode_rewards[-N_TRIALS:])
    
    if step % PRINT_EVERY == 0:
            
        print(f'| Steps: {N_UPDATE_STEPS*step:6} | Mean Rewards: {mean_trial_rewards:6.2f} |')
    
    if mean_trial_rewards >= REWARD_THRESHOLD:
        
        print(f'Reached reward threshold in {N_UPDATE_STEPS*step} steps')
        
        break

HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))

| Steps:      0 | Mean Rewards:   8.00 |
| Steps:   1000 | Mean Rewards:  17.73 |
| Steps:   2000 | Mean Rewards:  20.05 |
| Steps:   3000 | Mean Rewards:  22.32 |
| Steps:   4000 | Mean Rewards:  22.20 |
| Steps:   5000 | Mean Rewards:  19.16 |
| Steps:   6000 | Mean Rewards:  15.52 |
| Steps:   7000 | Mean Rewards:  17.20 |
| Steps:   8000 | Mean Rewards:  18.68 |
| Steps:   9000 | Mean Rewards:  20.40 |
| Steps:  10000 | Mean Rewards:  22.84 |
| Steps:  11000 | Mean Rewards:  25.56 |
| Steps:  12000 | Mean Rewards:  22.88 |
| Steps:  13000 | Mean Rewards:  19.24 |
| Steps:  14000 | Mean Rewards:  22.16 |
| Steps:  15000 | Mean Rewards:  20.68 |
| Steps:  16000 | Mean Rewards:  22.04 |
| Steps:  17000 | Mean Rewards:  19.84 |
| Steps:  18000 | Mean Rewards:  22.96 |
| Steps:  19000 | Mean Rewards:  23.00 |
| Steps:  20000 | Mean Rewards:  21.80 |
| Steps:  21000 | Mean Rewards:  20.76 |
| Steps:  22000 | Mean Rewards:  25.92 |
| Steps:  23000 | Mean Rewards:  28.28 |
| Steps:  24000 

| Steps: 200000 | Mean Rewards: 167.32 |
| Steps: 201000 | Mean Rewards: 151.08 |
| Steps: 202000 | Mean Rewards: 137.28 |
| Steps: 203000 | Mean Rewards: 135.60 |
| Steps: 204000 | Mean Rewards: 112.80 |
| Steps: 205000 | Mean Rewards: 126.84 |
| Steps: 206000 | Mean Rewards: 143.48 |
| Steps: 207000 | Mean Rewards: 150.52 |
| Steps: 208000 | Mean Rewards: 148.32 |
| Steps: 209000 | Mean Rewards: 166.24 |
| Steps: 210000 | Mean Rewards: 156.04 |
| Steps: 211000 | Mean Rewards: 158.96 |
| Steps: 212000 | Mean Rewards: 144.52 |
| Steps: 213000 | Mean Rewards: 147.12 |
| Steps: 214000 | Mean Rewards: 168.76 |
| Steps: 215000 | Mean Rewards: 180.16 |
| Steps: 216000 | Mean Rewards: 175.96 |
| Steps: 217000 | Mean Rewards: 173.72 |
| Steps: 218000 | Mean Rewards: 163.16 |
| Steps: 219000 | Mean Rewards: 163.88 |
| Steps: 220000 | Mean Rewards: 172.76 |
| Steps: 221000 | Mean Rewards: 173.48 |
| Steps: 222000 | Mean Rewards: 162.28 |
| Steps: 223000 | Mean Rewards: 177.72 |
| Steps: 224000 

| Steps: 400000 | Mean Rewards: 373.04 |
| Steps: 401000 | Mean Rewards: 399.16 |
| Steps: 402000 | Mean Rewards: 391.84 |
| Steps: 403000 | Mean Rewards: 394.60 |
| Steps: 404000 | Mean Rewards: 394.44 |
| Steps: 405000 | Mean Rewards: 391.80 |
| Steps: 406000 | Mean Rewards: 397.52 |
| Steps: 407000 | Mean Rewards: 409.84 |
| Steps: 408000 | Mean Rewards: 380.64 |
| Steps: 409000 | Mean Rewards: 340.48 |
| Steps: 410000 | Mean Rewards: 338.44 |
| Steps: 411000 | Mean Rewards: 322.60 |
| Steps: 412000 | Mean Rewards: 281.36 |
| Steps: 413000 | Mean Rewards: 234.84 |
| Steps: 414000 | Mean Rewards: 240.76 |
| Steps: 415000 | Mean Rewards: 247.20 |
| Steps: 416000 | Mean Rewards: 267.56 |
| Steps: 417000 | Mean Rewards: 262.96 |
| Steps: 418000 | Mean Rewards: 243.92 |
| Steps: 419000 | Mean Rewards: 234.76 |
| Steps: 420000 | Mean Rewards: 233.92 |
| Steps: 421000 | Mean Rewards: 245.84 |
| Steps: 422000 | Mean Rewards: 258.00 |
| Steps: 423000 | Mean Rewards: 289.04 |
| Steps: 424000 

| Steps: 600000 | Mean Rewards: 267.28 |
| Steps: 601000 | Mean Rewards: 307.32 |
| Steps: 602000 | Mean Rewards: 345.68 |
| Steps: 603000 | Mean Rewards: 353.92 |
| Steps: 604000 | Mean Rewards: 319.68 |
| Steps: 605000 | Mean Rewards: 295.80 |
| Steps: 606000 | Mean Rewards: 314.76 |
| Steps: 607000 | Mean Rewards: 329.72 |
| Steps: 608000 | Mean Rewards: 390.28 |
| Steps: 609000 | Mean Rewards: 422.48 |
| Steps: 610000 | Mean Rewards: 406.08 |
| Steps: 611000 | Mean Rewards: 376.20 |
| Steps: 612000 | Mean Rewards: 385.68 |
| Steps: 613000 | Mean Rewards: 373.92 |
| Steps: 614000 | Mean Rewards: 410.20 |
| Steps: 615000 | Mean Rewards: 437.12 |
| Steps: 616000 | Mean Rewards: 416.48 |
| Steps: 617000 | Mean Rewards: 437.56 |
| Steps: 618000 | Mean Rewards: 450.60 |
Reached reward threshold in 618400 steps
