In [312]:
# default_exp a3c.a3c_data
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [313]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# A3C Data

> A decoupled actor critic agent which trains on data collected from environments running in a completely separate process.

In [314]:
# export
# from fastai.basic_data import *
import torch.nn.utils as nn_utils
from fastai.torch_core import *
from fastai.callbacks import *
from fastai.basic_train import *
from fastai.callback import *
from fastai.basic_data import *
from fastrl.wrappers import *
from fastrl.basic_agents import *
from fastrl.basic_train import *
from fastrl.data_block import *
from fastrl.metrics import *
from dataclasses import asdict
from functools import partial
from fastprogress.fastprogress import IN_NOTEBOOK
from fastcore.utils import *
import torch.multiprocessing as mp
import torch.optim as optim
from queue import Empty
import textwrap
import logging
import gym

logging.basicConfig(format='[%(asctime)s] p%(process)s line:%(lineno)d %(levelname)s - %(message)s',
                    datefmt='%m-%d %H:%M:%S')
_logger=logging.getLogger(__name__)

In [315]:
# hide
_logger.setLevel('INFO')
from fastcore.foundation import *
import sys

In [316]:
# export
def a3c_data_fitter(model:Optional[nn.Module],learner_cls:Optional['AgentLearner'],agent:Optional['BaseAgent'],ds_cls:ExperienceSourceDataset,
            pause_event:mp.Event,cancel_event:mp.Event,main_queue:Optional[mp.JoinableQueue],metric_queue:Optional[mp.JoinableQueue],display=False,
            rows=1,cols=1,max_w=800):
    "A3C fitter for AsyncExperienceSourceDataset."
    ds=ds_cls()
    if display:ds=DatasetDisplayWrapper(ds,rows=rows,cols=cols,max_w=max_w)
    dl=DataLoader(ds,batch_size=1,num_workers=0)
    if learner_cls is not None:
        learn=learner_cls(data=DataBunch(dl,dl),model=model,agent=agent)
        ds.learn=learn
    try:
        while not cancel_event.is_set():
            for xb,yb in ds:
#                 print(yb)
                while pause_event.is_set() and not self.cancel_event.is_set():cancel_event.wait(0.1)
                if main_queue is not None:main_queue.put(yb)
            if metric_queue is not None:
                total_rewards=ds.pop_total_rewards()
                if total_rewards:
#                     print(total_rewards)
                    sys.stdout.flush()
                    if metric_queue.full():_logger.warning('Metric queue is full. Increase its size,empty it, or set metric_queue to None.')
                    metric_queue.put(TotalRewards(total_rewards))                    
            while pause_event.is_set():pass
    finally:
        main_queue.put(None)
        metric_queue.put(None)
        cancel_event.set()
        sys.stdout.flush()

@dataclass
class A3CLearner(AgentLearner):
    fitter_fn:Callable=a3c_data_fitter
    batch_sz:int=128
    discount:float=0.99
    entropy_beta:float=0.01
    clip_grad:float=0.1
    def init(self, init):print('skipping')
        
    def __post_init__(self):
        super(A3CLearner,self).__post_init__()
        if self.model is None:self.model=self.agent.model
        if self.agent.model is None: self.agent.model=self.model
        self.model.share_memory()
        self.opt=OptimWrapper(AdamW(self.model.parameters(),eps=1e-3))
        
    def predict(self,s):
        out=self.agent(s)
        if type(out)==tuple:return out[0],None
        return out,None

In [317]:
batch=[
 Experience(s=tensor([[-0.0285,  0.1640, -0.0033, -0.3421]]),sp=tensor([[-0.0285,  0.1640, -0.0033, -0.3421]]),
            a=tensor([1]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0252, -0.0311, -0.0101, -0.0504]]),sp=tensor([[-0.0252, -0.0311, -0.0101, -0.0504]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0258, -0.2261, -0.0111,  0.2391]]),sp=tensor([[-0.0258, -0.2261, -0.0111,  0.2391]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0517, -0.2260,  0.0195,  0.2377]]),sp=tensor([[-0.0517, -0.2260,  0.0195,  0.2377]]),
            a=tensor([1]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0562, -0.4214,  0.0242,  0.5365]]),sp=tensor([[-0.0562, -0.4214,  0.0242,  0.5365]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.]),agent_s=tensor([[[0.]]])),
 Experience(s=tensor([[-0.0647, -0.6169,  0.0349,  0.8367]]),sp=tensor([[-0.0647, -0.6169,  0.0349,  0.8367]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([1.]),agent_s=tensor([[[0.]]]))
]
class LinearA2C(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LinearA2C, self).__init__()

        self.policy = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        self.value = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def _get_conv_out(self, shape):
        o=self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self,x):
        fx=x.float()
#         batch_sz=fx.shape[0]
#         t=torch.full((batch_sz,2),0,dtype=float)
#         t[:,0]=1.0
        
        return self.policy(fx),self.value(fx)
model=LinearA2C((4,),2)

In [318]:
def getBack(var_grad_fn):
    print(var_grad_fn)
    for n in var_grad_fn.next_functions:
        if n[0]:
            try:
                tensor = getattr(n[0], 'variable')
                print(n[0])
                print('Tensor with grad found:', tensor)
                print(' - gradient:', tensor.grad)
                print()
            except AttributeError as e:
                getBack(n[0])

In [319]:
# export
def r_estimate(s,r,d_mask,model,val_gamma,device):
    "Returns rewards `r` estimated direction by `model` from states `s`"
    r_np=np.array(r,dtype=np.float32)
#     print(len(d_mask),len(r),len(s))
    if len(d_mask)!=0:
        s_v=torch.FloatTensor(s).to(device)
        v=model(s_v)[1] # Remember that models are going to return the actions and the values
        v_np=v.data.cpu().numpy()[:,0]
        r_np[d_mask]+=val_gamma*v_np
        
    return r_np

def unbatch(batch,model,last_val_gamma,device='cpu')->Tuple(List,List,List):
    s,a,r,d_mask,sp=[],[],[],[],[]
    for i,exp in enumerate(batch):
        s.append(exp.s.numpy())
        a.append(int(exp.a.numpy())) # TODO can we change this to toggle between discrete and continuous actions?
        r.append(exp.r.numpy().astype(np.float32))
        if not bool(exp.d):
            d_mask.append(i)
            sp.append(exp.sp.numpy())
    s_t=torch.FloatTensor(s).to(device)
    a_t=torch.LongTensor(a).to(device)
    
    r_np=r_estimate(sp,r,d_mask,model,last_val_gamma,device)
    estimated_r=torch.FloatTensor(r_np).to(device)
    return s_t,a_t,estimated_r

In [320]:
unbatch(batch,model,2)

(tensor([[[-0.0285,  0.1640, -0.0033, -0.3421]],
 
         [[-0.0252, -0.0311, -0.0101, -0.0504]],
 
         [[-0.0258, -0.2261, -0.0111,  0.2391]],
 
         [[-0.0517, -0.2260,  0.0195,  0.2377]],
 
         [[-0.0562, -0.4214,  0.0242,  0.5365]],
 
         [[-0.0647, -0.6169,  0.0349,  0.8367]]]),
 tensor([1, 0, 0, 1, 0, 0]),
 tensor([[0.8744],
         [0.8577],
         [0.8597],
         [0.8421],
         [0.8626],
         [1.0000]]))

In [321]:
# export
debug_batch=[]

class A3CTrainer(LearnerCallback):
    def __init__(self,*args,**kwargs):
        super(A3CTrainer,self).__init__(*args,**kwargs)
        self.batch=[]
        
    @property
    def skip_process_batch(self):return len(self.batch)<self.learn.data.bs
    def on_train_begin(self,**kwargs):self.batch.clear()
    
    def on_batch_begin(self,last_target,**kwargs):
        self.batch.extend([Experience(**{k:v[i] if len(v)!=0 else None for k,v in last_target.items()}) for i in range(len(last_target['s']))])
        
    def on_backward_begin(self,last_loss,**kwargs):
        if self.skip_process_batch:return {'skip_bwd':self.skip_process_batch}
        s_t,a_t,r_est=unbatch(self.batch,self.learn.model,self.learn.discount**self.data.ds_kwargs['skip_n_steps'])
        self.learn.opt.zero_grad()
        logits_v,value_v=self.learn.model(s_t)

        loss_value_v=F.mse_loss(value_v.squeeze(-1),r_est)
#         print((r_est.mean(),value_v.mean()))
        log_prob_v=F.log_softmax(logits_v,dim=1)
        adv_v=r_est-value_v.detach()
        log_prob_actions_v=adv_v*log_prob_v[range(self.learn.data.bs),a_t]
        loss_policy_v=-log_prob_actions_v.mean()
        
        prob_v=F.softmax(logits_v,dim=1)
#         print(prob_v.max(),log_prob_v.max(),prob_v.min(),log_prob_v.min())
        entropy_loss_v=self.learn.entropy_beta*(prob_v*log_prob_v).sum(dim=1).mean()
    
        print(entropy_loss_v,loss_policy_v,loss_value_v)
        loss_v=entropy_loss_v+loss_policy_v+loss_value_v
        
        self.learn.loss_func.loss=loss_v.detach()
        return {'last_loss':loss_v,'skip_bwd':self.skip_process_batch}
        
            
    def on_backward_end(self,*args,**kwargs): 
        if not self.skip_process_batch:nn_utils.clip_grad_norm_(self.learn.model.parameters(),self.learn.clip_grad)
        return {'skip_bwd':self.skip_process_batch,
                'skip_step':self.skip_process_batch,
                'skip_zero':self.skip_process_batch}
    def on_step_end(self,last_loss,*args,**kwargs):
        getBack(last_loss.grad_fn)
        if self.skip_process_batch:return
        self.batch.clear()

In [322]:
data=AsyncExperienceSourceDataBunch.from_env('CartPole-v1',bs=128,n_processes=4,firstlast=True,ds_kwargs={'n_envs':15,'skip_n_steps':4},num_workers=12)
model=LinearA2C((4,),2)
agent=ActorCriticAgent(model=model)
learn=A3CLearner(data,model,agent=agent,callback_fns=[A3CTrainer,RewardMetric,NGamesMetric])
learn.fit(1,lr=0.001,wd=0)

epoch,train_loss,valid_loss,train_reward,train_n_games,time
0,0.5,#na#,17.296296,27,00:03


tensor(-0.0068, grad_fn=<MulBackward0>) tensor(2.4953, grad_fn=<NegBackward>) tensor(13.6656, grad_fn=<MseLossBackward>)
<AddBackward0 object at 0x7fb4b7b1db50>
<AddBackward0 object at 0x7fb4b7b1d2d0>
<MulBackward0 object at 0x7fb4b7b20d10>
<MeanBackward0 object at 0x7fb42b51c2d0>
<SumBackward1 object at 0x7fb4b7b20f50>
<MulBackward0 object at 0x7fb4b7b92850>
<SoftmaxBackward object at 0x7fb42b51c9d0>
<AddmmBackward object at 0x7fb42b51ce10>
<AccumulateGrad object at 0x7fb42b5840d0>
Tensor with grad found: Parameter containing:
tensor([0.0304, 0.0143], requires_grad=True)
 - gradient: tensor([ 0.0008, -0.0008])

<ReluBackward0 object at 0x7fb42b584610>
<AddmmBackward object at 0x7fb42b5848d0>
<AccumulateGrad object at 0x7fb42b51ce90>
Tensor with grad found: Parameter containing:
tensor([ 0.0836,  0.1201,  0.3913, -0.4708,  0.1459,  0.2671, -0.0359,  0.2574,
         0.4952, -0.0368, -0.3959, -0.3020,  0.3118,  0.3358,  0.1539, -0.4785,
        -0.4129,  0.0446, -0.1756,  0.4835,  0.491

In [323]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 02_callbacks.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_metrics.ipynb.
Converted 05_data_block.ipynb.
Converted 06_basic_train.ipynb.
Converted 12_a3c.a3c_data.ipynb.
Converted index.ipynb.
Converted notes.ipynb.


converting: /opt/project/fastrl/nbs/12_a3c.a3c_data.ipynb
