In [1]:
!pip install gym



In [2]:
import torch
from torch import nn
from torch.distributions.categorical import Categorical
from torch import optim

import numpy as np
import gym
from gym.spaces import Discrete, Box

In [3]:
print(f"torch version: {torch.__version__}")
print(f"gym version: {gym.__version__}")

torch version: 1.7.0
gym version: 0.18.0


## Define Model

In [28]:
def mlp(sizes, hidden_act_fn=nn.Tanh, output_act_fn=nn.Identity):
    '''
    sizes is list of integers specifying the number of nodes
        in each layer of the network, including input and output layers.
    Returns a torch.nn.Sequential object.
    '''
    assert isinstance(sizes, list)
    layers = []
    num_gaps = len(sizes) - 1
    for i in range(num_gaps):
        act_fn = hidden_act_fn if i < num_gaps-1 else output_act_fn
        layers.extend([nn.Linear(sizes[i], sizes[i+1]), act_fn()])
    return nn.Sequential(*layers)

In [29]:
net = mlp([4,32,2])
net

Sequential(
  (0): Linear(in_features=4, out_features=32, bias=True)
  (1): Tanh()
  (2): Linear(in_features=32, out_features=2, bias=True)
  (3): Identity()
)

## Cumulative Future Discounted Rewards helper function

In [25]:
def accumulate_discount(trajec_rewards, gamma=0.99):
    '''
    trajec_rewards must be a list of scalar reward values for each step.
    Returns a list of reverse-accumlated, discounted rewards, where each value
        represents the cumulative discounted rewards from that step onwards up to the end of the trajectory.
    '''
    assert isinstance(trajec_rewards, list)
    trajec_len = len(trajec_rewards)
    cum_disc_rewards = [None for i in range(trajec_len)]
    for step in reversed(range(trajec_len)):
        cum_disc_rewards[step] = trajec_rewards[step] + gamma * (cum_disc_rewards[step+1] if step+1 < trajec_len else 0)
    return cum_disc_rewards

In [26]:
ep_rewards = np.random.randint(-2, 6, size=(12,)).tolist()
ep_rewards

[5, 3, 0, -1, -1, 2, 0, -2, -1, 3, 5, 1]

In [27]:
accumulate_discount(ep_rewards)

[13.312010071311759,
 8.395969769001777,
 5.45047451414321,
 5.505529812265868,
 6.571242234611988,
 7.647719428900999,
 5.704767099899999,
 5.762391009999999,
 7.840798999999999,
 8.9301,
 5.99,
 1.0]

## Loss Function

The combination of chosen actions and weights sort of behaves like labels. This loss function is essentially just cross entropy loss (negative log likelihood loss), except the loss for each sample is weighted by the expected return at that timestep.

In [68]:
def loss_func(activations, chosen_acts, weights):
    '''
    outputs must be a tensor, shape is (~step_batch_size, n_acts), dtype is torch.float32
    chosen_acts must be a tensor, shape is (~step_batch_size) dtype is torch.int32
    weights must be a tensor, shape is (~step_batch_size) dtype is torch.float32
    '''
    assert activations.dtype == torch.float32
    assert chosen_acts.dtype == torch.int32
    assert weights.dtype == torch.float32
    log_probs = Categorical(logits=activations).log_prob(chosen_acts)  # Returns a batch of nll losses
        # log_prob does cross entropy loss (softmax --> take prob corres to chosen class --> log --> negative)
    return -(log_probs * weights).mean()

In [59]:
inputs = torch.randn(5, 4, dtype=torch.float32)
inputs

tensor([[-0.2520, -1.0268,  0.7331, -0.7374],
        [ 2.5440, -0.7746,  0.5921,  0.2919],
        [-0.9547, -0.1100,  1.3107, -2.3316],
        [-1.5341, -0.4752, -0.3159, -1.7549],
        [-1.1733,  0.8263, -0.3815,  0.4881]])

In [60]:
activations = net(inputs)
activations

tensor([[ 0.2991,  0.0791],
        [-0.5278, -0.1625],
        [ 0.5637,  0.3496],
        [ 0.8434,  0.3493],
        [ 0.2714,  0.3622]], grad_fn=<AddmmBackward>)

In [77]:
actions = np.random.randint(0, 2, size=(5,))
actions

array([0, 1, 0, 1, 1])

In [74]:
cum_disc_rewards = accumulate_discount(np.random.randint(-2, 6, size=(5,)).tolist())
cum_disc_rewards

[2.98010298, -0.02009800000000017, 0.9897999999999998, 3.02, -2.0]

In [78]:
batch_loss = loss_func(activations=torch.tensor(activations, dtype=torch.float32), 
                       chosen_acts=torch.tensor(actions, dtype=torch.int32), 
                       weights=torch.tensor(cum_disc_rewards, dtype=torch.float32)
                      )
batch_loss

  batch_loss = loss_func(activations=torch.tensor(activations, dtype=torch.float32),


tensor(0.7928)

## Training Loop

### Test forward propagation:

In [30]:
state = np.random.randn(4)
state

array([2.30437415, 0.7246936 , 1.55096719, 0.12279306])

Looks like the `nn.Sequential` object will take a single sample with no problem.

In [51]:
activations = net(torch.tensor(state, dtype=torch.float32))
activations

tensor([-0.6728,  0.1078], grad_fn=<AddBackward0>)

### Test action sampling:

`Categorical` exhibits weird behaviour...

`Categorical`'s `probs` argument takes in a tensor of 'probabilities' in range `[0, inf)` ie. non-negative but does not need to sum to 1, as the class will automatically normalize the values to make the distribution. Make sure to sigmoid or softmax activations before passing this argument.

`Categorical`'s `logits` argument takes a tensor of values in range `(-inf, inf)` and will turn it into a probability distribution that sums to 1, probably with softmax but idk.

In [48]:
probs_list = [0.25, 0.25, 0.21, 0.56]
dist = torch.distributions.categorical.Categorical(probs=torch.tensor(probs_list))
print(f"Normalize: {dist.probs}\nThen natural log: {dist.logits}")

Normalize: tensor([0.1969, 0.1969, 0.1654, 0.4409])
Then natural log: tensor([-1.6253, -1.6253, -1.7997, -0.8188])


In [49]:
logits_list = [-1.05, -0.15, 0.41, 1.20]
dist = torch.distributions.categorical.Categorical(logits=torch.tensor(logits_list))
print(f"Softmax of logits: {dist.probs}\nThen natural log: {dist.logits}")

Softmax of logits: tensor([0.0580, 0.1426, 0.2496, 0.5499])
Then natural log: tensor([-2.8480, -1.9480, -1.3880, -0.5980])


### Training loop:

In [92]:
def train(env_name='CartPole-v0', 
          hidden_sizes=[32], 
          lr=1e-2, 
          num_epochs=50, 
          step_batch_size=5000, 
          render=False
         ):
    
    env = gym.make(env_name)
    assert isinstance(env.observation_space, Box), \
        "This example only works for envs with continuous state spaces."
    assert isinstance(env.action_space, Discrete), \
        "This example only works for envs with discrete action spaces."
    
    obs_dim = env.observation_space.shape[0]
    n_acts = env.action_space.n
    
    net = mlp(sizes=[obs_dim]+hidden_sizes+[n_acts])
    
    optimizer = optim.Adam(net.parameters(), lr=lr)
    
    for epoch in range(1, num_epochs+1):
        
        # Epoch-specific variables, resets each epoch
        batch_states = []      # State at each step, shape is (num steps over all episodes this epoch ie. >= step_batch_size, obs_dim)
        batch_acts = []        # Action at each step, shape is (num steps over all episodes this epoch, n_acts)
        batch_weights = []     # Cumulative future discounted reward at each step, shape is (num steps over all episodes this epoch)
        batch_ep_rets = []     # Returns for each episode in epoch, shape is (num episodes this epoch)
        batch_ep_lens = []     # Lengths (number of steps) of each episode in epoch, shape is (num episodes this epoch)
        
        # Episode-specific variables, resets each episode
        cur_state = env.reset()
        done = False
        ep_rewards = []
        render_episode = True
        
        while True:
            
            activations = net(torch.tensor(cur_state, dtype=torch.float32))
            action = Categorical(logits=activations).sample().item()
            
            next_state, reward, done, _ = env.step(action)
            
            batch_states.append(cur_state.copy())
            batch_acts.append(action)
            ep_rewards.append(reward)
            
            cur_state = next_state
            
            if render_episode and render:
                env.render()
            
            if done:
                # If episode over record info about episode
                ep_ret, ep_len = sum(ep_rewards), len(ep_rewards)
                batch_ep_rets.append(ep_ret)
                batch_ep_lens.append(ep_len)
                
                batch_weights.extend(accumulate_discount(ep_rewards))
                
                # Reset episode-specific variables
                cur_state = env.reset()
                done = False
                ep_rewards = []        
                render_episode = False
                
                if len(batch_states) >= step_batch_size:
                    '''
                    We are only allowed to break at the end of an episode.
                    If at the end of this episode we finally have enough steps,
                        then we take this opportunity to break and call it an epoch.
                    '''
                    break

                    
        optimizer.zero_grad()
        batch_loss = loss_func(activations=net(torch.tensor(batch_states, dtype=torch.float32)), 
                               chosen_acts=torch.tensor(batch_acts, dtype=torch.int32), 
                               weights=torch.tensor(batch_weights, dtype=torch.float32)
                              )
        batch_loss.backward()
        optimizer.step()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (epoch, batch_loss, np.mean(batch_ep_rets), np.mean(batch_ep_lens)))

In [93]:
train()

epoch: 9999 	 loss: 9.801 	 return: 23.986 	 ep_len: 23.986
epoch: 9999 	 loss: 9.919 	 return: 24.544 	 ep_len: 24.544
epoch: 9999 	 loss: 11.491 	 return: 28.600 	 ep_len: 28.600
epoch: 9999 	 loss: 11.989 	 return: 30.303 	 ep_len: 30.303
epoch: 9999 	 loss: 13.136 	 return: 33.079 	 ep_len: 33.079
epoch: 9999 	 loss: 13.969 	 return: 40.008 	 ep_len: 40.008
epoch: 9999 	 loss: 15.235 	 return: 43.241 	 ep_len: 43.241
epoch: 9999 	 loss: 15.558 	 return: 47.781 	 ep_len: 47.781
epoch: 9999 	 loss: 16.142 	 return: 50.520 	 ep_len: 50.520
epoch: 9999 	 loss: 16.218 	 return: 52.905 	 ep_len: 52.905
epoch: 9999 	 loss: 16.700 	 return: 54.543 	 ep_len: 54.543
epoch: 9999 	 loss: 18.658 	 return: 63.087 	 ep_len: 63.087
epoch: 9999 	 loss: 19.156 	 return: 69.162 	 ep_len: 69.162
epoch: 9999 	 loss: 19.559 	 return: 74.044 	 ep_len: 74.044
epoch: 9999 	 loss: 20.407 	 return: 79.921 	 ep_len: 79.921
epoch: 9999 	 loss: 21.241 	 return: 82.426 	 ep_len: 82.426
epoch: 9999 	 loss: 21.964