In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp agents.dqn.asynchronous

In [None]:
#|export
# Python native modules
import os
from typing import *
from collections import deque
from copy import deepcopy
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
import torch.multiprocessing as mp
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *
from torchdata.dataloader2.graph import find_dps,traverse,replace_dp,DataPipe
# Local modules
from fastrl.torch_core import *
from fastrl.core import *
from fastrl.agents.core import *
from fastrl.pipes.core import *
from fastrl.data.block import *
from fastrl.memory.experience_replay import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.agents.dqn.basic import *
from fastrl.data.dataloader2 import *
from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator

# DQN Async
> Components that allow for syncing multiple dqn agents on multiple processes to calcualtions on the
main process.

There is a little weirdness using cuda with spawn. pytorch has a bug: https://github.com/pytorch/pytorch/issues/30401 so queue usage isnt so simple



## Training DataPipes

In [None]:
#|export
class ModelSubscriber(dp.iter.IterDataPipe):
    "If an agent is passed to another process and 'spawn' start method is used, then this module is needed."
    def __init__(self,source_datapipe): 
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self.source_datapipe,only_datapipe=True),AgentBase).model
        
    def __iter__(self):
        for x in self.source_datapipe:
            if type(x)==GetInputItemRequest and x.key.startswith('model_state_dict_publish_'):
                self.model.load_state_dict(x.value)
                continue
            yield x

    @classmethod
    def insert_dp(cls,old_dp=InputInjester) -> Callable[[DataPipe],DataPipe]:
        def _insert_dp(pipe):
            v = replace_dp(
                traverse(pipe,only_datapipe=True),
                find_dp(traverse(pipe,only_datapipe=True),old_dp),
                cls(find_dp(traverse(pipe,only_datapipe=True),old_dp))
            )
            return list(v.values())[0][0]
        return _insert_dp

In [None]:
#|export
class ModelPublisher(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe,
            publish_freq:int=1,
            # Sometimes its not possible to share current model due to cuda issues.
            # `do_deepcopy` will copy and move the model to cpu in order to publish it.
            do_deepcopy:bool=False
        ):
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self,only_datapipe=True),LearnerBase).model
        self.publish_freq = publish_freq
        self.protocol_clients = []
        self._expect_response = []
        self.initialized = False
        self.do_deepcopy = do_deepcopy

    @classmethod
    def insert_dp(cls,old_dp=LoggerBasePassThrough,publish_freq=1) -> Callable[[DataPipe],DataPipe]:
        def _insert_dp(pipe):
            v = replace_dp(
                traverse(pipe,only_datapipe=True),
                find_dp(traverse(pipe,only_datapipe=True),old_dp),
                cls(find_dp(traverse(pipe,only_datapipe=True),old_dp),publish_freq=publish_freq)
            )
            return list(v.values())[0][0]
        return _insert_dp
 
    def _reset(self):
        for dl in find_dp(traverse(self,only_datapipe=True),LearnerBase).iterable:
            for q_wrapper in dl.datapipe.iterable.datapipes:
                self.protocol_clients.append(q_wrapper.protocol)
                self._expect_response.append(False)
        self.initialized = True

    def __iter__(self):
        for i,batch in enumerate(self.source_datapipe):
            # print('Got batch: ',batch)
            # print('running reset')
            if not self.initialized: self._reset()
            #  (this batch we should publish) and (there are protocols) and (there are some that are ready)
            if type(batch)==GetInputItemResponse and batch.value.startswith('model_state_dict_publish_'): 
                client_num = int(batch.value.replace('model_state_dict_publish_',''))

                if self._expect_response[client_num]:
                    self._expect_response[client_num] = False

                continue
            if i%self.publish_freq==0 and self.protocol_clients and not all(self._expect_response):
                with torch.no_grad():
                    # We need to deepcopy the model itself since `cpu` is an inplace op.
                    # We cant keep the model in cuda because mp.Manager passes around the 
                    # tensors too much and causes errors ref: https://github.com/pytorch/pytorch/issues/30401
                    # This is also why we cant just call state_dict directly. It returns references
                    # to cuda tensors.
                    if self.do_deepcopy:
                        state = deepcopy(self.model).to(device=self.device).state_dict()
                    else:
                        state = self.model.state_dict()

                        
                    for client_id,client in enumerate(self.protocol_clients):
                        if not self._expect_response[client_id]: 
                            # print('PUBLISHING!!!!')
                            client.request_input_item(
                                key=f'model_state_dict_publish_{client_id}',value=state
                            )
                            self._expect_response[client_id] = True
            yield batch

Try training with basic defaults...

In [None]:
#|eval:False
import torch
from torch.nn import *
import torch.nn.functional as F
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.data.block import *
from fastrl.envs.gym import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from fastrl.agents.dqn.basic import *

logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                 batch_on_pipe=BatchCollector)

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2).cuda()
# model.share_memory() # This will not work in spawn
# Setup the Agent
agent = DQNAgent(model,max_steps=4000,device='cuda',
                 dp_augmentation_fns=[ModelSubscriber.insert_dp()])
# Setup the DataBlock
block = DataBlock(
    GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False),
    GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,include_images=True)
)
dls = L(block.dataloaders(['CartPole-v1']*1,num_workers=1))
# # Setup the Learner
learner = DQNLearner(model,dls,batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda',
                     dp_augmentation_fns=[ModelPublisher.insert_dp()]
                     )
# learner.fit(2)

In [None]:
%%writefile ../../external_run_scripts/agents_dqn_async_35.py
# %%python

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    
    try:
        set_start_method('spawn')
    except RuntimeError:
        pass
    
    from fastcore.all import *
    import torch
    from torch.nn import *
    import torch.nn.functional as F
    from fastrl.loggers.core import *
    from fastrl.loggers.jupyter_visualizers import *
    from fastrl.learner.core import *
    from fastrl.data.block import *
    from fastrl.envs.gym import *
    from fastrl.agents.core import *
    from fastrl.agents.discrete import *
    from fastrl.agents.dqn.basic import *
    from fastrl.agents.dqn.asynchronous import *
    
    from torchdata.dataloader2 import DataLoader2
    from torchdata.dataloader2.graph import traverse
    from fastrl.data.dataloader2 import *
    
    logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                    batch_on_pipe=BatchCollector)

    # Setup up the core NN
    torch.manual_seed(0)
    model = DQN(4,2).cuda()
    # model.share_memory() # This will not work in spawn
    # Setup the Agent
    agent = DQNAgent(model,max_steps=4000,device='cuda',
                    dp_augmentation_fns=[ModelSubscriber.insert_dp()])
    # Setup the DataBlock
    block = DataBlock(
        GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False),
        GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,include_images=True)
    )
    dls = L(block.dataloaders(['CartPole-v1']*1,num_workers=1))
    # # Setup the Learner
    learner = DQNLearner(model,dls,batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda',
                        dp_augmentation_fns=[ModelPublisher.insert_dp(publish_freq=10)])
    # print(traverse(learner))
    learner.fit(20)

Overwriting ../../external_run_scripts/agents_dqn_async_35.py


In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()