In [1]:
# default_exp qlearning.dqn_target

In [2]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastrl.ptan_extension import *
from fastrl.qlearning.dqn import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

  return torch._C._cuda_getDeviceCount() > 0


# Target DQN

In [3]:
# export
class TargetDQNTrainer(Callback):
    def __init__(self,n_batch=0): store_attr()
    def after_pred(self):
        exps=[ExperienceFirstLast(*o) for o in self.learn.sample_yb]
        s=torch.stack([e.state for e in exps]).float().to(default_device())
        a=torch.stack([e.action for e in exps]).to(default_device())
        sp=torch.stack([e.last_state for e in exps]).float().to(default_device())
        r=torch.stack([e.reward for e in exps]).float().to(default_device())
        d=torch.stack([e.done for e in exps]).to(default_device())

        state_action_values = self.learn.model(s.float()).gather(1, a.unsqueeze(-1)).squeeze(-1)
#         next_state_values = self.learn.target_model(sp.float()).max(1)[0]
        next_state_values=self.get_next_state_values(sp)
        next_state_values[d] = 0.0

        expected_state_action_values=next_state_values.detach()*(self.learn.discount**self.learn.n_steps)+r
#         print(*self.learn.yb,self.learn.pred)
#         print(self.learn.pred,self.learn.yb)
#         print(self.learn._yb,self.learn.yb[0])
        self.learn._yb=self.learn.yb
        self.learn.yb=(expected_state_action_values.float(),)
#         print(self.learn.yb[0].mean(),self.learn.yb[0].size())
        self.learn.pred=state_action_values
#         print(self.learn.pred.mean(),self.learn.pred.size())
    
#         print(self.learn.agent.a_selector.epsilon,self.n_batch)
        
    def get_next_state_values(self,sp):
        return self.learn.target_model(sp.float()).max(1)[0]
    
#     def after_epoch(self): print(len(self.learn.cbs[4].queue))
    
    def after_loss(self):
        self.learn.yb=self.learn._yb
        
    def after_batch(self):
        if self.n_batch%self.learn.target_sync==0:
#             print('copy over',self.n_batch)
            self.learn.target_model.load_state_dict(self.learn.model.state_dict())
        self.n_batch+=1

In [4]:
# export
class TargetDQNLearner(AgentLearner):
    def __init__(self,dls,discount=0.99,n_steps=3,target_sync=300,**kwargs):
        store_attr()
        self.target_q_v=[]
        super().__init__(dls,loss_func=nn.MSELoss(),**kwargs)
        self.target_model=deepcopy(self.model)

In [5]:
env='CartPole-v1'
model=LinearDQN((4,),2)
agent=DiscreteAgent(model=model.to(default_device()),device=default_device(),
                    a_selector=EpsilonGreedyActionSelector())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=1,dls_kwargs={'bs':1,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
                     )
dls=blk.dataloaders([env]*1,n=1*1000,device=default_device())

learner=TargetDQNLearner(dls,agent=agent,n_steps=1,cbs=[EpsilonTracker(e_steps=100),
                                        ExperienceReplay(sz=100000,bs=32,starting_els=32,max_steps=gym.make(env)._max_episode_steps),
                                        TargetDQNTrainer],metrics=[AvgEpisodeRewardMetric(experience_cls=ExperienceFirstLast,always_extend=True)])
learner.fit(47,lr=0.0001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,0.523472,10.336842,,10.336842,00:10
1,1.146026,11.15,,11.15,00:10
2,2.486364,12.98,,12.98,00:10
3,3.273238,11.55,,11.55,00:10
4,3.176242,19.37,,19.37,00:10
5,3.657235,25.84,,25.84,00:10
6,3.982712,33.01,,33.01,00:10
7,6.177357,40.39,,40.39,00:10
8,5.508453,47.19,,47.19,00:10
9,7.158568,50.64,,50.64,00:10


  warn("Your generator is empty.")


KeyboardInterrupt: 

In [6]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html()

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_learner.ipynb.
Converted 05a_ptan_extend.ipynb.
Converted 05b_data.ipynb.
Converted 05c_async_data.ipynb.
Converted 13_metrics.ipynb.
Converted 14a_actorcritic.sac.ipynb.
Converted 14b_actorcritic.diayn.ipynb.
Converted 15_actorcritic.a3c_data.ipynb.
Converted 16_actorcritic.a2c.ipynb.
Converted 17_actorcritc.v1.dads.ipynb.
Converted 18_policy_gradient.ppo.ipynb.
Converted 19_policy_gradient.trpo.ipynb.
Converted 20a_qlearning.dqn.ipynb.
Converted 20b_qlearning.dqn_n_step.ipynb.
Converted 20c_qlearning.dqn_target.ipynb.
Converted 20d_qlearning.dqn_double.ipynb.
Converted 20e_qlearning.dqn_noisy.ipynb.
Converted index.ipynb.
Converted notes.ipynb.


converting: /opt/project/fastrl/nbs/20c_qlearning.dqn_target.ipynb
converting: /opt/project/fastrl/nbs/20e_qlearning.dqn_noisy.ipynb
converting: /opt/project/fastrl/nbs/05a_ptan_extend.ipynb
