In [None]:
# default_exp a3c.a3c_data
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# A3C Data

> A decoupled actor critic agent which trains on data collected from environments running in a completely separate process.

In [None]:
# export
# from fastai.basic_data import *
import torch.nn.utils as nn_utils
from fastai.torch_core import *
from fastai.callbacks import *
from fastrl.wrappers import *
from fastrl.basic_agents import *
from fastrl.basic_train import *
from fastrl.data_block import *
from fastrl.metrics import *
from fastai.basic_train import *
from dataclasses import asdict
from functools import partial
from fastprogress.fastprogress import IN_NOTEBOOK
from fastcore.utils import *
import torch.multiprocessing as mp
from queue import Empty
import textwrap
import logging
import gym

logging.basicConfig(format='[%(asctime)s] p%(process)s line:%(lineno)d %(levelname)s - %(message)s',
                    datefmt='%m-%d %H:%M:%S')
_logger=logging.getLogger(__name__)

In [None]:
# hide
_logger.setLevel('INFO')
from fastcore.foundation import *
import sys

In [None]:
# export
# @safe_fit
def a3c_data_fitter(model,agent,ds,data_queue,pause_event,
                    cancel_event,metric_queue):
    dataset=ds()
    while not cancel_event.is_set():
        for xb,yb in dataset:
            data_queue.put(yb)
            if pause_event.is_set():cancel_event.wait(0.1)
            if cancel_event.is_set():break
            
            if metric_queue is not None:
                rs=dataset.pop_total_r()
                if len(rs)!=0:metric_queue.put(TotalRewards(np.mean(rs)))
                    
            if cancel_event.is_set():break
        if cancel_event.is_set():break

@dataclass
class A3CLearner(AgentLearner):
    fitter:Callable=a3c_data_fitter
    batch_sz:int=128
    discount:float=0.99
    entropy_beta:float=0.01
    clip_grad:float=0.1
        
    def __post_init__(self):
        super(A3CLearner,self).__post_init__()
        if self.model is None:self.model=self.agent.model
        if self.agent.model is None: self.agent.model=self.model
        self.model.share_memory()
        
    def predict(self,s):
        out=self.model(s)
        if type(out)==tuple:return out[0]
        return out

In [None]:
batch=[
 Experience(s=tensor([[-0.0285,  0.1640, -0.0033, -0.3421]]),sp=tensor([[-0.0285,  0.1640, -0.0033, -0.3421]]),
            a=tensor([1]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0252, -0.0311, -0.0101, -0.0504]]),sp=tensor([[-0.0252, -0.0311, -0.0101, -0.0504]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0258, -0.2261, -0.0111,  0.2391]]),sp=tensor([[-0.0258, -0.2261, -0.0111,  0.2391]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0517, -0.2260,  0.0195,  0.2377]]),sp=tensor([[-0.0517, -0.2260,  0.0195,  0.2377]]),
            a=tensor([1]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0562, -0.4214,  0.0242,  0.5365]]),sp=tensor([[-0.0562, -0.4214,  0.0242,  0.5365]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0647, -0.6169,  0.0349,  0.8367]]),sp=tensor([[-0.0647, -0.6169,  0.0349,  0.8367]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([1.]),agent_s=tensor([[[0.]]]))
]
class LinearA2C(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LinearA2C, self).__init__()

        self.policy = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        self.value = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def _get_conv_out(self, shape):
        o=self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        fx=x.float()
        return self.policy(fx),self.value(fx)
model=LinearA2C((4,),2)

In [None]:
def r_estimate(s,r,d_mask,model,val_gamma,device):
    "Returns rewards `r` estimated direction by `model` from states `s`"
    r_np=np.array(r,dtype=np.float32)
    if d_mask:
        s_v=torch.FloatTensor(s).to(device)
        v=model(s_v)[1] # Remember that models are going to return the actions and the values
        v_np=v.data.cpu().numpy()[:,0]
        r_np[d_mask]+=val_gamma*v_np
    return r_np

def unbatch(batch,model,last_val_gamma,device='cpu')->Tuple(List,List,List):
    s,a,r,d_mask,sp=[],[],[],[],[]
    for i,exp in enumerate(batch):
        s.append(exp.s.numpy())
        print(exp.s.shape)
        a.append(int(exp.a.numpy())) # TODO can we change this to toggle between discrete and continuous actions?
        r.append(exp.r.numpy().astype(np.float32))
        if int(exp.d)==0:
            d_mask.append(i)
            sp.append(exp.sp.numpy())
    s_t=torch.FloatTensor(s).to(device)
    a_t=torch.LongTensor(a).to(device)
    
    r_np=r_estimate(sp,r,d_mask,model,last_val_gamma,device)
    estimated_r=torch.FloatTensor(r_np).to(device)
    return s_t.squeeze(1),a_t,estimated_r.squeeze(1)

In [None]:
unbatch(batch,model,2)

(tensor([[-0.0285,  0.1640, -0.0033, -0.3421],
         [-0.0252, -0.0311, -0.0101, -0.0504],
         [-0.0258, -0.2261, -0.0111,  0.2391],
         [-0.0517, -0.2260,  0.0195,  0.2377],
         [-0.0562, -0.4214,  0.0242,  0.5365],
         [-0.0647, -0.6169,  0.0349,  0.8367]]),
 tensor([1, 0, 0, 1, 0, 0]),
 tensor([1.6767, 1.8241, 1.9454, 1.9464, 2.0984, 1.0000]))

In [None]:
# export
class A3CTrainer(LearnerCallback):
    def __init__(self,*args,**kwargs):
        super(A3CTrainer,self).__init__(*args,**kwargs)
        self.batch=[]
        
    @property
    def skip_process_batch(self):return len(self.batch)<self.learn.batch_sz
    
    def on_train_begin(self,**kwargs):
        self.batch.clear()
    
    def on_batch_begin(self,last_target,**kwargs):
        self.batch.extend([Experience(**o) for o in last_target])
        
    def on_backward_begin(self,last_loss,**kwargs):
        if self.skip_process_batch:return {'skip_bwd':self.skip_process_batch}
        
        s_t,a_t,r_est=unbatch(self.batch,self.learn.model,self.learn.discount**self.learn.data.steps)
        
        logits_v,value_v=self.learn.model(s_t)
        loss_value_v=F.mse_loss(value_v.squeeze(-1),r_est)
        log_prob_v=F.log_softmax(logits_v,dim=1)
        adv_v=r_est-value_v.detach()
        log_prob_actions_v=adv_v*log_prob_v[range(self.learn.batch_sz),a_t]
        loss_policy_v=-log_prob_actions_v.mean()
        
        prob_v=F.softmax(logits_v,dim=1)
        entropy_loss_v=self.learn.entropy_beta*(prob_v*log_prob_v).sum(dim=1).mean()
        
        loss_v=entropy_loss_v+loss_policy_v+loss_value_v
        self.learn.loss_func.loss=loss_v.detach()
        return {'last_loss':loss_v,'skip_bwd':self.skip_process_batch}
        
        
    def on_backward_end(self,*args,**kwargs): return {'skip_bwd':self.skip_process_batch,
                                                      'skip_step':self.skip_process_batch,
                                                      'skip_zero':self.skip_process_batch}
    def on_step_end(self,*args,**kwargs):
        if self.skip_process_batch:return
        nn_utils.clip_grad_norm_(self.learn.model.parameters(),self.learn.clip_grad)
        self.batch.clear()

In [None]:
import pytest
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',display=False,firstlast=True,add_valid=False,skip_step=4,n_processes=4,n_envs=2)
model=LinearA2C((4,),2)
agent=PolicyAgent(model=model)
learn=A3CLearner(data,model,agent=agent,callback_fns=[A3CTrainer,RewardMetric])
learn.fit(4,lr=0.001,wd=0)

epoch,train_loss,valid_loss,train_reward,time


torch.Size([2, 4])


TypeError: only size-1 arrays can be converted to Python scalars

In [None]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)