In [1]:
# Soft Actor Critic
# https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/sac

In [2]:
import numpy as np
import itertools
from copy import deepcopy
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.optim import Adam

LOG_STD_MAX = 2
LOG_STD_MIN = -20

In [3]:
def combined_shape(length, shape=None):
    if shape is None: 
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def layers(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers) 

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

In [4]:
class SquashedGaussianMLPActor(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = layers([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim) 
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)

        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding 
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None
        
        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi
    
class MLPQFunction(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = layers([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.
    
class MLPActorCritic(nn.Module):
    def __init__(self, observation_space, action_space, hidden_sizes=(256,256), activation=nn.ReLU):
        super().__init__()
        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

In [5]:
class ReplayBuffer:
    '''A simple FIFO experience replay buffer for SAC agents.'''
    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.terminated_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, terminated):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.terminated_buf[self.ptr] = terminated
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     terminated=self.terminated_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}

In [6]:
def sac(env_name='HalfCheetah-v4', 
        actor_critic=MLPActorCritic,
        steps_per_epoch=5000, 
        epochs=50, 
        replay_size=int(1e6), 
        gamma=0.99, 
        polyak=0.995, 
        pi_lr=0.001, 
        q_lr=0.001, 
        alpha=0.2,
        batch_size=100, 
        start_steps=10000, 
        update_after=1000, 
        update_every=50, 
        num_test_episodes=10, 
        max_ep_len=1000):

    env = gym.make(env_name)
    test_env = gym.make(env_name)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]
    
    ac = actor_critic(env.observation_space, env.action_space)
    ac_targ = deepcopy(ac)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q networks
    q_params = itertools.chain(ac.q1.parameters(), ac.q2.parameters())

    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n'%var_counts)

    def compute_loss_q(data):
        o, a, r, o2, terminated = data['obs'], data['act'], data['rew'], data['obs2'], data['terminated']

        q1 = ac.q1(o,a)
        q2 = ac.q2(o,a)
        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from CURRENT policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - terminated) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2
        return loss_q

    def compute_loss_pi(data):
        o = data['obs']
        pi, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, pi)
        q2_pi = ac.q2(o, pi)
        q_pi = torch.min(q1_pi, q2_pi)
        loss_pi = (alpha * logp_pi - q_pi).mean() # Entropy-regularized policy loss
        return loss_pi
    
    # Optimizers
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    q_optimizer = Adam(q_params, lr=q_lr)

    def update(data):
        # First run one gradient descent step for Q1 and Q2.
        q_optimizer.zero_grad()
        loss_q = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()
        
        # Freeze Q-network so you don't waste computational effort 
        # computing gradients for it during the policy learning step.
        for p in q_params:
            p.requires_grad = False
        
        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-network so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32), deterministic)
    
    def test_agent():
        ep_rets = []
        for _ in range(num_test_episodes):
            o, _ = test_env.reset()
            terminated = False
            ep_ret, ep_len = 0, 0
            while not(terminated or (ep_len == max_ep_len)):
                # Take deterministic actions at test time (noise_scale=0)
                o, r, terminated, _, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            ep_rets.append(ep_ret)
        return ep_rets

    total_steps = steps_per_epoch * epochs
    o, _ = env.reset()
    epoch_n = 0
    ep_len = 0
    epoch_ret = []

    for t in range(total_steps):
        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards, 
        # use the learned policy
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.action_space.sample()

        o2, r, terminated, _, _ = env.step(a)
        ep_len += 1
        replay_buffer.store(o, a, r, o2, terminated)
        o = o2

        if terminated or (ep_len==max_ep_len):
            o, _ = env.reset()
            ep_len = 0

        if t >= update_after and t % update_every == 0:
            for _ in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)

        if (t+1) % steps_per_epoch == 0:
            epoch_n += 1
            ep_rets = test_agent()
            epoch_ret.append(np.mean(ep_rets))
            print('Epoch: %3d \t Mean epoch return %.3f \t '% (epoch_n, epoch_ret[-1]))

    return ac, epoch_ret

In [None]:
# Run experiment

ac, epoch_ret = sac() # steps_per_epoch=5000, epochs=50 -> 250,000 TotalEnvInteracts

In [8]:
# Visualisation

env = gym.make('HalfCheetah-v4', render_mode='rgb_array')
o, _ = env.reset() 

frames = []
terminated = False
truncated = False
while not (terminated or truncated):
    frame = env.render()
    a = ac.act(torch.as_tensor(o, dtype=torch.float32), False)
    next_o, _, terminated, truncated, _  = env.step(a)
    o = next_o
    frames.append(frame)

from PIL import Image
def create_gif(frames, filename='sac.gif'):
    images = [Image.fromarray(frame) for frame in frames]
    images[0].save(filename, save_all=True, append_images=images[1:], optimize=False, duration=1, loop=0)

create_gif(frames)