In [1]:
# default_exp actorcritic.sac

In [2]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastai.callback.progress import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [3]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# SAC

> Soft Actor Critic

In [None]:
# export
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])


LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding 
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290) 
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, observation_space, action_space, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_space.shape[0]
        act_dim = action_space.shape[0]
        act_limit = action_space.high[0]

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

In [None]:
# export
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}



In [4]:
# export
class Critic(Module):
    def __init__(self, input_shape, n_actions):
        self.q=nn.Sequential(
            nn.Linear(input_shape[0]+n_actions, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self,x):
        fx=x.float()
        return self.q(fx)
    
class Actor(Module):
    def __init__(self, input_shape, n_actions):
        self.actions=nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def forward(self,x):
        fx=x.float()
        return self.actions(fx)

In [5]:
class SACModule(Module):
    def __init__(self, input_shape, n_actions):
        self.critic=Critic(input_shape,n_actions)
        self.critic_target=deepcopy(self.critic)
        
        self.actor=Actor(input_shape,n_actions)
        
        self.critic_mode=False
        
    def forward(self,*args):
        return self.actor(*args)
    
    def soft_copy(self,tau=0.99):
        for target_param, local_param in zip(self.critic_target.parameters(),self.critic.parameters()):
            target_param.data.copy_(tau*local_param.data+(1.0-tau)*target_param.data)


In [6]:
# export
class SACTrainer(Callback):
    def __init__(self):     self.batch_n=0
    def before_fit(self):   self.batch_n=0
    def after_batch(self):  
        if self.batch_n%self.soft_copy_freq==0: self.model.soft_copy()
        self.batch_n+=1            
            
    def after_backward(self):
        nn_utils.clip_grad_norm_(self.learn.model.parameters(),self.learn.clip_grad)

In [None]:
# export
# Set up function for computing SAC Q-losses
def compute_loss_q(data):
    o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

    q1 = ac.q1(o,a)
    q2 = ac.q2(o,a)

    # Bellman backup for Q functions
    with torch.no_grad():
        # Target actions come from *current* policy
        a2, logp_a2 = ac.pi(o2)

        # Target Q-values
        q1_pi_targ = ac_targ.q1(o2, a2)
        q2_pi_targ = ac_targ.q2(o2, a2)
        q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
        backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

    # MSE loss against Bellman backup
    loss_q1 = ((q1 - backup)**2).mean()
    loss_q2 = ((q2 - backup)**2).mean()
    loss_q = loss_q1 + loss_q2

    return loss_q, q_info

# Set up function for computing SAC pi loss
def compute_loss_pi(data):
    o = data['obs']
    pi, logp_pi = ac.pi(o)
    q1_pi = ac.q1(o, pi)
    q2_pi = ac.q2(o, pi)
    q_pi = torch.min(q1_pi, q2_pi)

    # Entropy-regularized policy loss
    loss_pi = (alpha * logp_pi - q_pi).mean()

    return loss_pi, pi_info


In [9]:
# export
def loss_func(pred,yb,learn):

    q_optimizer.zero_grad()
    loss_q, q_info = compute_loss_q(data)
    loss_q.backward()
    q_optimizer.step()

    # Record things
    logger.store(LossQ=loss_q.item(), **q_info)

    # Freeze Q-networks so you don't waste computational effort 
    # computing gradients for them during the policy learning step.
    for p in q_params:
        p.requires_grad = False

    # Next run one gradient descent step for pi.
    pi_optimizer.zero_grad()
    loss_pi, pi_info = compute_loss_pi(data)
    loss_pi.backward()
    pi_optimizer.step()

    # Unfreeze Q-networks so you can optimize it at next DDPG step.
    for p in q_params:
        p.requires_grad = True

    # Record things
    logger.store(LossPi=loss_pi.item(), **pi_info)

    # Finally, update target networks by polyak averaging.
    with torch.no_grad():
        for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
            # NB: We use an in-place operations "mul_", "add_" to update target
            # params, as opposed to "mul" and "add", which would make new tensors.
            p_targ.data.mul_(polyak)
            p_targ.data.add_((1 - polyak) * p.data)

    return loss_v

class SACLearner(AgentLearner):
    def __init__(self,dls,actor_critic=core.MLPActorCritic, ac_kwargs=dict(), seed=0, 
                    steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99, 
                    polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000, 
                    update_after=1000, update_every=50, max_ep_len=1000, 
                    logger_kwargs=dict(), save_freq=1,**kwargs):
        store_attr()
        super().__init__(dls,loss_func=partial(loss_func,learn=self),**kwargs)

        
#     def _split(self, b):
#         if len(b)==1 and type(b[0])==tuple:b=b[0]
#         super()._split(b)

In [16]:
class ContinousActorCriticAgent(BaseAgent):
    a_selector: fastrl.basic_agents.ActionSelector = None,
    device: str = None,
    preprocessor: Callable = default_states_preprocessor,
    apply_softmax: bool = False
        
    def __post_init__(self):
        self.a_selector=ifnone(self.a_selector,ProbabilityActionSelector())
        self.apply_softmax=True
    
    @torch.no_grad()
    def __call__(self,sl,asl,include_batch_dim=False):
        x=self.preprocessor(x) if self.preprocessor is not None else s
        asl=np.zeros(x.shape) if asl is None or len(asl)==0 else asl
        if torch.is_tensor(x):
            x=x.to(self.device)
        actions_dist=self.model.actor(x)
        

In [10]:
env='CartPole-v1'
model=SACModule((4,),2).to(device=default_device())
agent=PolicyAgent(model=model,device=default_device())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=4,dls_kwargs={'bs':128,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
#                       batch_tfms=lambda x:(x['s'],x),
                     )
dls=blk.dataloaders([env]*15,n=128*100,device=default_device())

learner=SACLearner(dls,agent=agent,cbs=[SACTrainer],metrics=[AvgEpisodeRewardMetric()])
learner.fit(400,lr=0.001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,0.0,5.25,00:00,,


TypeError: loss_func() got multiple values for argument 'learn'

In [None]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html(n_workers=0)