In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

In [35]:
def mlp(layer_sizes, activation=nn.Tanh, output_activation=nn.Identity):
    layers = []

    for i in range(len(layer_sizes)-1):
        act_func = activation if (i < len(layer_sizes)-2) else output_activation
        layers += [nn.Linear(layer_sizes[i], layer_sizes[i+1]), act_func()]

    return nn.Sequential(*layers)

In [6]:
mlp([10, 32, 2])

Sequential(
  (0): Linear(in_features=10, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

In [8]:
import gym
import numpy as np

In [23]:
gym_env = "CartPole-v0"

env = gym.make(gym_env)

assert isinstance(env.observation_space, gym.spaces.Box), "This environment only works for environment with continuous state spaces"
assert isinstance(env.action_space, gym.spaces.Discrete), "This environment only works for environment with discrete action spaces"

In [38]:
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

In [98]:
HIDDEN_SIZES = [32]

logits_net = mlp([obs_dim]+HIDDEN_SIZES+[act_dim])
logits_net

Sequential(
  (0): Linear(in_features=4, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

In [52]:
def get_policy(obs):
    logits  = logits_net(obs)
    return Categorical(logits=logits)

In [42]:
def get_action(obs):
    return get_policy(obs).sample().item()

In [43]:
def compute_loss(obs, acts, weights):
    logp = get_policy(obs).log_prob(acts)
    return -(logp * weights).mean()

In [67]:
obs = torch.randn(100, 4) # batch, obs_dim
policy = get_policy(obs) # Categorical Policy - a distribution is returned

print("Policy:", policy)
print(policy.sample())
acts = torch.randint(0, 2, [100])
policy.log_prob(acts)

Policy: Categorical(logits: torch.Size([100, 2]))
tensor([1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,
        0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1,
        0, 0, 1, 1])


tensor([-0.5415, -0.8377, -0.6083, -0.7339, -0.6113, -0.7687, -0.5087, -0.6006,
        -0.7922, -0.4937, -0.8201, -0.4605, -0.5554, -0.5678, -0.5015, -0.5988,
        -0.5054, -0.7782, -1.0411, -1.0663, -0.5864, -0.6540, -1.0945, -0.9020,
        -0.7447, -0.7926, -0.9266, -0.5421, -0.5797, -0.5679, -0.7111, -0.4245,
        -1.2585, -0.5568, -1.2851, -0.8817, -0.6015, -0.7757, -0.9865, -0.8387,
        -0.5655, -0.9332, -0.7439, -0.5720, -0.9351, -0.4925, -0.3700, -0.5772,
        -0.7524, -0.5554, -0.3710, -0.7247, -0.5859, -0.7001, -0.3992, -0.9132,
        -0.8990, -0.8715, -0.5999, -0.8695, -0.3711, -0.5721, -1.1485, -0.9614,
        -0.6303, -0.4826, -0.9654, -0.4747, -0.4386, -0.6433, -0.6811, -1.0519,
        -0.5775, -0.9157, -0.4639, -0.6025, -0.7009, -0.8217, -0.8470, -0.8259,
        -0.8275, -0.7497, -0.6220, -0.7388, -0.4583, -0.7115, -0.5521, -0.5768,
        -0.8267, -1.1320, -0.7166, -0.8797, -0.9437, -0.9349, -0.8862, -1.0286,
        -0.8073, -0.5718, -0.8782, -0.67

In [91]:
def train_one_epoch(env, optimizer, batch_size, render=False):
    batch_obs = []
    batch_acts = []
    batch_weights = []
    batch_rets = []
    batch_lens = []
    
    # get initial observation from starting distribution
    obs = env.reset()
    done = False
    ep_rews = []
    
    # Render the first episode of the epoch
    finish_rendering_this_epoch = False
    
    while True:
        # rendering
        if (not finish_rendering_this_epoch) and render:
            env.render()
        
        # save the current observation
        batch_obs.append(obs.copy())
        
        # get action for the current observation
        act = get_action(torch.as_tensor(obs, dtype=torch.float32))
        obs, rew, done, _ = env.step(act)
        
        # save the action and reward
        ep_rews.append(rew)
        batch_acts.append(act)
        
        if done:
            # record info about the episode
            ep_ret = sum(ep_rews)
            ep_len = len(ep_rews)
            
            batch_rets.append(ep_ret)
            batch_lens.append(ep_len)
            batch_weights += [ep_ret] * ep_len # the weights are the returns for each episode, broadcasted to support
                                                # the operation in function compute_loss
            
            #reset the environment
            obs = env.reset()
            done = False
            ep_rews = []
            
            # won't render again after first episode in the epoch
            finish_rendering_this_epoch = False
            
            # end experience loop if we have enough of it
            if len(batch_obs) > batch_size:
                break
        
    # perform a single update step
    optimizer.zero_grad()
    batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32), 
                              acts=torch.as_tensor(batch_acts, dtype=torch.float32),
                              weights=torch.as_tensor(batch_weights, dtype=torch.float32))
    batch_loss.backward()
    optimizer.step()

    return batch_loss.item(), batch_rets, batch_lens
        

In [99]:
LR = 0.1

optimizer = torch.optim.Adam(logits_net.parameters(), lr=LR)

In [100]:
!export LANG=en_US

BATCH_SIZES = 20_000
EPOCHS = 50

# Training Loop
for epoch in range(EPOCHS):
    batch_loss, batch_rets, batch_lens = train_one_epoch(env, optimizer, BATCH_SIZES)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {batch_loss:.2f} Average Return: {np.mean(batch_rets):.1f} Average Steps: {np.mean(batch_lens):.1f}")

Epoch [1/50], Loss: 20.57 Average Return: 23.2 Average Steps: 23.2
Epoch [2/50], Loss: 12.65 Average Return: 18.9 Average Steps: 18.9
Epoch [3/50], Loss: 26.81 Average Return: 41.0 Average Steps: 41.0
Epoch [4/50], Loss: 30.55 Average Return: 56.5 Average Steps: 56.5
Epoch [5/50], Loss: 22.75 Average Return: 44.0 Average Steps: 44.0
Epoch [6/50], Loss: 23.23 Average Return: 49.5 Average Steps: 49.5
Epoch [7/50], Loss: 27.89 Average Return: 68.6 Average Steps: 68.6
Epoch [8/50], Loss: 29.95 Average Return: 84.5 Average Steps: 84.5
Epoch [9/50], Loss: 24.14 Average Return: 76.6 Average Steps: 76.6
Epoch [10/50], Loss: 22.85 Average Return: 80.2 Average Steps: 80.2
Epoch [11/50], Loss: 21.65 Average Return: 91.0 Average Steps: 91.0
Epoch [12/50], Loss: 23.71 Average Return: 107.6 Average Steps: 107.6
Epoch [13/50], Loss: 27.17 Average Return: 123.7 Average Steps: 123.7
Epoch [14/50], Loss: 28.18 Average Return: 150.7 Average Steps: 150.7
Epoch [15/50], Loss: 35.74 Average Return: 196.9 Av

KeyboardInterrupt: 