In [None]:
# default_exp a3c.a3c_data

In [2]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastai.callback.progress import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

In [3]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# A3C Datawise

## A3C Model

In [4]:
# export
class LinearA2C(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LinearA2C, self).__init__()

        self.policy = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        self.value = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def _get_conv_out(self, shape):
        o=self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self,x):
        fx=x.float()
        return self.policy(fx),self.value(fx)

## A3C Learner

In [5]:
batch=[
 Experience(s=tensor([[-0.0285,  0.1640, -0.0033, -0.3421]]),sp=tensor([[-0.0285,  0.1640, -0.0033, -0.3421]]),
            a=tensor([1]),r=tensor([1.]),d=tensor([0.])),
 Experience(s=tensor([[-0.0252, -0.0311, -0.0101, -0.0504]]),sp=tensor([[-0.0252, -0.0311, -0.0101, -0.0504]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.])),
 Experience(s=tensor([[-0.0258, -0.2261, -0.0111,  0.2391]]),sp=tensor([[-0.0258, -0.2261, -0.0111,  0.2391]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.])),
 Experience(s=tensor([[-0.0517, -0.2260,  0.0195,  0.2377]]),sp=tensor([[-0.0517, -0.2260,  0.0195,  0.2377]]),
            a=tensor([1]),r=tensor([1.]),d=tensor([0.])),
 Experience(s=tensor([[-0.0562, -0.4214,  0.0242,  0.5365]]),sp=tensor([[-0.0562, -0.4214,  0.0242,  0.5365]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([0.])),
 Experience(s=tensor([[-0.0647, -0.6169,  0.0349,  0.8367]]),sp=tensor([[-0.0647, -0.6169,  0.0349,  0.8367]]),
            a=tensor([0]),r=tensor([1.]),d=tensor([1.]))
]

In [6]:
# export
def r_estimate(s,r,d_mask,non_d_mask,model,val_gamma,device):
    "Returns rewards `r` estimated direction by `model` from states `s`"
    r_np = np.array(r, dtype=np.float32)
    if len(d_mask) != 0:
        s_v = torch.FloatTensor(s).to(device)
        v = model(s_v)[1]  # Remember that models are going to return the actions and the values
        v_np = v.data.cpu().numpy()[:, 0]
        r_np[d_mask] += val_gamma * v_np
    return r_np

def unbatch(batch,model,last_val_gamma,device='cpu'):
    s, a, r, d_mask, sp = [], [], [], [], []
    non_d_mask = []
    for i, exp in enumerate(batch):
#         print(exp.s.shape,exp.r.shape,exp.sp.shape,exp.a.shape,exp.d.shape)
#         raise Exception
        s.append(exp.s.numpy()[0])
        a.append(int(exp.a.numpy()))  # TODO can we change this to toggle between discrete and continuous actions?
        r.append(exp.r.numpy().astype(np.float32).reshape(1,))
        if not bool(exp.d):
            d_mask.append(i)
            sp.append(exp.sp.numpy()[0].reshape(1,-1))
        else:
            non_d_mask.append(i)
    s_t = torch.FloatTensor(s).to(device)
    a_t = torch.LongTensor(a).to(device)
    r_np = r_estimate(sp, r, d_mask, non_d_mask,model, last_val_gamma, device)
    estimated_r = torch.FloatTensor(r_np).to(device)
    return s_t, a_t, estimated_r

In [7]:
model=LinearA2C((4,),2)
unbatch(batch,model,2)

(tensor([[-0.0285,  0.1640, -0.0033, -0.3421],
         [-0.0252, -0.0311, -0.0101, -0.0504],
         [-0.0258, -0.2261, -0.0111,  0.2391],
         [-0.0517, -0.2260,  0.0195,  0.2377],
         [-0.0562, -0.4214,  0.0242,  0.5365],
         [-0.0647, -0.6169,  0.0349,  0.8367]]),
 tensor([1, 0, 0, 1, 0, 0]),
 tensor([[0.8995],
         [0.7793],
         [0.7228],
         [0.7366],
         [0.6730],
         [1.0000]]))

In [8]:
# export
def loss_func(pred,yb,learn):
#     print(yb)
    yb=[Experience(**{k:yb[k][i] for k in yb}) for i in range(learn.dls.bs)]
    s_t,a_t,r_est=unbatch(yb,learn.model,learn.discount**learn.reward_steps)
#     print(s_t.shape,a_t.shape,r_est.shape)
    r_est=r_est.squeeze(1)

    learn.opt.zero_grad()
    logits_v,value_v=learn.model(s_t)

    loss_value_v=F.mse_loss(value_v.squeeze(-1),r_est)

    log_prob_v=F.log_softmax(logits_v,dim=1)
    adv_v=r_est-value_v.detach()

    log_prob_actions_v=adv_v*log_prob_v[range(learn.dls.bs),a_t]
    loss_policy_v=-log_prob_actions_v.mean()

    prob_v=F.softmax(logits_v,dim=1)
    entropy_loss_v=learn.entropy_beta*(prob_v*log_prob_v).sum(dim=1).mean()

    loss_v=entropy_loss_v+loss_value_v+loss_policy_v

    return loss_v

class A3CLearner(AgentLearner):
    def __init__(self,dls,discount=0.99,entropy_beta=0.01,clip_grad=0.1,reward_steps=1,**kwargs):
        self.create_m=True
        super().__init__(dls,loss_func=partial(loss_func,learn=self),**kwargs)
        self.opt=OptimWrapper(AdamW(self.model.parameters(),eps=1e-3))
        self.model.share_memory()
        self.discount=discount
        self.entropy_beta=entropy_beta
        self.reward_steps=reward_steps
        self.clip_grad=clip_grad
        
    def _split(self, b):
        if len(b)==1 and type(b[0])==tuple:b=b[0]
        super()._split(b)

In [9]:
# export
class A3CTrainer(Callback):
    
    def after_backward(self):
        nn_utils.clip_grad_norm_(self.learn.model.parameters(),self.learn.clip_grad)

In [10]:
def data_fit(queue:mp.JoinableQueue=None,items:L=None,agent:BaseAgent=None,learner_cls:Learner=None,experience_block:ExperienceBlock=None,
             cancel:mp.Event=None):
#     print(agent,flush=True)
    blk=IterableDataBlock(blocks=(experience_block(agent=agent)),
                          splitter=FuncSplitter(lambda x:False))
    dls=blk.dataloaders(items)
    while True:
        for x in dls[0]:
            queue.put(x)
            if cancel.is_set():
                queue.put(None)
                return None

In [15]:
# export
class AvgEpisodeRewardMetric(Metric):
    def __init__(self):self.rolling_rewards=[deque([0],maxlen=100)]
        
    def accumulate(self,learn):
        yb=learn.yb[0]
        yb=[Experience(**{k:yb[k][i] for k in yb}) for i in range(learn.dls.bs)]
        rewards=[y.episode_r for y in yb if y.absolute_end]
#         print([y for y in yb if y.absolute_end])
        for r in rewards:self.rolling_rewards.append(r.numpy())
#         print(len(rewards))
#         if len(rewards)!=0:self.r=sum(rewards)/len(rewards)
        
    @property
    def value(self): return np.mean(self.rolling_rewards)
    @property
    def name(self):return 'avg_episode_r'

In [None]:
env='CartPole-v1'
model=LinearA2C((4,),2)

block=AsyncExperienceBlock(
    experience_block=partial(FirstLastExperienceBlock,a=0,seed=None,dls_kwargs={'bs':1,'n_steps':4,'num_workers':0,
                                                                             'verbose':False,'indexed':True,'shuffle_train':False,
                                                                             'batch_tfms':lambda x:(x['s'],x)}),
    n_processes=4,
    n=128*20,
    data_fit=data_fit,
    agent=ActorCriticAgent(model)
)
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
                      batch_tfms=lambda x:(x['s'],x),
                     )
dls=blk.dataloaders([env]*15,bs=128)

agent=ActorCriticAgent(model=model)
learner=A3CLearner(dls,agent=agent,cbs=[A3CTrainer],reward_steps=4,metrics=[AvgEpisodeRewardMetric()])
learner.fit(600,lr=0.001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,1.625428,14.455291,,14.455291,00:07
1,1.908461,13.63015,,13.63015,00:07
2,3.037269,11.44929,,11.44929,00:07
3,3.583667,10.05849,,10.05849,00:07
4,3.155716,9.191704,,9.191704,00:07
5,2.492913,8.611614,,8.611614,00:07
6,1.93327,8.24117,,8.24117,00:08
7,2.475357,8.19125,,8.19125,00:07
8,5.102181,8.170218,,8.170218,00:07
9,4.77487,8.010156,,8.010156,00:07


In [None]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)