In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.basic

In [None]:
#|export
# Python native modules
import os
from collections import deque
from typing import Callable,Optional,List
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import optim
from torch import nn
# Local modules
from fastrl.agents.core import AgentHead,AgentBase
from fastrl.pipes.core import find_dp
from fastrl.memory.experience_replay import ExperienceReplay
from fastrl.agents.core import StepFieldSelector,SimpleModelRunner,NumpyConverter
from fastrl.agents.discrete import EpsilonCollector,PyPrimativeConverter,ArgMaxer,EpsilonSelector
from fastrl.loggers.core import (
    Record,BatchCollector,EpochCollector,RollingTerminatedRewardCollector,EpisodeCollector,ProgressBarLogger
)
from fastrl.learner.core import LearnerBase,LearnerHead,StepBatcher
from fastrl.torch_core import Module

# DQN Basic
> Core DQN modules, pipes, and tooling

## Model

In [None]:
#|export
class DQN(Module):
    def __init__(
            self,
            state_sz:int,  # The input dim of the state
            action_sz:int, # The output dim of the actions
            hidden=512,    # Number of neurons connected between the 2 input/output layers
            head_layer:Module=nn.Linear, # DQN extensions such as Dueling DQNs have custom heads
            activition_fn:Module=nn.ReLU # The activiation fn used by `DQN`
        ):
        self.layers=nn.Sequential(
            nn.Linear(state_sz,hidden),
            activition_fn(),
            head_layer(hidden,action_sz),
        )
    def forward(self,x): return self.layers(x)


## Agent

In [None]:
#|export
DataPipeAugmentationFn = Callable[[DataPipe],Optional[DataPipe]]

def DQNAgent(
    model,
    min_epsilon=0.02,
    max_epsilon=1,
    max_steps=1000,
    device='cpu',
    do_logging:bool=False
)->AgentHead:
    agent_base = AgentBase(model)
    agent_base = StepFieldSelector(agent_base,field='next_state')
    agent_base = SimpleModelRunner(agent_base).to(device=device)
    agent,raw_agent = agent_base.fork(2)
    agent = agent.map(torch.clone)
    agent = ArgMaxer(agent)
    agent = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)
    if do_logging: 
        agent = EpsilonCollector(agent).catch_records()
    agent = ArgMaxer(agent,only_idx=True)
    agent = NumpyConverter(agent)
    agent = PyPrimativeConverter(agent)
    agent = agent.zip(raw_agent)
    agent = AgentHead(agent)
    return agent

In [None]:
torch.manual_seed(0)
model = DQN(4,2)

agent = DQNAgent(model)

In [None]:
from fastcore.all import test_eq
from fastrl.core import SimpleStep

In [None]:
input_tensor = torch.tensor([1,2,3,4]).float()
step = SimpleStep(next_state=input_tensor)

for i in range(10):
    for action in agent([step]):
        print(action)
    
test_eq(input_tensor,torch.tensor([1., 2., 3., 4.]))

In [None]:
from fastrl.envs.gym import GymDataPipe

In [None]:

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)

agent = DQNAgent(model,do_logging=True)

pipe = GymDataPipe(['CartPole-v1']*1,agent=agent,n=10)
pipe = BatchCollector(pipe,batches=5)
pipe = EpochCollector(pipe,epochs=10).dump_records()
# Setup Logger
pipe = ProgressBarLogger(pipe)

list(pipe);

## Training DataPipes

In [None]:
#|export
class QCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.done_mask = batch.terminated.reshape(-1,)
            self.learner.next_q = self.learner.model(batch.next_state)
            self.learner.next_q = self.learner.next_q.max(dim=1).values.reshape(-1,1)
            self.learner.next_q[self.learner.done_mask] = 0 
            yield batch

In [None]:
#|export
class TargetCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,discount=0.99,nsteps=1):
        self.source_datapipe = source_datapipe
        self.discount = discount
        self.nsteps = nsteps
        self.learner = None
        
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.targets = batch.reward+self.learner.next_q*(self.discount**self.nsteps)
            self.learner.pred = self.learner.model(batch.state)
            self.learner.target_qs = self.learner.pred.clone().float()
            self.learner.target_qs.scatter_(1,batch.action.long(),self.learner.targets.float())
            yield batch

In [None]:
#|export
class LossCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,loss_func):
        self.source_datapipe = source_datapipe
        self.loss_func = loss_func
        
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.loss_grad = self.loss_func(self.learner.pred, self.learner.target_qs)
            yield batch

In [None]:
#|export
class ModelLearnCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe, opt):
        self.source_datapipe = source_datapipe
        self.opt = opt
        
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.loss_grad.backward()
            self.opt.step()
            self.opt.zero_grad()
            self.learner.loss = self.learner.loss_grad.clone()
            yield self.learner.loss

In [None]:
#|export
class LossCollector(dp.iter.IterDataPipe):
    title:str='loss'

    def __init__(self,
            source_datapipe, # The parent datapipe, likely the one to collect metrics from
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        
        
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        yield Record(self.title,None)
        for i,steps in enumerate(self.source_datapipe):
            yield Record('loss',self.learner.loss.cpu().detach().numpy())
            yield steps

In [None]:
#|export
def DQNLearner(
    model,
    dls,
    do_logging:bool=True,
    loss_func=nn.MSELoss(),
    opt=optim.AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None,
    batches=None
) -> LearnerHead:
    learner = LearnerBase(model,dls[0])
    learner = BatchCollector(learner,batches=batches)
    learner = EpochCollector(learner)
    if do_logging: 
        learner = learner.dump_records()
        learner = ProgressBarLogger(learner)
        learner = RollingTerminatedRewardCollector(learner)
        learner = EpisodeCollector(learner)
    learner = learner.catch_records(drop=not do_logging)

    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,freeze_memory=True)
    learner = StepBatcher(learner,device=device)
    learner = QCalc(learner)
    learner = TargetCalc(learner,nsteps=nsteps)
    learner = LossCalc(learner,loss_func=loss_func)
    learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
    if do_logging: 
        learner = LossCollector(learner).catch_records()

    if len(dls)==2:
        val_learner = LearnerBase(model,dls[1]).visualize_vscode()
        val_learner = BatchCollector(val_learner,batches=batches)
        val_learner = EpochCollector(val_learner).dump_records()
        learner = LearnerHead((learner,val_learner))
    else:
        learner = LearnerHead(learner)
    return learner

Try training with basic defaults...

In [None]:
from fastrl.dataloading.core import dataloaders
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe

In [None]:
#|eval:false
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2).cuda()
# Setup the Agent
agent = DQNAgent(model,do_logging=True,max_steps=4000,device='cuda')
# Setup the DataBlock
params = dict(source=['CartPole-v1']*1,agent=agent,nsteps=1,nskips=1,firstlast=False,bs=1)
dls = dataloaders((
    GymDataPipe(**params), GymDataPipe(**params,include_images=True)
))

# Setup the Learner
learner = DQNLearner(
    model,
    dls,
    batches=1000,
    bs=128,
    max_sz=1000,
    device='cuda'
)
# learner.fit(1)
learner.fit(5)

In [None]:
learner.validate(2)

If we try a regular DQN with nsteps/nskips it doesnt really converge after 130. We cannot expect stability at all, and nsteps (correctly) tries to reduce to number of duplicated states so that the agent can sample more unique state transitions. The problem with this is the base dqn is not stable, so giving it lots of "new" unique state transitions do not help. In otherwords, its going to forget the old stuff very quickly, and having duplicate states helps "remind it"

In [None]:
#|hide
#|eval: false
!nbdev_export