In [4]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete, Box

In [15]:
def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    # feed forward network
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers.append(nn.Linear(sizes[j], sizes[j + 1]))
        layers.append(act())
    return nn.Sequential(*layers)

In [142]:
def train(
    env="CartPole-v0",
    hidden_sizes=[32],
    lr=1e-02,
    epochs=50,
    batch_size=5000,
    render=False,
    use_reward_to_go=False,
):
    env = gym.make(env)
    assert isinstance(env.observation_space, Box), "This example works only for envs with continuous state spaces"
    assert isinstance(env.action_space, Discrete), "This example works only for envs with discrete action space"

    observation_dims = env.observation_space.shape[0]
    n_actions = env.action_space.n
 
    logits_net = mlp(sizes=[observation_dims] + hidden_sizes + [n_actions])

    def reward_to_go(rewards):
        n = len(rewards)
        rtgs = np.zeros_like(rewards)
        for i in reversed(range(n)):
            rtgs[i] = rewards[i] + (rtgs[i + 1] if i + 1 < n else 0)
        return rtgs

    # creates an action probability distribution \pi(a|s) - categorical because discrete action space
    def get_policy(obs):
        logits = logits_net(obs)
        # logits can be passed as logits to nn.Categorical and don't have to be softmaxed
        return Categorical(logits=logits)

    # get an action by sampling from policy \pi(a|s)
    # only one observation at a time, only one action returned at a time
    # .item() is used to retrieve a value from Tensor contaning a single value
    def get_action(obs):
        return get_policy(obs).sample().item()


    # make loss function whose gradient, for the right data, is policy gradient
    # when weights are equal to episode returns -> loss gradient = policy gradient
    # see the equation for sample policy gradient https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html
    # !! it's not an acutall loss function from supervised training
    def compute_loss(obs, actions, weights):
        # get the log probability of an action which is drawn from a policy `get_policy(obs)`
        logp = get_policy(obs).log_prob(actions)
        return -(logp * weights).mean()

    optimizer = Adam(logits_net.parameters(), lr=lr)
    
    # for training policy
    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_actions = []      # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_returns = []      # for measuring episode returns
        batch_lens = []         # for measuring episode lengths
    
        # reset episode-specific variables
        obs, _ = env.reset()       # first obs comes from starting distribution
        done = False            # signal from environment that episode is over
        ep_rewards = []         # list for rewards accrued throughout episode

        finished_rendering_this_epoch = False
        # collect experience by acting in the environment with the current policy

        while True:
    
            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()
    
            # save observations
            batch_obs.append(obs.copy())
            # act in the environment
            action = get_action(torch.as_tensor(obs, dtype=torch.float32))
            obs, reward, done, _, _ = env.step(action)
    
            batch_actions.append(action)
            ep_rewards.append(reward)

            if done:
                # if episode is over, record info about the episode
                ep_return, ep_len = sum(ep_rewards), len(ep_rewards)
                batch_returns.append(ep_return)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                if use_reward_to_go:
                    batch_weights += list(reward_to_go(ep_rewards))
                else:
                    batch_weights += [ep_return] * ep_len

                # reset episode specific variables
                (obs, _), done, ep_rewards = env.reset(), False, []
                # wont render again this epoch
                finished_rendering_this_epoch = True

                # print(batch_obs)
                # print(batch_actions)
                # print(batch_weights)
                # print(len(batch_obs), len(batch_actions), len(batch_weights))
                # raise Exception()

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break
    

        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(
            obs=torch.as_tensor(batch_obs, dtype=torch.float32),
            actions=torch.as_tensor(batch_actions, dtype=torch.int32),
            weights=torch.as_tensor(batch_weights, dtype=torch.float32)
        )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_returns, batch_lens

    for i in range(epochs):
        batch_loss, batch_returns, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_returns), np.mean(batch_lens)))

In [58]:
train(render=True)

epoch:   0 	 loss: 18.793 	 return: 22.545 	 ep_len: 22.545
epoch:   1 	 loss: 27.058 	 return: 30.311 	 ep_len: 30.311
epoch:   2 	 loss: 24.918 	 return: 29.645 	 ep_len: 29.645
epoch:   3 	 loss: 29.556 	 return: 36.065 	 ep_len: 36.065
epoch:   4 	 loss: 33.602 	 return: 41.397 	 ep_len: 41.397
epoch:   5 	 loss: 37.174 	 return: 44.561 	 ep_len: 44.561
epoch:   6 	 loss: 36.712 	 return: 47.358 	 ep_len: 47.358
epoch:   7 	 loss: 39.574 	 return: 51.051 	 ep_len: 51.051
epoch:   8 	 loss: 40.138 	 return: 53.213 	 ep_len: 53.213
epoch:   9 	 loss: 39.717 	 return: 52.758 	 ep_len: 52.758
epoch:  10 	 loss: 43.056 	 return: 56.539 	 ep_len: 56.539
epoch:  11 	 loss: 40.463 	 return: 56.088 	 ep_len: 56.088
epoch:  12 	 loss: 40.984 	 return: 59.140 	 ep_len: 59.140
epoch:  13 	 loss: 41.284 	 return: 58.419 	 ep_len: 58.419
epoch:  14 	 loss: 39.991 	 return: 56.629 	 ep_len: 56.629
epoch:  15 	 loss: 39.541 	 return: 56.989 	 ep_len: 56.989
epoch:  16 	 loss: 44.369 	 return: 63.9

In [76]:
train(use_reward_to_go=True)

epoch:   0 	 loss: 10.597 	 return: 22.484 	 ep_len: 22.484
epoch:   1 	 loss: 13.204 	 return: 27.794 	 ep_len: 27.794
epoch:   2 	 loss: 11.977 	 return: 26.898 	 ep_len: 26.898
epoch:   3 	 loss: 14.334 	 return: 31.879 	 ep_len: 31.879
epoch:   4 	 loss: 14.291 	 return: 32.172 	 ep_len: 32.172
epoch:   5 	 loss: 18.655 	 return: 40.516 	 ep_len: 40.516
epoch:   6 	 loss: 16.793 	 return: 39.016 	 ep_len: 39.016
epoch:   7 	 loss: 19.588 	 return: 44.804 	 ep_len: 44.804
epoch:   8 	 loss: 20.537 	 return: 46.343 	 ep_len: 46.343
epoch:   9 	 loss: 21.589 	 return: 52.031 	 ep_len: 52.031
epoch:  10 	 loss: 20.967 	 return: 52.188 	 ep_len: 52.188
epoch:  11 	 loss: 20.233 	 return: 53.564 	 ep_len: 53.564
epoch:  12 	 loss: 23.753 	 return: 63.175 	 ep_len: 63.175
epoch:  13 	 loss: 27.168 	 return: 70.620 	 ep_len: 70.620
epoch:  14 	 loss: 30.306 	 return: 79.429 	 ep_len: 79.429
epoch:  15 	 loss: 30.020 	 return: 84.017 	 ep_len: 84.017
epoch:  16 	 loss: 32.824 	 return: 83.5

In [111]:
def get_policy(obs):
    # logits can be passed as logits to nn.Categorical and don't have to be softmaxed
    return Categorical(logits=torch.as_tensor(obs))

In [137]:
obs = np.random.randn(20, 5)
actions = np.zeros(20,)

In [139]:
get_policy(obs).log_prob(torch.as_tensor(actions)).shape

torch.Size([20])