In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.double

In [None]:
#|export
# Python native modules
from copy import deepcopy
from typing import Optional,Callable,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import nn,optim
# Local modulesf
from fastrl.pipes.core import find_dp
from fastrl.memory.experience_replay import ExperienceReplay
from fastrl.loggers.core import BatchCollector,EpochCollector
from fastrl.learner.core import LearnerBase,LearnerHead
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.agents.dqn.basic import (
    LossCollector,
    RollingTerminatedRewardCollector,
    EpisodeCollector,
    StepBatcher,
    TargetCalc,
    LossCalc,
    ModelLearnCalc,
    DQN,
    DQNAgent
)
from fastrl.agents.dqn.target import (
    TargetModelUpdater,
    TargetModelQCalc
)

# DQN Double
> Instead of using the actions from memory to train on, use the actions chosen by the current model.



## Training DataPipes

In [None]:
#|export
class DoubleQCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None):
        self.source_datapipe = source_datapipe
                
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.done_mask = batch.terminated.reshape(-1,)
            with torch.no_grad():
                chosen_actions = self.learner.model(batch.next_state).argmax(dim=1).reshape(-1,1)
                self.learner.next_q = self.learner.target_model(batch.next_state).gather(1,chosen_actions)
            self.learner.next_q[self.learner.done_mask] = 0
            yield batch

In [None]:
#|export
def DoubleDQNLearner(
    model,
    dls,
    logger_bases:Optional[Callable]=None,
    loss_func=nn.MSELoss(),
    opt=optim.AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None,
    batches=None,
    target_sync=300
) -> LearnerHead:
    learner = LearnerBase(model,dls=dls[0])
    learner = BatchCollector(learner,batches=batches)
    learner = EpochCollector(learner)
    if logger_bases: 
        learner = logger_bases(learner)
        learner = RollingTerminatedRewardCollector(learner)
        learner = EpisodeCollector(learner)
    learner = learner.catch_records()
    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
    learner = StepBatcher(learner,device=device)
    # learner = TargetModelQCalc(learner)
    learner = DoubleQCalc(learner)
    learner = TargetCalc(learner,nsteps=nsteps)
    learner = LossCalc(learner,loss_func=loss_func)
    learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
    learner = TargetModelUpdater(learner,target_sync=target_sync)
    if logger_bases: 
        learner = LossCollector(learner).catch_records()

    if len(dls)==2:
        val_learner = LearnerBase(model,dls[1])
        val_learner = BatchCollector(val_learner,batches=batches)
        val_learner = EpochCollector(val_learner).catch_records(drop=True)
        val_learner = VSCodeDataPipe(val_learner)
        return LearnerHead((learner,val_learner),model)
    else:
        return LearnerHead(learner,model)

Try training with basic defaults...

In [None]:
from fastrl.envs.gym import GymDataPipe
from fastrl.loggers.core import ProgressBarLogger
from fastrl.dataloading.core import dataloaders

In [None]:
#|eval:false
# Setup Loggers
def logger_bases(pipe):
    pipe = pipe.dump_records()
    pipe = ProgressBarLogger(pipe)
    return pipe
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the Agent
agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)
# Setup the DataBlock
train_pipe = GymDataPipe(
    ['CartPole-v1']*1,
    agent=agent,
    nsteps=2,
    nskips=2,
    firstlast=True,
    bs=1
)
dls = dataloaders(train_pipe)
# Setup the Learner
learner = DoubleDQNLearner(
    model,
    dls,
    logger_bases=logger_bases,
    bs=128,
    max_sz=100_000,
    nsteps=2,
    lr=0.01,
    batches=1000,
    target_sync=300
)
learner.fit(7)

The DQN learners, but I wonder if we can get it to learn faster...

In [None]:
train_pipe = GymDataPipe(
    ['CartPole-v1']*1,
    agent=agent,
    nsteps=2,
    nskips=2,
    firstlast=True,
    bs=1
)
dls = dataloaders([train_pipe,train_pipe])
# Setup the Learner
learner = DoubleDQNLearner(
    model,
    dls,
    logger_bases=logger_bases,
    bs=128,
    max_sz=100_000,
    nsteps=2,
    lr=0.01,
    batches=1000,
    target_sync=300
)

In [None]:
#|hide
#|eval:false
learner.validate()

In [None]:
#|hide
#|eval: false
!nbdev_export