In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.rainbow

In [None]:
#|export
# Python native modules
from copy import deepcopy
from typing import Optional,Callable,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import nn,optim
from fastcore.all import store_attr,ifnone
import numpy as np
import torch.nn.functional as F
# Local modulesf
from fastrl.torch_core import default_device,to_detach,evaluating
from fastrl.pipes.core import find_dp
from fastrl.agents.core import StepFieldSelector,SimpleModelRunner,NumpyConverter
from fastrl.agents.discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
from fastrl.memory.experience_replay import ExperienceReplay
from fastrl.loggers.core import BatchCollector,EpochCollector
from fastrl.learner.core import LearnerBase,LearnerHead
from fastrl.agents.core import AgentHead,AgentBase
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.loggers.core import ProgressBarLogger
from fastrl.agents.dqn.basic import (
    LossCollector,
    RollingTerminatedRewardCollector,
    EpisodeCollector,
    StepBatcher,
    TargetCalc,
    LossCalc,
    ModelLearnCalc,
    DQN,
    DQNAgent
)
from fastrl.agents.dqn.target import (
    TargetModelUpdater,
    TargetModelQCalc
)
from fastrl.agents.dqn.dueling import DuelingHead 
from fastrl.agents.dqn.categorical import (
    CategoricalDQNAgent,
    CategoricalDQN,
    CategoricalTargetQCalc,
    PartialCrossEntropy
)  

# DQN Rainbow
> Combines target, dueling, double, categorical dqns

> Important: I think this also needs special exploration layers also to be officially a rainbow implimentation

In [None]:
#|export
def DQNRainbowLearner(
    model,
    dls,
    do_logging:bool=True,
    loss_func=PartialCrossEntropy,
    opt=optim.AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None,
    batches=None,
    target_sync=300,
    # Use DoubleDQN target strategy
    double_dqn_strategy=True
) -> LearnerHead:
    learner = LearnerBase(model,dls=dls[0])
    learner = BatchCollector(learner,batches=batches)
    learner = EpochCollector(learner)
    if do_logging: 
        learner = learner.dump_records()
        learner = ProgressBarLogger(learner)
        learner = RollingTerminatedRewardCollector(learner)
        learner = EpisodeCollector(learner).catch_records()
    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
    learner = StepBatcher(learner,device=device)
    learner = CategoricalTargetQCalc(learner,nsteps=nsteps,double_dqn_strategy=double_dqn_strategy).to(device=device)
    learner = LossCalc(learner,loss_func=loss_func)
    learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
    learner = TargetModelUpdater(learner,target_sync=target_sync)
    if do_logging: 
        learner = LossCollector(learner).catch_records()

    if len(dls)==2:
        val_learner = LearnerBase(model,dls[1]).visualize_vscode()
        val_learner = BatchCollector(val_learner,batches=batches)
        val_learner = EpochCollector(val_learner).catch_records(drop=True)
        return LearnerHead((learner,val_learner))
    else:
        return LearnerHead(learner)

In [None]:
from fastrl.envs.gym import GymDataPipe
from fastrl.dataloading.core import dataloaders

In [None]:
#|eval:false
# Setup up the core NN
torch.manual_seed(0)
# Rainbow uses a CategoricalDQN with a DuelingHead (DuealingDQN)
model = CategoricalDQN(4,2,head_layer=DuelingHead)
# Setup the Agent
agent = CategoricalDQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)
# Setup the Dataloaders
params = dict(
    source=['CartPole-v1']*1,
    agent=agent,
    nsteps=2,
    nskips=2,
    firstlast=True
)
dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))
# Setup the Learner
learner = DQNRainbowLearner(
    model,
    dls,
    bs=128,
    max_sz=100_000,
    nsteps=2,
    lr=0.001,
    batches=1000,
    target_sync=300
)
learner.fit(7)

In [None]:
#|eval:false
learner.validate()

In [None]:
#|hide
#|eval: false
!nbdev_export