In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp agents.dqn.asynchronous

In [4]:
#|export
# Python native modules
import os
from collections import deque
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
import torch.multiprocessing as mp
import torch
from torch.nn import *
import torch.nn.functional as F
from torch.optim import *

from fastai.torch_basics import *
from fastai.torch_core import *
from torchdata.dataloader2.graph import find_dps,traverse
# Local modules

from fastrl.core import *
from fastrl.agents.core import *
from fastrl.pipes.core import *
from fastrl.fastai.data.block import *
from fastrl.memory.experience_replay import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.agents.dqn.basic import *
from fastrl.dataloader2_ext import *
from torchdata.dataloader2 import DataLoader2

# DQN Async
> Components that allow for syncing multiple dqn agents on multiple processes to calcualtions on the
main process.

There is a little weirdness using cuda with spawn. pytorch has a bug: https://github.com/pytorch/pytorch/issues/30401 so queue usage isnt so simple



## Training DataPipes

In [31]:
#|export
class ModelSubscriber(dp.iter.IterDataPipe):
    "If an agent is passed to another process and 'spawn' start method is used, then this module is needed."
    def __init__(self,
                 source_datapipe,
                 device:str='cpu'
                ): 
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self.source_datapipe,only_datapipe=True),AgentBase).model
        self.buffer = []
        self.device = device
        
    def __iter__(self):
        for x in self.source_datapipe:
            # print('ModelSubscriber',x)
            if type(x)==GetInputItemRequest and x.key.startswith('model_state_dict_pubish'):
            # if self.buffer:
                # print('ModelSubscriber: got x: ',x)
                # state = self.buffer.pop(0)
                self.model.load_state_dict(x.value)
                self.model.to(device=torch.device(self.device))
                continue
            yield x

In [36]:
#|export
class ModelPublisher(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,
                 publish_freq:int=1
                ):
        super().__init__()
        self.source_datapipe = source_datapipe
        self.model = find_dp(traverse(self,only_datapipe=True),LearnerBase).model
        self.publish_freq = publish_freq
        self.protocol_clients = []
        self._expect_response = []
        self.initialized = False
 
    def reset(self):
        if not self.initialized:
            for dl in find_dp(traverse(self,only_datapipe=True),LearnerBase).iterable:
                # dataloader.IterableWrapperIterDataPipe._IterateQueueDataPipes,[QueueWrappers]
                for q_wrapper in dl.datapipe.iterable.datapipes:
                    self.protocol_clients.append(q_wrapper.protocol)
                    self._expect_response.append(False)
            self.initialized = True

    def __iter__(self):
        for i,batch in enumerate(self.source_datapipe):
            # print('ModelPublisher: was called')
            #  (this batch we should publish) and (there are protocols) and (there are some that are ready)
            if type(batch)==str and batch.startswith('model_state_dict_pubish'): 
                client_num = int(batch.replace('model_state_dict_pubish_',''))
                try:
                    if self._expect_response[client_num]:
                        self.protocol_clients[client_num].get_response_input_item()
                except Exception as e:
                    print(f'failed on batch: {batch} num {client_num}')
                    raise
                continue
            if i%self.publish_freq==0 and self.protocol_clients and not all(self._expect_response):
                # print('PUBLISHING!!!')
                with torch.no_grad():
                    # We need to deepcopy the model itself since `cpu` is an inplace op.
                    # We cant keep the model in cuda because mp.Manager passes around the 
                    # tensors too much and causes errors ref: https://github.com/pytorch/pytorch/issues/30401
                    # This is alos why we cant just call state_dict directly. It returns references
                    # to cuda tensors.
                    state = deepcopy(self.model).cpu().state_dict()

                    for i,client in enumerate(self.protocol_clients):
                        if not self._expect_response[i]: 
                            client.request_input_item(
                                key=f'model_state_dict_pubish_{i}',value=state
                            )
                # print('batch: ',batch)
            yield batch
        
        self.protocol_clients = []
        self._expect_response = []

In [27]:
#|export
def DQNLearner(
    model,
    dls,
    agent,
    logger_bases=None,
    loss_func=MSELoss(),
    opt=AdamW,
    lr=0.005,
    bs=128,
    max_sz=10000,
    nsteps=1,
    device=None,
    batches=1000
) -> LearnerHead:
    learner = LearnerBase(model,dls,batches=batches,loss_func=MSELoss(),opt=opt(model.parameters(),lr=lr))
    learner = ModelPublisher(learner)
    learner = BatchCollector(learner,logger_bases=logger_bases,batch_on_pipe=LearnerBase)
    learner = EpocherCollector(learner,logger_bases=logger_bases)
    for logger_base in L(logger_bases): learner = logger_base.connect_source_datapipe(learner)
    if logger_bases: 
        learner = RollingTerminatedRewardCollector(learner,logger_bases)
        learner = EpisodeCollector(learner,logger_bases)
    learner = ExperienceReplay(learner,bs=bs,max_sz=max_sz #,clone_detach=dls[0].num_workers>0
                              )
    learner = StepBatcher(learner,device=device)
    learner = QCalc(learner,nsteps=nsteps)
    learner = ModelLearnCalc(learner)
    if logger_bases: 
        learner = LossCollector(learner,logger_bases)
    learner = LearnerHead(learner)
    return learner

In [28]:
#|export   
def DQNAgent(
    model,
    logger_bases=None,
    min_epsilon=0.02,
    max_epsilon=1,
    max_steps=1000,
    device='cpu'
)->AgentHead:
    agent = AgentBase(model)
    agent = StepFieldSelector(agent,field='state')
    agent = InputInjester(agent)
    agent = ModelSubscriber(agent,device=device)
    agent = SimpleModelRunner(agent,device=device)
    agent = ArgMaxer(agent)
    selector = EpsilonSelector(agent,min_epsilon=min_epsilon,max_epsilon=max_epsilon,max_steps=max_steps,device=device)
    if logger_bases is not None: agent = EpsilonCollector(selector,logger_bases)
    agent = ArgMaxer(agent,only_idx=True)
    agent = NumpyConverter(agent)
    agent = PyPrimativeConverter(agent)
    agent = AgentHead(agent)
    return agent

Try training with basic defaults...

In [11]:
import torch
from torch.nn import *
import torch.nn.functional as F
from fastrl.loggers.core import *
from fastrl.loggers.jupyter_visualizers import *
from fastrl.learner.core import *
from fastrl.fastai.data.block import *
from fastrl.envs.gym import *
from fastrl.agents.core import *
from fastrl.agents.discrete import *
from torch.utils.data.dataloader_experimental import DataLoader2

logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                 batch_on_pipe=BatchCollector)

# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2).cuda()
# model.share_memory() # This will not work in spawn
# Setup the Agent
agent = DQNAgent(model,[logger_base],max_steps=4000,device='cuda')
# Setup the DataBlock
block = DataBlock(
    blocks = GymTransformBlock(agent=agent,
                               nsteps=1,nskips=1,firstlast=False,
                               # dl_type=partial(DataLoader2,persistent_workers=True)
                              )
)
# pipes = L(block.datapipes(['CartPole-v1']*1,n=10))
dls = L(block.dataloaders(['CartPole-v1']*1,bs=1,num_workers=0))
# # Setup the Learner
learner = DQNLearner(model,dls,[agent],batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda')
learner.fit(2)

loss,rolling_reward,epoch,batch,epsilon


AttributeError: 'DataLoader' object has no attribute 'datapipe'
This exception is thrown by __iter__ of BatchCollector()

In [34]:
%%writefile external_run_scripts/agents_dqn_async_35.py
# %%python

if __name__=='__main__':
    from torch.multiprocessing import Pool, Process, set_start_method
    
    try:
        set_start_method('spawn')
    except RuntimeError:
        pass
    
    from fastcore.all import *
    import torch
    from torch.nn import *
    import torch.nn.functional as F
    from fastrl.loggers.core import *
    from fastrl.loggers.jupyter_visualizers import *
    from fastrl.learner.core import *
    from fastrl.fastai.data.block import *
    from fastrl.envs.gym import *
    from fastrl.agents.core import *
    from fastrl.agents.discrete import *
    from fastrl.agents.dqn.basic import *
    from fastrl.agents.dqn.asynchronous import *
    
    from torchdata.dataloader2 import DataLoader2
    from fastrl.dataloader2_ext import *
    
    logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                     batch_on_pipe=BatchCollector)
    
    # RollingTerminatedRewardCollector.debug=True

    # Setup up the core NN
    torch.manual_seed(0)
    model = DQN(4,2).cuda()
    # model.share_memory() # This will not work in spawn
    # Setup the Agent
    agent = DQNAgent(model,[logger_base],max_steps=8000,device='cuda')
    # Setup the DataBlock
    block = DataBlock(
        blocks = GymTransformBlock(agent=agent,
                                   nsteps=1,nskips=1,firstlast=False
                                  )
    )
    pipe = L(block.datapipes(['CartPole-v1']*1))
    
    dl = DataLoader2(
        pipe[0],
        reading_service=PrototypeMultiProcessingReadingService(
            num_workers = 1,
            # persistent_workers=True,
            protocol_client_type = InputItemIterDataPipeQueueProtocolClient,
            protocol_server_type = InputItemIterDataPipeQueueProtocolServer,
            pipe_type = item_input_pipe_type,
            eventloop = SpawnProcessForDataPipeline
        )
    )

    dls = [dl]
    
    # from torchdata.dataloader2.graph import find_dps,traverse
    # print(traverse(dls[0].datapipe))
    
    # dls = L(block.dataloaders(['CartPole-v1']*1,n=1000,bs=1,num_workers=1))
    # print('persistent workers: ',dls[0].persistent_workers)
    # # Setup the Learner
    learner = DQNLearner(model,dls,[agent],batches=1000,logger_bases=[logger_base],bs=128,max_sz=100_000,device='cuda')
    learner.fit(20)

Overwriting external_run_scripts/agents_dqn_async_35.py


In [35]:
#|hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()