In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.dueling

In [None]:
#|export
# Python native modules
from copy import deepcopy
from typing import Optional,Callable,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import nn,optim
# Local modulesf
from fastrl.pipes.core import find_dp
from fastrl.memory.experience_replay import ExperienceReplay
from fastrl.loggers.core import BatchCollector,EpochCollector
from fastrl.learner.core import LearnerBase,LearnerHead
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.agents.dqn.basic import (
    LossCollector,
    RollingTerminatedRewardCollector,
    EpisodeCollector,
    StepBatcher,
    TargetCalc,
    LossCalc,
    ModelLearnCalc,
    DQN,
    DQNAgent
)
from fastrl.agents.dqn.target import (
    TargetModelUpdater,
    TargetModelQCalc,
    DQNTargetLearner
)

# DQN Dueling
> DQN using a split head for comparing the davantage of different actions



## Training DataPipes

In [None]:
#|export
class DuelingHead(nn.Module):
    def __init__(self,
            hidden:int, # Input into the DuelingHead, likely a hidden layer input
            n_actions:int, # Number/dim of actions to output
            lin_cls=nn.Linear
        ):
        super().__init__()
        self.val = lin_cls(hidden,1)
        self.adv = lin_cls(hidden,n_actions)

    def forward(self,xi):
        val,adv=self.val(xi),self.adv(xi)
        xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
        return xi

Try training with basic defaults...

In [None]:
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.envs.gym import GymDataPipe
from fastrl.loggers.core import ProgressBarLogger
from fastrl.dataloading.core import dataloaders

In [None]:
#|eval:false
# Setup Loggers
def logger_bases(pipe):
    pipe = pipe.dump_records()
    pipe = ProgressBarLogger(pipe)
    return pipe
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2,head_layer=DuelingHead)
# Setup the Agent
model.train()
model = model.share_memory()
agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)
# Setup the DataBlock
train_pipe = GymDataPipe(
    ['CartPole-v1']*1,
    agent=agent,
    nsteps=2,
    nskips=2,
    firstlast=True,
    bs=1
)
dls = dataloaders(train_pipe)
# Setup the Learner
learner = DQNTargetLearner(
    model,
    dls,
    logger_bases=logger_bases,
    bs=128,
    max_sz=100_000,
    nsteps=2,
    lr=0.01,
    batches=1000,
    target_sync=300
)
learner.fit(7)

In [None]:
#|hide
learner.validate()

In [None]:
#|hide
#|eval: false
!nbdev_export