In [None]:
# default_exp actorcritic.a2c

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastrl.ptan_extension import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# A2C

> Synchronous Actor Critic

In [None]:
# export
class LinearA2C(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LinearA2C, self).__init__()

        self.policy = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        self.value = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self,x):
        fx=x.float()
        return self.policy(fx),self.value(fx)

In [None]:
# export
def unbatch(batch, net, val_gamma,device='cpu'):
    states = []
    actions = []
    rewards = []
    not_done_idx = []
    last_states = []
    for idx, exp in enumerate(batch):
#         print(exp.state.numpy().shape,int(exp.action),float(exp.reward),exp.last_state.numpy().shape if not bool(exp.done) else None,exp.done)
        states.append(np.array(exp.state.cpu().detach().numpy(), copy=False))
        actions.append(int(exp.action.cpu().detach()))
        rewards.append(float(exp.reward.cpu().detach()))
        if not exp.done:
            not_done_idx.append(idx)
            last_states.append(np.array(exp.last_state.cpu().detach().numpy(), copy=False))
    states_v = torch.FloatTensor(states).to(device)
    actions_t = torch.LongTensor(actions).to(device)
    # handle rewards
    rewards_np = np.array(rewards, dtype=np.float32)
    if not_done_idx:
        last_states_v = torch.FloatTensor(last_states).to(device)
        last_vals_v = net(last_states_v)[1]
        last_vals_np = last_vals_v.data.cpu().numpy()[:, 0]
        rewards_np[not_done_idx] += val_gamma * last_vals_np

    ref_vals_v = torch.FloatTensor(rewards_np).to(device)
    return states_v, actions_t, ref_vals_v

In [None]:
# export
class A2CTrainer(Callback):
    
    def after_backward(self):
#         print('clipping',self.learn.clip_grad,np.mean([o.detach().numpy().mean() for o in self.learn.model.parameters()]))
        nn_utils.clip_grad_norm_(self.learn.model.parameters(),self.learn.clip_grad)
        
    def after_step(self):
        self.learn.loss+=self.learn.loss_policy_v

In [None]:
# export
def loss_func(pred,a,r,sp,d,episode_rewards,learn=None):
    if type(learn.yb[0][0])!=ExperienceFirstLast:
        bs=len(learn.xb[0])
        yb=[]
        for i in range(bs):
            yb.append(ExperienceFirstLast(state=learn.xb[0][i],action=a[i],reward=r[i],last_state=sp[i],done=d[i],episode_reward=0))
    else:
        bs=len(learn.yb)
        yb=learn.yb[0]
    
    s_t,a_t,r_est=unbatch(yb,learn.model,learn.discount**learn.reward_steps,default_device())
    
    learn.opt.zero_grad()
    logits_v,value_v=learn.model(s_t)

    loss_value_v=F.mse_loss(value_v.squeeze(-1),r_est)
#     loss_value_v=F.mse_loss(value_v,r_est)

    log_prob_v=F.log_softmax(logits_v,dim=1)
    adv_v=r_est-value_v.squeeze(-1).detach()

    log_prob_actions_v=adv_v*log_prob_v[range(bs),a_t]
    loss_policy_v=-log_prob_actions_v.mean()

    prob_v=F.softmax(logits_v,dim=1)
    entropy_loss_v=learn.entropy_beta*(prob_v*log_prob_v).sum(dim=1).mean()
    
    # calculate the polocy gradients only
    loss_policy_v.backward(retain_graph=True)
    

    loss_v=entropy_loss_v+loss_value_v
    
    setattr(learn,'loss_policy_v',loss_policy_v)
    return loss_v

class A2CLearner(AgentLearner):
    def __init__(self,dls,discount=0.99,entropy_beta=0.01,clip_grad=0.1,reward_steps=1,**kwargs):
        super().__init__(dls,loss_func=partial(loss_func,learn=self),**kwargs)
        self.opt=OptimWrapper(AdamW(self.model.parameters(),eps=1e-3))
        self.discount=discount
        self.entropy_beta=entropy_beta
        self.reward_steps=reward_steps
        self.clip_grad=clip_grad

It is important to note that A2C without extra augmentation will lose memoiry very quickly. If it succeeds at 200+ reward, it will eventually **forget** what strategies got it there and you will see the graph for rewards drop back down. It is recommended to use some kind of early stopping for rewards. The more fun solution would be expermenting with ways to keep the agent stable possibly by keeping samples that it deemed as "important".

In [None]:
env='CartPole-v1'
model=LinearA2C((4,),2)
agent=ActorCriticAgent(model=model.to(default_device()),device=default_device())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=4,dls_kwargs={'bs':128,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
#                       batch_tfms=lambda x:(x['s'],x),
                     )
dls=blk.dataloaders([env]*15,n=128*100,device=default_device())

learner=A2CLearner(dls,agent=agent,cbs=[A2CTrainer],reward_steps=4,metrics=[AvgEpisodeRewardMetric()])
print('model start',np.mean([o.cpu().detach().numpy().mean() for o in learner.model.parameters()]))
learner.fit(10,lr=0.001,wd=0)

model start 0.0030695


epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,17.594112,26.006762,,26.006762,00:12
1,26.896385,27.03896,,27.03896,00:12
2,31.740593,50.093048,,50.093048,00:12
3,49.083435,81.3525,,81.3525,00:13
4,53.767052,113.584167,,113.584167,00:11
5,72.958267,143.551667,,143.551667,00:11
6,78.430984,162.593333,,162.593333,00:11
7,84.090927,160.688333,,160.688333,00:11
8,98.564392,182.206667,,182.206667,00:11
9,110.922218,192.848333,,192.848333,00:11


  warn("Your generator is empty.")


In [None]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html()

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_learner.ipynb.
Converted 05a_data.ipynb.
Converted 05b_async_data.ipynb.
Converted 13_metrics.ipynb.
Converted 14_actorcritic.sac.ipynb.
Converted 15_actorcritic.a3c_data.ipynb.
Converted 16_actorcritic.a2c.ipynb.
Converted index.ipynb.
Converted notes.ipynb.


converting: /opt/project/fastrl/nbs/14_actorcritic.sac.ipynb
converting: /opt/project/fastrl/nbs/16_actorcritic.a2c.ipynb


In [None]:
import torch
if torch.cuda.is_available():
    %reset -f
    import IPython
    app = IPython.Application.instance()
    app.kernel.do_shutdown(True) 