In [17]:
# default_exp qlearning.dqn_n_step

In [18]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastrl.ptan_extension import *
from fastrl.qlearning.dqn import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

# DQN

In [25]:
# export
class TargetDQNTrainer(Callback):
    def __init__(self,n_batch=0): store_attr()
    def after_pred(self):
        s,a,r,sp,d,er,steps=(self.learn.xb+self.learn.yb)
        exps=[ExperienceFirstLast(*o) for o in zip(*(self.learn.xb+self.learn.yb))]
        batch_targets=[calc_target(self.learn.model, exp.reward, exp.last_state,exp.done,self.learn.discount**self.learn.n_steps)
                         for exp in exps]
        

        state_action_values = self.learn.model(s.float()).gather(1, a.unsqueeze(-1)).squeeze(-1)
        next_state_values = self.learn.target_model(sp.float()).max(1)[0]
        next_state_values[d] = 0.0

        expected_state_action_values=next_state_values.detach()*(self.learn.discount**self.learn.n_steps)+r
#         print(*self.learn.yb,self.learn.pred)
#         print(self.learn.pred,self.learn.yb)
#         print(self.learn._yb,self.learn.yb[0])
        self.learn._yb=self.learn.yb
        self.learn.yb=(expected_state_action_values.float(),)
        self.learn.pred=state_action_values
    
    def after_loss(self):
        self.learn.yb=self.learn._yb
        
    def after_batch(self):
        if self.n_batch%self.learn.target_sync==0:
            self.learn.target_model.load_state_dict(self.learn.model.state_dict())
        self.n_batch+=1

In [26]:
# export
class TargetDQNLearner(AgentLearner):
    def __init__(self,dls,discount=0.99,n_steps=3,target_sync=300,**kwargs):
        store_attr()
        self.target_q_v=[]
        super().__init__(dls,loss_func=nn.MSELoss(),**kwargs)
        self.target_model=deepcopy(self.model)

In [27]:
env='CartPole-v1'
model=LinearDQN((4,),2)
agent=DiscreteAgent(model=model.to(default_device()),device=default_device(),
                    a_selector=EpsilonGreedyActionSelector())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=3,dls_kwargs={'bs':32,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
#                       batch_tfms=lambda x:(x['s'],x),
                     )
dls=blk.dataloaders([env]*1,n=32*100,device=default_device())

learner=TargetDQNLearner(dls,agent=agent,n_steps=3,cbs=[EpsilonTracker,
                                        ExperienceReplay(sz=50000,bs=32,starting_els=32,max_steps=gym.make(env)._max_episode_steps),
                                        TargetDQNTrainer],metrics=[AvgEpisodeRewardMetric(experience_cls=ExperienceFirstLast)])
learner.fit(15,lr=0.01,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,0.488254,21.196809,,21.196809,00:03
1,0.196225,21.565,,21.565,00:03
2,0.121115,23.971667,,23.971667,00:03
3,0.652146,24.961667,,24.961667,00:03
4,0.471985,25.873333,,25.873333,00:03
5,0.39346,26.045,,26.045,00:03
6,0.68586,27.056667,,27.056667,00:03
7,0.844604,28.356667,,28.356667,00:03
8,0.761524,28.88,,28.88,00:03
9,1.480243,32.986667,,32.986667,00:03


In [None]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html()