In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.dqn.target

In [None]:
#|export
# Python native modules
from copy import deepcopy
from typing import Optional,Callable,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import traverse_dps,DataPipe
import torch
from torch import nn,optim
# Local modulesf
from fastrl.pipes.core import find_dp
from fastrl.memory.experience_replay import ExperienceReplay
from fastrl.loggers.core import BatchCollector,EpochCollector
from fastrl.learner.core import LearnerBase,LearnerHead
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe
from fastrl.loggers.core import ProgressBarLogger
from fastrl.agents.dqn.basic import (
    LossCollector,
    RollingTerminatedRewardCollector,
    EpisodeCollector,
    StepBatcher,
    TargetCalc,
    LossCalc,
    ModelLearnCalc,
    DQN,
    DQNAgent
)

In [None]:
#|hide
import logging
from fastrl.core import default_logging

In [None]:
#|hide
logging.basicConfig(**default_logging())

# DQN Target
> DQN that uses snapshots from the NN module to stabilize training



## Training DataPipes

In [None]:
#|export
class TargetModelUpdater(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,target_sync=300):
        self.source_datapipe = source_datapipe
        self.target_sync = target_sync
        self.n_batch = 0
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        with torch.no_grad():
            self.learner.target_model = deepcopy(self.learner.model)
        
    def reset(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        with torch.no_grad():
            self.learner.target_model = deepcopy(self.learner.model)
        
    def __iter__(self):
        if self._snapshot_state.NotStarted: 
            self.reset()
        for batch in self.source_datapipe:
            if self.n_batch%self.target_sync==0:
                with torch.no_grad():
                    self.learner.target_model.load_state_dict(self.learner.model.state_dict())
            self.n_batch+=1
            yield batch

In [None]:
#|export
class TargetModelQCalc(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None):
        self.source_datapipe = source_datapipe
        
    def __iter__(self):
        self.learner = find_dp(traverse_dps(self),LearnerBase)
        for batch in self.source_datapipe:
            self.learner.done_mask = batch.terminated.reshape(-1,)
            with torch.no_grad():
                self.learner.next_q = self.learner.target_model(batch.next_state)
            self.learner.next_q = self.learner.next_q.max(dim=1).values.reshape(-1,1)
            self.learner.next_q[self.learner.done_mask] = 0 
            yield batch

In [None]:
#|export
def DQNTargetLearner(
    model,
    dls,
    do_logging:bool=True,
    loss_func=nn.MSELoss(),
    opt=optim.AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None,
    batches=None,
    target_sync=300
) -> LearnerHead:
    learner = LearnerBase(model,dls=dls[0])
    learner = BatchCollector(learner,batches=batches)
    learner = EpochCollector(learner)
    if do_logging: 
        learner = learner.dump_records()
        learner = ProgressBarLogger(learner)
        learner = RollingTerminatedRewardCollector(learner)
        learner = EpisodeCollector(learner).catch_records()
    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
    learner = StepBatcher(learner,device=device)
    learner = TargetModelQCalc(learner)
    learner = TargetCalc(learner,nsteps=nsteps)
    learner = LossCalc(learner,loss_func=loss_func)
    learner = ModelLearnCalc(learner,opt=opt(model.parameters(),lr=lr))
    learner = TargetModelUpdater(learner,target_sync=target_sync)
    if do_logging: 
        learner = LossCollector(learner).catch_records()

    if len(dls)==2:
        val_learner = LearnerBase(model,dls[1]).visualize_vscode()
        val_learner = BatchCollector(val_learner,batches=batches)
        val_learner = EpochCollector(val_learner).catch_records(drop=True)
        return LearnerHead((learner,val_learner))
    else:
        return LearnerHead(learner)

In [None]:
from fastrl.envs.gym import GymDataPipe
from fastrl.dataloading.core import dataloaders

In [None]:
#|eval:false
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the Agent
agent = DQNAgent(model,do_logging=True,min_epsilon=0.02,max_epsilon=1,max_steps=5000)
# Setup the DataLoader
params = dict(
    source=['CartPole-v1']*1,
    agent=agent,
    nsteps=2,
    nskips=2,
    firstlast=True
)
dls = dataloaders((GymDataPipe(**params),GymDataPipe(**params,include_images=True).unbatch()))
# Setup the Learner
learner = DQNTargetLearner(
    model,
    dls,
    bs=128,
    max_sz=100_000,
    nsteps=2,
    lr=0.01,
    batches=1000,
    target_sync=300
)
# learner.fit(7)
learner.fit(1)

In [None]:
#|eval:false
learner.validate(2)

In [None]:
#|eval:false
sample_run = learner.validate(2,show=False,return_outputs=True)

In [None]:
from fastrl.agents.core import AgentHead,AgentBase
from fastrl.agents.core import SimpleModelRunner
from fastrl.memory.memory_visualizer import MemoryBufferViewer
from fastrl.agents.core import StepFieldSelector

In [None]:
#|eval:false
def DQNValAgent(
    model,
    device='cpu'
)->AgentHead:
    agent_base = AgentBase(model)
    agent = StepFieldSelector(agent_base,field='state')
    agent = SimpleModelRunner(agent).to(device=device)
    agent = AgentHead(agent)
    return agent

val_agent = DQNValAgent(model)
MemoryBufferViewer(sample_run,val_agent)

In [None]:
#|hide
#|eval: false
!nbdev_export