In [1]:
#| default_exp vpg_core

In [1]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

### MLP Generator

In [3]:
#| export
def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    
    # for n sizes, there will be n-1 layers. A layer connect two sizes
    # Note: A size here refers to the input, hidden or output activations in the MLP network 
    for layer in range(len(sizes)-1):
        # use the output_activation when creating the last layer which is at index n-2
        act = activation if layer < (len(sizes)-2) else output_activation
        
        # create a layer connecting this size and the next size, and add its activation
        layers += [nn.Linear(sizes[layer], sizes[layer+1]), act()]
    return nn.Sequential(*layers)

In [4]:
hidden_sizes = [32, 32]
in_size = 4
out_size = 1

model = mlp([in_size]+hidden_sizes+[out_size], nn.ReLU, nn.Identity)
model

Sequential(
  (0): Linear(in_features=4, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=32, bias=True)
  (3): ReLU()
  (4): Linear(in_features=32, out_features=1, bias=True)
  (5): Identity()
)

In [5]:
[p.shape for p in model.parameters()]

[torch.Size([32, 4]),
 torch.Size([32]),
 torch.Size([32, 32]),
 torch.Size([32]),
 torch.Size([1, 32]),
 torch.Size([1])]

In [6]:
sum([np.prod(p.shape) for p in model.parameters()])

1249

### Count Params in a Model

In [7]:
#| export
def count_params(model):
    return sum([np.prod(p.shape) for p in model.parameters()])

In [8]:
count_params(model)

1249

### Get Combined Shape of a Batch of Inputs

In [9]:
#| export
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

In [10]:
combined_shape(300, [4])

(300, 4)

### Discounted Cumulative Sum

In [11]:
#| export
import scipy.signal

def discount_cumsum(x, discount):
    """
    magic from rllab for computing discounted cumulative sums of vectors.

    input: 
        vector x, 
        [x0, 
         x1, 
         x2]

    output:
        [x0 + discount * x1 + discount^2 * x2,  
         x1 + discount * x2,
         x2]
    """
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]

In [12]:
discount_cumsum(np.ones(5), 0.9)

array([4.0951, 3.439 , 2.71  , 1.9   , 1.    ])

### Actors

#### Base Actor Class

In [13]:
#| export
class Actor(nn.Module):
    def _distribution(self, obs):
        raise NotImplementedError
        
    def _log_prob_from_distribution(self, pi, act):
        raise NotImplementedError
        
    def forward(self, obs, act=None):
        pi = self._distribution(obs)
        log_prob = None
        if act is not None:
            log_prob = self._log_prob_from_distribution(pi, act)
        return pi, log_prob        

#### MLP Categorical Actor

In [14]:
#| export
from torch.distributions.categorical import Categorical

class MLPCategoricalActor(Actor):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.logits_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
        
    def _distribution(self, obs):
        logits = self.logits_net(obs)
        return Categorical(logits=logits)
        
    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act)

In [15]:
actor = MLPCategoricalActor(obs_dim=16, act_dim=2, hidden_sizes=[64, 64], activation=nn.ReLU)
actor

MLPCategoricalActor(
  (logits_net): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
    (5): Identity()
  )
)

In [16]:
actor(torch.randn(16)), actor(torch.randn(128, 16))

((Categorical(logits: torch.Size([2])), None),
 (Categorical(logits: torch.Size([128, 2])), None))

In [17]:
act = actor(torch.randn(16))[0].sample()
act

tensor(0)

In [18]:
actor(torch.randn(16), torch.tensor([0]))

(Categorical(logits: torch.Size([2])),
 tensor([-0.7752], grad_fn=<SqueezeBackward1>))

#### MLP Guassian Actor

In [19]:
#| export
from torch.distributions.normal import Normal

class MLPGuassianActor(Actor):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        log_std = -0.5*np.ones(act_dim, dtype=np.float32)
        self.log_std = nn.Parameter(torch.as_tensor(log_std))
        self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)
        
    def _distribution(self, obs):
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)
    
    def _log_prob_from_distribution(self, pi, act):
        # act has N dimensions. To get the log_prob of action `act`, sum the 
        # log_prob of each of the dim of act
        return pi.log_prob(act).sum(axis=-1)

In [20]:
actor = MLPGuassianActor(obs_dim=16, act_dim=2, hidden_sizes=[64, 64], activation=nn.ReLU)
actor

MLPGuassianActor(
  (mu_net): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
    (5): Identity()
  )
)

In [21]:
actor(torch.randn(16)), actor(torch.randn(16), torch.tensor([0.2, 0.25])) 

((Normal(loc: torch.Size([2]), scale: torch.Size([2])), None),
 (Normal(loc: torch.Size([2]), scale: torch.Size([2])),
  tensor(-1.1237, grad_fn=<SumBackward1>)))

### Critics

#### MLP Critic

In [22]:
#| export
class MLPCritic(nn.Module):
    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)
        
    def forward(self, obs):
        # reshape the output from (batch, 1) to (batch)
        return torch.squeeze(self.v_net(obs), dim=-1)

In [23]:
critic = MLPCritic(obs_dim=16, hidden_sizes=[64,64], activation=nn.ReLU)
critic

MLPCritic(
  (v_net): Sequential(
    (0): Linear(in_features=16, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=1, bias=True)
    (5): Identity()
  )
)

In [24]:
critic(torch.randn(8, 16))

tensor([ 0.0894,  0.0939, -0.0332, -0.1243, -0.0339, -0.0149, -0.0384,  0.0430],
       grad_fn=<SqueezeBackward1>)

### ActorCritic

In [25]:
#| export
import gym

class MLPActorCritic(nn.Module):
    def __init__(self, observation_space, action_space, 
                 hidden_sizes=[64, 64], 
                 activation=nn.Tanh):
        super().__init__()
        
        obs_dim = observation_space.shape[0]
        
        # create a categorical actor for action space of type spaces.Discrete
        # and a guassian actor for action space of type spaces.Box
        if isinstance(action_space, gym.spaces.Discrete):
            self.pi = MLPCategoricalActor(obs_dim, action_space.n, hidden_sizes, activation)
        elif isinstance(action_space, gym.spaces.Box):
            self.pi = MLPGuassianActor(obs_dim, action_space.shape[0], hidden_sizes, activation)
            
        # create the value function network
        self.v = MLPCritic(obs_dim, hidden_sizes, activation)
        
    def step(self, obs):
        # step is computed with no grad
        with torch.no_grad():
            # get the pi distribution for this observation
            pi = self.pi._distribution(obs)
            # sample actions from pi
            act = pi.sample()
            # get the log_prob of this action
            log_prob = self.pi._log_prob_from_distribution(pi, act)
            # get the value of this state
            v = self.v(obs)
        return act.numpy(), v.numpy(), log_prob.numpy()
    
    def act(self, obs):
        """Get only the actions for this observation"""
        return self.step(obs)[0]        

In [26]:
env = gym.make("CartPole-v0")
env.observation_space, env.action_space



(Box(4,), Discrete(2))

In [27]:
actor_critic = MLPActorCritic(env.observation_space, env.action_space)
actor_critic

MLPActorCritic(
  (pi): MLPCategoricalActor(
    (logits_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
      (4): Linear(in_features=64, out_features=2, bias=True)
      (5): Identity()
    )
  )
  (v): MLPCritic(
    (v_net): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
      (4): Linear(in_features=64, out_features=1, bias=True)
      (5): Identity()
    )
  )
)

In [28]:
obs = env.observation_space.sample()
obs

array([ 6.9627032e-02, -2.1017475e+38,  4.0349108e-01, -1.0240234e+38],
      dtype=float32)

In [29]:
act, val, log_prob = actor_critic.step(torch.as_tensor(obs, dtype=torch.float32))
act, val, log_prob

(array(0), array(0.41797066, dtype=float32), array(-0.4955001, dtype=float32))

In [30]:
#| hide
# import nbdev

# nbdev.nbdev_export()