In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.double

In [None]:
#|export
# Python native modules
# import os
# from collections import deque
# from typing import *
# Third party libs
# from fastcore.all import *
# import torchdata.datapipes as dp
# from torch.utils.data.dataloader_experimental import DataLoader2
# from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
# from torchdata.dataloader2.graph import find_dps,traverse,DataPipe,replace_dp,remove_dp
# Local modules
# import torch
# from torch.nn import *
# import torch.nn.functional as F
# from torch.optim import *

# from fastrl.torch_core import *

# from fastrl.core import *
# from fastrl.agents.core import *
# from fastrl.pipes.core import *
# from fastrl.data.block import *
# from fastrl.memory.experience_replay import *
# from fastrl.agents.core import *
# from fastrl.agents.discrete import *
# from fastrl.loggers.core import *
# from fastrl.loggers.jupyter_visualizers import *
# from fastrl.learner.core import *
# from fastrl.agents.dqn.basic import *
# from fastrl.agents.dqn.target import *

# DQN Double
> Instead of using the actions from memory to train on, use the actions chosen by the current model.



## Training DataPipes

In [None]:
#|export
class DoubleQCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None):
        self.source_datapipe = source_datapipe
                
    def __iter__(self):
        self.learner = find_dp(traverse(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.done_mask = batch.terminated.reshape(-1,)
            with torch.no_grad():
                chosen_actions = self.learner.model(batch.next_state).argmax(dim=1).reshape(-1,1)
                self.learner.next_q = self.learner.target_model(batch.next_state).gather(1,chosen_actions)
            self.learner.next_q[self.learner.done_mask] = 0
            yield batch

Try training with basic defaults...

In [None]:
from fastrl.envs.gym import GymDataPipe
from fastrl.loggers.core import ProgressBarLogger
from fastrl.dataloading.core import dataloaders

In [None]:
# Setup Loggers
logger_base = ProgressBarLogger()

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the Agent
agent = DQNAgent(model,[logger_base],max_steps=4000)
# Setup the DataBlock
block = DataBlock(
    GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,n=1000,bs=1)
)
dls = L(block.dataloaders(['CartPole-v1']*1))
# Setup the Learner
learner = DQNLearner(model,dls,logger_bases=[logger_base],bs=128,max_sz=100_000,
                    dp_augmentation_fns=[
                        TargetModelUpdater.insert_dp(),
                        DoubleQCalc.replace_dp()
                    ])
learner.fit(3)
# learner.fit(25)

The DQN learners, but I wonder if we can get it to learn faster...

In [None]:
# Setup Loggers
logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                 batch_on_pipe=BatchCollector)

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the Agent
agent = DQNAgent(model,[logger_base],max_steps=10000)
# Setup the DataBlock
block = DataBlock(
    GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True,n=1000,bs=1), # We basically merge 2 steps into 1 and skip. 
    (GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True,n=100,include_images=True),VSCodeTransformBlock())
)
# pipes = L(block.datapipes(['CartPole-v1']*1,n=10))
dls = L(block.dataloaders(['CartPole-v1']*1))
# Setup the Learner
learner = DQNLearner(model,dls,logger_bases=[logger_base],bs=128,max_sz=20_000,nsteps=2,lr=0.001,
                    dp_augmentation_fns=[
                        TargetModelUpdater.insert_dp(),
                        DoubleQCalc.replace_dp()
                    ])
learner.fit(3)
# learner.fit(10)

In [None]:
#|hide
#|eval:false
learner.validate()

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()