In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp agents.dqn.asynchronous

In [132]:
#|export
# Python native modules
import os
from collections import deque
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
import torch.multiprocessing as mp
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *

from fastai.torch_basics import *
from fastai.torch_core import *
# Local modules

from fastrl.core import *
from fastrl.agents.core import *
from fastrl.pipes.core import *
from fastrl.fastai.data.block import *
from fastrl.memory.experience_replay import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.agents.dqn.basic import *

In [170]:
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe

In [172]:
IterDataPipe.getstate_hook

# DQN Async
> Components that allow for syncing multiple dqn agents on multiple processes to calcualtions on the
main process.



## Training DataPipes

In [184]:
#|export
class ModelSubscriber(dp.iter.IterDataPipe):
    "If an agent is passed to another process and 'spawn' start method is used, then this module is needed."
    def __init__(self,
                 source_datapipe
                ): 
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_pipe_instance(self.source_datapipe,AgentBase).model
        self.main_queue = self.initialize_queue()
        
    def initialize_queue(self):
        "If the start method is `spawn` then the queue will need to be managed using a Manager."
        if mp.get_start_method()=='spawn':
            ctx = mp.get_context('spawn')
            manager = ctx.Manager()
            queue = manager.Queue()
            return queue
        else:
            return mp.Queue()
    
    def __iter__(self):
        for x in self.source_datapipe:
            if not self.main_queue.empty():
                state = self.main_queue.get(timeout=1)
                self.model.load_state_dict(state)
            yield x

In [215]:
#|export
class ModelPublisher(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,
                 agents=None,
                 publish_freq:int=1
                ):
        super().__init__()
        self.source_datapipe = source_datapipe
        if not isinstance(agents,(list,tuple)): raise ValueError(f'Agents must be a list or tuple, not {type(agents)}')
        self.queues = [find_pipe_instance(agent,ModelSubscriber).main_queue for agent in agents]
        self.model = find_pipe_instance(self,LearnerBase).model
        self.publish_freq = publish_freq
                
    def __iter__(self):
        for i,batch in enumerate(self.source_datapipe):
            if i%self.publish_freq==0:
                for q in self.queues: 
                    with torch.no_grad():
                        q.put(deepcopy(self.model).cpu().state_dict())
            yield batch

In [216]:
#|export
def DQNLearner(
    model,
    dls,
    agent,
    logger_bases=None,
    loss_func=MSELoss(),
    opt=AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None
) -> LearnerHead:
    learner = LearnerBase(model,dls,loss_func=MSELoss(),opt=opt(model.parameters(),lr=lr))
    learner = ModelPublisher(learner,agent)
    learner = BatchCollector(learner,logger_bases=logger_bases,batch_on_pipe=LearnerBase)
    learner = EpocherCollector(learner,logger_bases=logger_bases)
    for logger_base in L(logger_bases): learner = logger_base.connect_source_datapipe(learner)
    if logger_bases: 
        learner = RollingTerminatedRewardCollector(learner,logger_bases)
        learner = EpisodeCollector(learner,logger_bases)
    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz,clone_detach=dls[0].num_workers>0)
    learner = StepBatcher(learner,device=device)
    learner = QCalc(learner,nsteps=nsteps)
    learner = ModelLearnCalc(learner)
    if logger_bases: 
        learner = LossCollector(learner,logger_bases)
    learner = LearnerHead(learner)
    return learner

In [217]:
#|export   
def DQNAgent(
    model,
    logger_bases=None,
    min_epsilon=0.02,
    max_epsilon=1,
    max_steps=1000,
    device='cpu'
)->AgentHead:
    agent = AgentBase(model)
    agent = StepFieldSelector(agent,field='state')
    agent = ModelSubscriber(agent)
    agent = SimpleModelRunner(agent,device=device)
    agent = ArgMaxer(agent)
    selector = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)
    if logger_bases is not None: agent = EpsilonCollector(selector,logger_bases)
    agent = ArgMaxer(agent,only_idx=True)
    agent = NumpyConverter(agent)
    agent = PyPrimativeConverter(agent)
    agent = AgentHead(agent)
    return agent

Try training with basic defaults...

In [218]:
import torch
from torch.nn import *
import torch.nn.functional as F
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.fastai.data.block import *
from fastrl.envs.gym import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from torch.utils.data.dataloader_experimental import DataLoader2

logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                 batch_on_pipe=BatchCollector)

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2).cuda()
# model.share_memory() # This will not work in spawn
# Setup the Agent
agent = DQNAgent(model,[logger_base],max_steps=4000,device='cuda')
# Setup the DataBlock
block = DataBlock(
    blocks = GymTransformBlock(agent=agent,
                               nsteps=1,nskips=1,firstlast=False
                              )
)
# pipes = L(block.datapipes(['CartPole-v1']*1,n=10))
dls = L(block.dataloaders(['CartPole-v1']*1,n=1000,bs=1,num_workers=0))
# # Setup the Learner
learner = DQNLearner(model,dls,[agent],logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda')
# learner.fit(20)

In [219]:
%%writefile external_run_scripts/agents_dqn_async_35.py
# %%python

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    
    try:
        set_start_method('spawn')
    except RuntimeError:
        pass
    
    from fastcore.all import *
    import torch
    from torch.nn import *
    import torch.nn.functional as F
    from fastrl.loggers.core import *
    from fastrl.loggers.jupyter_visualizers import *
    from fastrl.learner.core import *
    from fastrl.fastai.data.block import *
    from fastrl.envs.gym import *
    from fastrl.agents.core import *
    from fastrl.agents.discrete import *
    from fastrl.agents.dqn.basic import *
    from fastrl.agents.dqn.asynchronous import *
    from torch.utils.data.dataloader_experimental import DataLoader2
    
    logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                     batch_on_pipe=BatchCollector)

    # Setup up the core NN
    torch.manual_seed(0)
    model = DQN(4,2).cuda()
    # model.share_memory() # This will not work in spawn
    # Setup the Agent
    agent = DQNAgent(model,[logger_base],max_steps=8000,device='cuda')
    # Setup the DataBlock
    block = DataBlock(
        blocks = GymTransformBlock(agent=agent,
                                   nsteps=1,nskips=1,firstlast=False,
                                   dl_type=partial(DataLoader2,persistent_workers=True
                                                  )
                                  )
    )
    # pipes = L(block.datapipes(['CartPole-v1']*1,n=10))
    dls = L(block.dataloaders(['CartPole-v1']*1,n=1000,bs=1,num_workers=1))
    print('persistent workers: ',dls[0].persistent_workers)
    # # Setup the Learner
    learner = DQNLearner(model,dls,[agent],logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda')
    learner.fit(20)

Overwriting external_run_scripts/agents_dqn_async_35.py


In [220]:
#|hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()