In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.dueling

In [None]:
#|export
# Python native modules
# Third party libs
import torch
from torch import nn
# Local modules
from fastrl.agents.dqn.basic import (
    DQN,
    DQNAgent
)
from fastrl.agents.dqn.target import DQNTargetLearner

# DQN Dueling
> DQN using a split head for comparing the davantage of different actions



## Training DataPipes

In [None]:
#|export
class DuelingHead(nn.Module):
    def __init__(
            self,
            hidden: int, # Input into the DuelingHead, likely a hidden layer input
            n_actions: int, # Number/dim of actions to output
            lin_cls = nn.Linear
        ):
        super().__init__()
        self.val = lin_cls(hidden,1)
        self.adv = lin_cls(hidden,n_actions)

    def forward(self,xi):
        val,adv = self.val(xi),self.adv(xi)
        xi = val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
        return xi

Try training with basic defaults...

In [None]:
from fastrl.envs.gym import GymDataPipe
from fastrl.dataloading.core import dataloaders

In [None]:
#|eval:false
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2,head_layer=DuelingHead)
# Setup the Agent
model.train()
agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)
# Setup the Dataloaders
params = dict(
    source=['CartPole-v1']*1,
    agent=agent,
    nsteps=2,
    nskips=2,
    firstlast=True
)
dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))
# Setup the Learner
learner = DQNTargetLearner(
    model,
    dls,
    bs=128,
    max_sz=100_000,
    nsteps=2,
    lr=0.01,
    batches=1000,
    target_sync=300
)
learner.fit(7)

In [None]:
#|eval:false
learner.validate()

In [None]:
#|hide
#|eval: false
!nbdev_export