In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

In [2]:
def mlp(layer_sizes, activation=nn.Tanh, output_activation=nn.Identity):
    layers = []

    for i in range(len(layer_sizes)-1):
        act_func = activation if (i < len(layer_sizes)-2) else output_activation
        layers += [nn.Linear(layer_sizes[i], layer_sizes[i+1]), act_func()]

    return nn.Sequential(*layers)

In [3]:
mlp([10, 32, 2])

Sequential(
  (0): Linear(in_features=10, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

In [4]:
import gym
import numpy as np

import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

In [5]:
gym_env = "CartPole-v0"

env = gym.make(gym_env)

assert isinstance(env.observation_space, gym.spaces.Box), "This environment only works for environment with continuous state spaces"
assert isinstance(env.action_space, gym.spaces.Discrete), "This environment only works for environment with discrete action spaces"



In [6]:
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [7]:
HIDDEN_SIZES = [32]

logits_net = mlp([obs_dim]+HIDDEN_SIZES+[act_dim])
logits_net

Sequential(
  (0): Linear(in_features=4, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

In [8]:
def get_policy(obs):
    logits  = logits_net(obs)
    return Categorical(logits=logits)

In [9]:
def get_action(obs):
    return get_policy(obs).sample().item()

In [10]:
def compute_loss(obs, acts, weights):
    logp = get_policy(obs).log_prob(acts)
    return -(logp * weights).mean()

In [21]:
single_obs = torch.randn([4])
batch_obs = torch.randn([100, 4])
get_action(single_obs), logits_net(single_obs).shape, logits_net(batch_obs).shape

(1, torch.Size([2]), torch.Size([100, 2]))

In [11]:
obs = torch.randn(100, 4) # batch, obs_dim
policy = get_policy(obs) # Categorical Policy - a distribution is returned

print("Policy:", policy)
print(policy.sample())
acts = torch.randint(0, 2, [100])
policy.log_prob(acts)

Policy: Categorical(logits: torch.Size([100, 2]))
tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
        1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0,
        0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0,
        1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1,
        1, 1, 0, 1])


tensor([-0.6972, -1.2423, -1.3719, -1.0404, -0.8116, -0.5229, -0.6577, -1.0504,
        -0.5490, -0.3167, -0.6885, -0.5156, -1.1858, -1.0370, -0.6696, -0.6612,
        -0.5351, -0.7732, -1.0667, -0.3474, -0.7559, -0.3566, -0.5733, -0.5500,
        -0.5365, -0.4914, -0.8499, -0.9789, -0.4449, -0.8379, -0.3769, -0.4798,
        -0.8374, -0.9946, -1.0269, -0.7567, -0.6642, -0.7981, -0.4114, -0.7447,
        -0.8401, -1.5441, -0.3796, -1.3311, -0.5071, -0.5817, -0.4597, -0.4234,
        -0.5877, -1.0471, -0.7150, -0.6503, -0.6240, -0.2856, -0.6392, -0.7687,
        -0.6203, -0.9354, -0.5417, -0.6070, -0.4574, -0.6183, -0.2483, -0.4218,
        -1.5511, -0.2440, -0.7468, -0.4616, -0.4793, -0.3291, -0.6744, -0.6055,
        -0.8404, -0.5995, -0.9508, -1.2535, -0.4451, -0.7578, -0.9332, -0.5868,
        -0.5781, -0.8848, -0.7732, -0.3316, -0.3877, -0.3990, -0.9403, -0.4419,
        -0.4145, -1.4540, -0.3879, -0.7687, -0.5717, -0.2877, -0.5896, -0.6645,
        -0.4730, -0.2664, -0.5230, -1.57

In [43]:
def train_one_epoch(env, optimizer, batch_size, render=True):
    batch_obs = []
    batch_acts = []
    batch_weights = []
    batch_rets = []
    batch_lens = []
    
    # get initial observation from starting distribution
    obs = env.reset()
    done = False
    ep_rews = []
    
    # Render the first episode of the epoch
    finish_rendering_this_epoch = False
    
    while True:
        # rendering
        if (not finish_rendering_this_epoch) and render:
            env.render()
            # plt.imshow(env.render(mode="rgb_array"))
            # display.display(plt.gcf())
            # display.clear_output(wait=True)
        
        # save the current observation
        batch_obs.append(obs.copy())
        
        # get action for the current observation
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _ = env.step(act)
        
        # save the action and reward
        ep_rews.append(rew)
        batch_acts.append(act)
        
        if done:
            # record info about the episode
            ep_ret = sum(ep_rews)
            ep_len = len(ep_rews)
            
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            batch_weights += [ep_ret] * ep_len # the weights are the returns for each episode, broadcasted to support
                                                # the operation in function compute_loss
            
            #reset the environment
            obs = env.reset()
            done = False
            ep_rews = []
            
            # won't render again after first episode in the epoch
            finish_rendering_this_epoch = True
            
            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
        
    # perform a single update step
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32), 
                              acts=torch.as_tensor(batch_acts, dtype=torch.float32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32))
    batch_loss.backward()
    optimizer.step()
    env.close()

    return batch_loss.item(), batch_rets, batch_lens
        

In [44]:
LR = 0.1

optimizer = torch.optim.Adam(logits_net.parameters(), lr=LR)

In [45]:
BATCH_SIZES = 5_000
EPOCHS = 50

# Training Loop
for epoch in range(EPOCHS):
    batch_loss, batch_rets, batch_lens = train_one_epoch(env, optimizer, BATCH_SIZES, render=True)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {batch_loss:.2f} Average Return: {np.mean(batch_rets):.1f} Average Steps: {np.mean(batch_lens):.1f}")

Epoch [1/50], Loss: 18.80 Average Return: 21.6 Average Steps: 21.6
Epoch [2/50], Loss: 18.02 Average Return: 24.1 Average Steps: 24.1
Epoch [3/50], Loss: 38.26 Average Return: 58.5 Average Steps: 58.5
Epoch [4/50], Loss: 16.72 Average Return: 30.2 Average Steps: 30.2
Epoch [5/50], Loss: 14.36 Average Return: 28.9 Average Steps: 28.9
Epoch [6/50], Loss: 15.51 Average Return: 34.0 Average Steps: 34.0
Epoch [7/50], Loss: 18.60 Average Return: 45.0 Average Steps: 45.0
Epoch [8/50], Loss: 27.18 Average Return: 69.6 Average Steps: 69.6
Epoch [9/50], Loss: 33.27 Average Return: 101.3 Average Steps: 101.3
Epoch [10/50], Loss: 25.82 Average Return: 92.0 Average Steps: 92.0
Epoch [11/50], Loss: 24.51 Average Return: 102.1 Average Steps: 102.1
Epoch [12/50], Loss: 22.96 Average Return: 109.5 Average Steps: 109.5
Epoch [13/50], Loss: 21.48 Average Return: 113.4 Average Steps: 113.4
Epoch [14/50], Loss: 23.43 Average Return: 127.9 Average Steps: 127.9
Epoch [15/50], Loss: 30.56 Average Return: 174.