In [32]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

In [33]:
def mlp(layer_sizes, activation=nn.Tanh, output_activation=nn.Identity):
    layers = []

    for i in range(len(layer_sizes)-1):
        act_func = activation if (i < len(layer_sizes)-2) else output_activation
        layers += [nn.Linear(layer_sizes[i], layer_sizes[i+1]), act_func()]

    return nn.Sequential(*layers)

In [34]:
mlp([10, 32, 2])

Sequential(
  (0): Linear(in_features=10, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

In [35]:
import gym
import numpy as np

import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

In [36]:
gym_env = "CartPole-v0"

env = gym.make(gym_env)

assert isinstance(env.observation_space, gym.spaces.Box), "This environment only works for environment with continuous state spaces"
assert isinstance(env.action_space, gym.spaces.Discrete), "This environment only works for environment with discrete action spaces"

In [37]:
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [38]:
HIDDEN_SIZES = [32]

logits_net = mlp([obs_dim]+HIDDEN_SIZES+[act_dim])
logits_net

Sequential(
  (0): Linear(in_features=4, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

In [39]:
def get_policy(obs):
    logits  = logits_net(obs)
    return Categorical(logits=logits)

In [40]:
def get_action(obs):
    return get_policy(obs).sample().item()

In [41]:
def compute_loss(obs, acts, weights):
    logp = get_policy(obs).log_prob(acts)
    return -(logp * weights).mean()

In [42]:
obs = torch.randn(100, 4) # batch, obs_dim
policy = get_policy(obs) # Categorical Policy - a distribution is returned

print("Policy:", policy)
print(policy.sample())
acts = torch.randint(0, 2, [100])
policy.log_prob(acts)

Policy: Categorical(logits: torch.Size([100, 2]))
tensor([1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0,
        0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0,
        1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0,
        0, 0, 1, 0])


tensor([-0.7274, -0.6375, -0.6431, -0.4836, -0.5551, -0.5950, -0.5674, -0.8721,
        -1.0282, -0.9089, -0.5469, -0.6654, -0.6133, -0.8735, -0.8167, -0.6684,
        -0.6578, -0.8105, -0.6609, -0.7503, -0.8359, -0.7647, -0.6769, -0.6627,
        -0.8990, -0.5357, -0.6252, -0.9206, -0.5655, -0.8761, -0.4550, -0.7382,
        -0.5013, -0.5299, -0.7774, -0.6554, -0.5443, -0.8650, -0.7305, -0.7554,
        -0.8329, -0.7287, -0.8471, -0.5098, -0.5916, -0.5357, -0.9608, -0.8672,
        -0.8355, -0.8727, -0.6262, -0.6127, -0.5338, -0.9311, -0.5907, -0.8618,
        -0.5818, -0.5532, -0.5281, -0.9945, -0.8214, -0.4418, -0.7916, -0.6644,
        -0.9708, -0.7182, -0.7832, -0.8636, -1.0214, -0.4789, -0.5438, -0.7430,
        -0.7120, -0.6144, -0.6762, -0.5949, -0.6695, -0.5573, -0.7534, -0.8118,
        -0.6147, -1.0073, -1.0187, -1.0009, -0.5821, -0.6101, -0.9304, -0.6295,
        -0.5807, -0.5537, -0.7836, -0.7905, -0.4483, -0.7675, -0.8361, -0.8706,
        -0.6805, -0.5934, -0.6450, -0.56

In [43]:
def train_one_epoch(env, optimizer, batch_size, render=True):
    batch_obs = []
    batch_acts = []
    batch_weights = []
    batch_rets = []
    batch_lens = []
    
    # get initial observation from starting distribution
    obs = env.reset()
    done = False
    ep_rews = []
    
    # Render the first episode of the epoch
    finish_rendering_this_epoch = False
    
    while True:
        # rendering
        if (not finish_rendering_this_epoch) and render:
            env.render()
            # plt.imshow(env.render(mode="rgb_array"))
            # display.display(plt.gcf())
            # display.clear_output(wait=True)
        
        # save the current observation
        batch_obs.append(obs.copy())
        
        # get action for the current observation
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _ = env.step(act)
        
        # save the action and reward
        ep_rews.append(rew)
        batch_acts.append(act)
        
        if done:
            # record info about the episode
            ep_ret = sum(ep_rews)
            ep_len = len(ep_rews)
            
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            batch_weights += [ep_ret] * ep_len # the weights are the returns for each episode, broadcasted to support
                                                # the operation in function compute_loss
            
            #reset the environment
            obs = env.reset()
            done = False
            ep_rews = []
            
            # won't render again after first episode in the epoch
            finish_rendering_this_epoch = True
            
            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
        
    # perform a single update step
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32), 
                              acts=torch.as_tensor(batch_acts, dtype=torch.float32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32))
    batch_loss.backward()
    optimizer.step()
    env.close()

    return batch_loss.item(), batch_rets, batch_lens
        

In [44]:
LR = 0.1

optimizer = torch.optim.Adam(logits_net.parameters(), lr=LR)

In [45]:
BATCH_SIZES = 5_000
EPOCHS = 50

# Training Loop
for epoch in range(EPOCHS):
    batch_loss, batch_rets, batch_lens = train_one_epoch(env, optimizer, BATCH_SIZES, render=True)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {batch_loss:.2f} Average Return: {np.mean(batch_rets):.1f} Average Steps: {np.mean(batch_lens):.1f}")

Epoch [1/50], Loss: 18.80 Average Return: 21.6 Average Steps: 21.6
Epoch [2/50], Loss: 18.02 Average Return: 24.1 Average Steps: 24.1
Epoch [3/50], Loss: 38.26 Average Return: 58.5 Average Steps: 58.5
Epoch [4/50], Loss: 16.72 Average Return: 30.2 Average Steps: 30.2
Epoch [5/50], Loss: 14.36 Average Return: 28.9 Average Steps: 28.9
Epoch [6/50], Loss: 15.51 Average Return: 34.0 Average Steps: 34.0
Epoch [7/50], Loss: 18.60 Average Return: 45.0 Average Steps: 45.0
Epoch [8/50], Loss: 27.18 Average Return: 69.6 Average Steps: 69.6
Epoch [9/50], Loss: 33.27 Average Return: 101.3 Average Steps: 101.3
Epoch [10/50], Loss: 25.82 Average Return: 92.0 Average Steps: 92.0
Epoch [11/50], Loss: 24.51 Average Return: 102.1 Average Steps: 102.1
Epoch [12/50], Loss: 22.96 Average Return: 109.5 Average Steps: 109.5
Epoch [13/50], Loss: 21.48 Average Return: 113.4 Average Steps: 113.4
Epoch [14/50], Loss: 23.43 Average Return: 127.9 Average Steps: 127.9
Epoch [15/50], Loss: 30.56 Average Return: 174.