In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp agents.dqn.asynchronous

In [4]:
#|export
# Python native modules
import os
from collections import deque
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
import torch.multiprocessing as mp
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *

from fastai.torch_basics import *
from fastai.torch_core import *
from torchdata.dataloader2.graph import find_dps,traverse
# Local modules

from fastrl.core import *
from fastrl.agents.core import *
from fastrl.pipes.core import *
from fastrl.fastai.data.block import *
from fastrl.memory.experience_replay import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.agents.dqn.basic import *
from fastrl.dataloader2_ext import *
from torchdata.dataloader2 import DataLoader2,DataLoader2Iterator

# DQN Async
> Components that allow for syncing multiple dqn agents on multiple processes to calcualtions on the
main process.

There is a little weirdness using cuda with spawn. pytorch has a bug: https://github.com/pytorch/pytorch/issues/30401 so queue usage isnt so simple



## Training DataPipes

In [5]:
#|export
class ModelSubscriber(dp.iter.IterDataPipe):
    "If an agent is passed to another process and 'spawn' start method is used, then this module is needed."
    def __init__(self,source_datapipe): 
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self.source_datapipe,only_datapipe=True),AgentBase).model
        
    def __iter__(self):
        for x in self.source_datapipe:
            if type(x)==GetInputItemRequest and x.key.startswith('model_state_dict_publish_'):
                self.model.load_state_dict(x.value)
                continue
            yield x

In [6]:
#|export
class ModelPublisher(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe,
            publish_freq:int=1,
            # Sometimes its not possible to share current model due to cuda issues.
            # `do_deepcopy` will copy and move the model to cpu in order to publish it.
            do_deepcopy:bool=False
        ):
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self,only_datapipe=True),LearnerBase).model
        self.publish_freq = publish_freq
        self.protocol_clients = []
        self._expect_response = []
        self.initialized = False
        self.do_deepcopy = do_deepcopy
 
    def _reset(self):
        for dl in find_dp(traverse(self,only_datapipe=True),LearnerBase).iterable:
            for q_wrapper in dl.datapipe.iterable.datapipes:
                self.protocol_clients.append(q_wrapper.protocol)
                self._expect_response.append(False)
        self.initialized = True

    def __iter__(self):
        for i,batch in enumerate(self.source_datapipe):
            if not self.initialized: self._reset()
            #  (this batch we should publish) and (there are protocols) and (there are some that are ready)
            if type(batch)==GetInputItemResponse and batch.value.startswith('model_state_dict_publish_'): 
                client_num = int(batch.value.replace('model_state_dict_publish_',''))

                if self._expect_response[client_num]:
                    self._expect_response[client_num] = False

                continue
            if i%self.publish_freq==0 and self.protocol_clients and not all(self._expect_response):
                with torch.no_grad():
                    # We need to deepcopy the model itself since `cpu` is an inplace op.
                    # We cant keep the model in cuda because mp.Manager passes around the 
                    # tensors too much and causes errors ref: https://github.com/pytorch/pytorch/issues/30401
                    # This is alos why we cant just call state_dict directly. It returns references
                    # to cuda tensors.
                    if self.do_deepcopy:
                        state = deepcopy(self.model).to(device=self.device).state_dict()
                    else:
                        state = self.model.state_dict()

                        
                    for client_id,client in enumerate(self.protocol_clients):
                        if not self._expect_response[client_id]: 
                            # print('PUBLISHING!!!!')
                            client.request_input_item(
                                key=f'model_state_dict_publish_{client_id}',value=state
                            )
                            self._expect_response[client_id] = True
            yield batch

In [7]:
#|export
def DQNLearner(
    model,
    dls,
    agent,
    logger_bases=None,
    loss_func=MSELoss(),
    opt=AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None,
    batches=1000,
    publish_freq=1
) -> LearnerHead:
    learner = LearnerBase(model,dls,batches=batches,loss_func=MSELoss(),opt=opt(model.parameters(),lr=lr))
    learner = ModelPublisher(learner,publish_freq=publish_freq)
    learner = BatchCollector(learner,logger_bases=logger_bases,batch_on_pipe=LearnerBase)
    learner = EpocherCollector(learner,logger_bases=logger_bases)
    for logger_base in L(logger_bases): learner = logger_base.connect_source_datapipe(learner)
    if logger_bases: 
        learner = RollingTerminatedRewardCollector(learner,logger_bases)
        learner = EpisodeCollector(learner,logger_bases)
    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz)
    learner = StepBatcher(learner,device=device)
    learner = QCalc(learner,nsteps=nsteps)
    learner = ModelLearnCalc(learner)
    if logger_bases: 
        learner = LossCollector(learner,logger_bases)
    learner = LearnerHead(learner)
    return learner

In [20]:
#|export
class EpsilonCollector(dp.iter.IterDataPipe):
    def __init__(self,
         source_datapipe, # The parent datapipe, likely the one to collect metrics from
         logger_bases:List[LoggerBase] # `LoggerBase`s that we want to send metrics to
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = [o.buffer for o in logger_bases]
        
    def __iter__(self):
        for q in self.main_buffers: q.append(Record('epsilon',None))
        for action in self.source_datapipe:
            for q in self.main_buffers: 
                q.append(Record('epsilon',self.source_datapipe.epsilon))
            yield action

class CacheLoggerBase(LoggerBase):
    "Short lived logger base meant to dump logs"
    def reset(self):
        # This logger will be exhausted frequently if used in an agent.
        # We need to get the buffer alive so we dont lose reference
        pass
    
    def __iter__(self):
        print('Iterating through buffer of len: ',len(self.buffer))
        yield from self.buffer
        self.buffer.clear()
        # for element in self.source_datapipe:
        #     while self.buffer: 
        #         yield self.buffer.pop(0)
        #     yield element

In [18]:
#|export   
def DQNAgent(
    model,
    logger_bases=None,
    min_epsilon=0.02,
    max_epsilon=1,
    max_steps=1000,
    device='cpu'
)->AgentHead:
    _logger_bases = None
    if logger_bases is None: 
        _logger_bases = CacheLoggerBase()
    agent = AgentBase(model)
    agent = StepFieldSelector(agent,field='state')
    agent = InputInjester(agent)
    agent = ModelSubscriber(agent)
    agent = SimpleModelRunner(agent,device=device)
    agent = ArgMaxer(agent)
    selector = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)

    agent = EpsilonCollector(selector,ifnone(logger_bases,[_logger_bases]))
    # if logger_bases is None:
    #     agent = _logger_bases.connect_source_datapipe(agent)

    agent = ArgMaxer(agent,only_idx=True)
    agent = NumpyConverter(agent)
    agent = PyPrimativeConverter(agent)
    agent = AgentHead(agent,_logger_bases)
    return agent

Try training with basic defaults...

In [10]:
import torch
from torch.nn import *
import torch.nn.functional as F
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.fastai.data.block import *
from fastrl.envs.gym import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from torch.utils.data.dataloader_experimental import DataLoader2

logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                 batch_on_pipe=BatchCollector)

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2).cuda()
# model.share_memory() # This will not work in spawn
# Setup the Agent
agent = DQNAgent(model,max_steps=4000,device='cuda')
# Setup the DataBlock
block = DataBlock(
    blocks = GymTransformBlock(agent=agent,
                               nsteps=1,nskips=1,firstlast=False,
                               # dl_type=partial(DataLoader2,persistent_workers=True)
                              )
)
# pipes = L(block.datapipes(['CartPole-v1']*1,n=10))
dls = L(block.dataloaders(['CartPole-v1']*1,bs=1,num_workers=0))
# # Setup the Learner
learner = DQNLearner(model,dls,[agent],batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda')
learner.fit(2)

AttributeError: 'DataLoader' object has no attribute 'datapipe'
This exception is thrown by __iter__ of ModelPublisher()

In [21]:
%%writefile external_run_scripts/agents_dqn_async_35.py
# %%python

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    
    try:
        set_start_method('spawn')
    except RuntimeError:
        pass
    
    from fastcore.all import *
    import torch
    from torch.nn import *
    import torch.nn.functional as F
    from fastrl.loggers.core import *
    from fastrl.loggers.jupyter_visualizers import *
    from fastrl.learner.core import *
    from fastrl.fastai.data.block import *
    from fastrl.envs.gym import *
    from fastrl.agents.core import *
    from fastrl.agents.discrete import *
    from fastrl.agents.dqn.basic import *
    from fastrl.agents.dqn.asynchronous import *
    
    from torchdata.dataloader2 import DataLoader2
    from fastrl.dataloader2_ext import *
    
    logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                     batch_on_pipe=BatchCollector)
    
    # RollingTerminatedRewardCollector.debug=True

    # Setup up the core NN
    torch.manual_seed(0)
    model = DQN(4,2).cuda()
    # model.share_memory() # This will not work in spawn
    # Setup the Agent
    agent = DQNAgent(model,max_steps=8000,device='cuda')
    # Setup the DataBlock
    block = DataBlock(
        blocks = GymTransformBlock(agent=agent,
                                   nsteps=1,nskips=1,firstlast=False
                                  )
    )
    pipe = L(block.datapipes(['CartPole-v1']*1))
    
    dl = DataLoader2(
        pipe[0],
        reading_service=PrototypeMultiProcessingReadingService(
            num_workers = 1,
            # persistent_workers=True,
            protocol_client_type = InputItemIterDataPipeQueueProtocolClient,
            protocol_server_type = InputItemIterDataPipeQueueProtocolServer,
            pipe_type = item_input_pipe_type,
            eventloop = SpawnProcessForDataPipeline
        )
    )

    dls = [dl]
    
    # from torchdata.dataloader2.graph import find_dps,traverse
    # print(traverse(dls[0].datapipe))
    
    # dls = L(block.dataloaders(['CartPole-v1']*1,n=1000,bs=1,num_workers=1))
    # print('persistent workers: ',dls[0].persistent_workers)
    # # Setup the Learner
    learner = DQNLearner(model,dls,[agent],batches=1000,logger_bases=[logger_base],
                         publish_freq=10,
                         bs=128,max_sz=100_000,device='cuda')
    learner.fit(20)

Overwriting external_run_scripts/agents_dqn_async_35.py


In [22]:
#|hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()