In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp loggers.core

In [4]:
#|export
# Python native modules
import os,typing
# Third party libs
from fastcore.all import *
from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method,Queue
import torchdata.datapipes as dp
from fastprogress.fastprogress import *
from torchdata.dataloader2.graph import find_dps,traverse
from torch.utils.data.datapipes._hook_iterator import _SnapshotState
# Local modules
from fastrl.core import *
from fastrl.pipes.core import *

# Loggers Core
> Utilities used for handling log messages and display over multiple processes.

In [5]:
#|export
class LoggerBase(dp.iter.IterDataPipe):
    debug:bool=False
    
    def __init__(self,source_datapipe=None,do_filter=True):
        self.source_datapipe = source_datapipe
        self.buffer = []
        self.do_filter = do_filter
        
    def connect_source_datapipe(self,pipe):
        self.source_datapipe = pipe
        return self
    
    def filter_record(self,record):
        return type(record)==Record and self.do_filter
    
    def dequeue(self): 
        while self.buffer: yield self.buffer.pop(0)
    
    def reset(self):
        # We can chain multiple `LoggerBase`s together, but if we do this, we dont want the 
        # first one in the chain filtering out the Records before the others!
        if issubclass(type(self.source_datapipe),LoggerBase):
            self.source_datapipe.do_filter = False
        # Note: trying to decide if this is really needed.
        # if self.debug:
        #     print(self,' resetting buffer.')
        # if self._snapshot_state!=_SnapshotState.Restored:
        #     self.buffer = []
    
    def __iter__(self):
        raise NotImplementedError
        
add_docs(
    LoggerBase,
    """The `LoggerBase` class outlines simply the `buffer`. 
    It works in combo with `LogCollector` datapipe which will add to the `buffer`.
    
    `LoggerBase` also filters out the log records to as to not disrupt the training pipeline""",
    filter_record="Returns True of `record` is actually a record and that `self` actually is set to filter.",
    connect_source_datapipe="""`LoggerBase` does not need to be part of a `DataPipeGraph` 
    when its initialized, so this method allows for inserting into a `DataPipeGraph` later on.""",
    reset="""Checks if `self.source_datapipe` is also a logger base, and if so will tell `self.source_datapipe`
    not to filter out the log records.""",
    dequeue="Empties the `self.buffer` yielding each of its contents."
)        

In [6]:
#|export
class LoggerBasePassThrough(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,logger_bases=None):
        self.source_datapipe = source_datapipe
        self.logger_bases = logger_bases

    def __iter__(self):
        yield from self.source_datapipe

add_docs(
LoggerBasePassThrough,
"""Allows for collectors to find `LoggerBase`s early in the pipeline without
worrying about accidently iterating the logger bases at the incorrect time/frequency.

This is mainly used for collectors to call `find_dps` easily on the pipeline.
"""
)    

In [7]:
logger_base = LoggerBase()

In [8]:
traverse(logger_base)

{140572219072976: (LoggerBase, {})}

In [9]:
#|export
class LogCollector(dp.iter.IterDataPipe):
    debug:bool=False
    header:Optional[str]=None

    def __init__(self,
         source_datapipe, # The parent datapipe, likely the one to collect metrics from
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        
    def __iter__(self): raise NotImplementedError

    def push_header(
            self,
            key:str
        ):
        # self.reset()
        for q in self.main_buffers: q.append(Record(key,None))

    def reset(self):
        if self.main_buffers is None:
            if self.debug: print(f'Resetting {self}')
            logger_bases = find_dps(traverse(self),LoggerBase,include_subclasses=True)
            self.main_buffers = [o.buffer for o in logger_bases]
            self.push_header(self.header)

add_docs(
LogCollector,
"""`LogCollector` specifically manages finding and attaching itself to
`LoggerBase`s found earlier in the pipeline.""",
reset="Grabs buffers from all logger bases in the pipeline.",
push_header="""Should be called after the pipeline is initialized. Sends header
`Record`s to the `LoggerBase`s so they know what logs are going to be sent to them."""
)  


Notes:

User can init multiple different logger bases if they want

We then can manually add Collectors, custom for certain pipes such as for collecting rewards. 

In [10]:
#|export
class ProgressBarLogger(LoggerBase):
    debug:bool=False

    def __init__(self,
                 # This does not need to be immediately set since we need the `LogCollectors` to 
                 # first be able to reference its queues.
                 source_datapipe=None, 
                 # For automatic pipe attaching, we can designate which pipe this should be
                 # referneced for information on which epoch we are on
                 epoch_on_pipe:dp.iter.IterDataPipe=None,
                 # For automatic pipe attaching, we can designate which pipe this should be
                 # referneced for information on which batch we are on
                 batch_on_pipe:dp.iter.IterDataPipe=None
                ):
        super().__init__(source_datapipe=source_datapipe)
        self.epoch_on_pipe = epoch_on_pipe
        self.batch_on_pipe = batch_on_pipe
        
        self.collector_keys = None
        self.attached_collectors = None
    
    def __iter__(self):
        epocher = find_dp(traverse(self),self.epoch_on_pipe)
        batcher = find_dp(traverse(self),self.batch_on_pipe)
        mbar = master_bar(range(epocher.epochs)) 
        pbar = progress_bar(range(batcher.batches),parent=mbar,leave=False)

        mbar.update(0)
        i = 0
        for record in self.source_datapipe:
            if self.filter_record(record):
                self.buffer.append(record)
                # We only want to start setting up logging when the data loader starts producing 
                # real data.
                continue
   
            if i==0:
                self.attached_collectors = {o.name:o.value for o in self.dequeue()}
                if self.debug: print('Got initial values: ',self.attached_collectors)
                mbar.write(self.attached_collectors, table=True)
                self.collector_keys = list(self.attached_collectors)
                pbar.update(0)
                    
            attached_collectors = {o.name:o.value for o in self.dequeue()}
            if self.debug: print('Got running values: ',self.attached_collectors)

            if attached_collectors:
                self.attached_collectors = merge(self.attached_collectors,attached_collectors)
            
            if 'batch' in attached_collectors: 
                pbar.update(attached_collectors['batch'])
                
            if 'epoch' in attached_collectors:
                mbar.update(attached_collectors['epoch'])
                collector_values = {k:self.attached_collectors.get(k,None) for k in self.collector_keys}
                mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in collector_values.values()], table=True)
                
            i+=1  
            yield record

        attached_collectors = {o.name:o.value for o in self.dequeue()}
        if attached_collectors: self.attached_collectors = merge(self.attached_collectors,attached_collectors)

        collector_values = {k:self.attached_collectors.get(k,None) for k in self.collector_keys}
        mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in collector_values.values()], table=True)

        pbar.on_iter_end()
        mbar.on_iter_end()
            

In [11]:
#|export
class RewardCollector(LogCollector):
    header:str='reward'

    def __iter__(self):
        for i,steps in enumerate(self.source_datapipe):
            # if i==0: self.push_header('reward')
            if isinstance(steps,dp.DataChunk):
                for step in steps:
                    for q in self.main_buffers: q.append(Record('reward',step.reward.detach().numpy()))
            else:
                for q in self.main_buffers: q.append(Record('reward',steps.reward.detach().numpy()))
            yield steps

In [21]:
pip show torch

290.85s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Name: torch
Version: 1.13.0.dev20220819+cu113
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /opt/conda/lib/python3.7/site-packages
Requires: typing-extensions
Required-by: fastai, fastrl, torchvision, torchtext, torchelastic, torchdata
Note: you may need to restart the kernel to use updated packages.


In [20]:
pip show torchdata

279.20s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


Name: torchdata
Version: 0.5.0.dev20220819
Summary: Composable data loading modules for PyTorch
Home-page: https://github.com/pytorch/data
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD
Location: /opt/conda/lib/python3.7/site-packages
Requires: portalocker, requests, torch, urllib3
Required-by: 
Note: you may need to restart the kernel to use updated packages.


In [12]:
#|export
class EpocherCollector(LogCollector):
    debug:bool=False
    header:str='epoch'

    def __init__(self,
            source_datapipe,
            epochs:int=0,
            logger_bases:List[LoggerBase]=None # `LoggerBase`s that we want to send metrics to
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        self.iteration_started = False
        self.epochs = epochs
        self.epoch = 0

    def __iter__(self): 
        # if self.main_buffers is not None and not self.iteration_started:
        #     self.push_header('epoch')
        #     self.iteration_started = True
            
        for i in range(self.epochs): 
            self.epoch = i
            if self.main_buffers is not None:
                for q in self.main_buffers: q.append(Record('epoch',self.epoch))
            # print('yielding on epoch',self.epoch)
            yield from self.source_datapipe
            
add_docs(
EpocherCollector,
"""Tracks the number of epochs that the pipeline is currently on.""",
reset="Grabs buffers from all logger bases in the pipeline."
)

In [13]:
#|export
class BatchCollector(LogCollector):
    header:str='batch'

    def __init__(self,
            source_datapipe,
            batches:Optional[int]=None,
            # If `batches` is None, `BatchCollector` with try to find: `batch_on_pipe` instance
            # and try to grab a `batches` field from there.
            batch_on_pipe:dp.iter.IterDataPipe=None 
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        self.iteration_started = False
        self.batches = (
            batches if batches is not None else self.batch_on_pipe_get_batches(batch_on_pipe)
        )
        self.batch = 0
        # self.header = 'batch'
        
    def batch_on_pipe_get_batches(self,batch_on_pipe):
        pipe = find_dp(traverse(self.source_datapipe),batch_on_pipe)
        if hasattr(pipe,'batches'):
            return pipe.batches
        elif hasattr(pipe,'limit'):
            return pipe.limit
        else:
            raise RuntimeError(f'Pipe {pipe} isnt recognized as a batch tracker.')

    def __iter__(self): 
        # if self.main_buffers is not None and not self.iteration_started:
        #     self.push_header('batch')
        #     if self.debug: print('pushing batch',self.main_buffers)
        #     self.iteration_started = True
            
        self.batch = 0
        for batch,record in enumerate(self.source_datapipe): 
            yield record
            if type(record)!=Record:
                self.batch += 1
                if self.main_buffers is not None:
                    # print('posting batch values: ',Record('batch',self.batch))
                    for q in self.main_buffers: q.append(Record('batch',self.batch))
            if self.batch>self.batches: 
                break

add_docs(
BatchCollector,
"""Tracks the number of batches that the pipeline is currently on.""",
batch_on_pipe_get_batches="Gets the number of batches from `batch_on_pipe`",
reset="Grabs buffers from all logger bases in the pipeline."
)

In [14]:
#|export
class TestSync(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe
        ):
        self.source_datapipe = source_datapipe
        self.actions_augments = []
        
    def __iter__(self): 
        for step in self.source_datapipe:
            # print('Got step: ',step)
            if isinstance(step,GetInputItemRequest):
                # print('augmenting!!!!!')
                self.actions_augments.append(step.value)
                continue
            elif self.actions_augments:
                step = step.__class__(**{fld:getattr(step,fld)+self.actions_augments.pop(0) 
                                         if fld=='action' else 
                                         getattr(step,fld) for fld in step._fields})
            yield step
add_docs(
    TestSync,
    """Tests getting values from data loader requests."""
)

In [15]:
from fastrl.data.dataloader2 import *

In [16]:
LogCollector.debug=False
ProgressBarLogger.debug=False

In [None]:
import pandas as pd
from fastrl.envs.gym import *
from fastrl.pipes.map.transforms import *

In [18]:
envs = ['CartPole-v1']*10

logger_base = ProgressBarLogger(batch_on_pipe=BatchCollector,epoch_on_pipe=EpocherCollector)

pipe = dp.map.Mapper(envs)
pipe = TypeTransformer(pipe,[GymTypeTransform])
pipe = dp.iter.MapToIterConverter(pipe)
pipe = LoggerBasePassThrough(pipe,[logger_base])
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)
pipe = RewardCollector(pipe)
pipe = InputInjester(pipe)
pipe = TestSync(pipe)
pipe = pipe.header(limit=10)


pipe = BatchCollector(pipe,batch_on_pipe=dp.iter.Header)
pipe = EpocherCollector(pipe,epochs=5)
pipe = logger_base.connect_source_datapipe(pipe)
# Turn off the seed so that some envs end before others...
steps = list(pipe)

TypeError: 'NoneType' object is not iterable
This exception is thrown by __iter__ of RewardCollector()

In [None]:

from torchdata.dataloader2.dataloader2 import DataLoader2

In [None]:
dl = DataLoader2(
    pipe,
    reading_service=PrototypeMultiProcessingReadingService(
        num_workers = 1,
        protocol_client_type = InputItemIterDataPipeQueueProtocolClient,
        protocol_server_type = InputItemIterDataPipeQueueProtocolServer,
        pipe_type = item_input_pipe_type,
        eventloop = SpawnProcessForDataPipeline
    )
)

# dl = logger_base.connect_source_datapipe(dl)

In [None]:
#|export
from fastrl.core import StepType

class ActionPublish(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe, # Pretend this is in the middle of a learner training segment
            dls
        ):
        self.source_datapipe = source_datapipe
        self.dls = dls
        self.protocol_clients = []
        self._expect_response = []
        self.initialized = False
        
    def __iter__(self): 
        for step in self.source_datapipe:
            if not self.initialized:
                for dl in self.dls:
                    # dataloader.IterableWrapperIterDataPipe._IterateQueueDataPipes,[QueueWrappers]
                    for q_wrapper in dl.datapipe.iterable.datapipes:
                        self.protocol_clients.append(q_wrapper.protocol)
                        self._expect_response.append(False)
                self.initialized = True
            
            if isinstance(step,StepType):
                for i,client in enumerate(self.protocol_clients):
                    if self._expect_response[i]: 
                        client.get_response_input_item()
                    else:
                        client.request_input_item(
                            'action_augmentation',value=100
                        )

            yield step
        self.protocol_clients = []
        self._expect_response = []
add_docs(
    ActionPublish,
    """Publishes an action augmentation to the dataloader."""
)

In [None]:
learn_pipe = ActionPublish(dl,[dl])

for o in learn_pipe:pass
    # print('Final Output',o)

# for i,o in enumerate(dl):
#     learn_pipe.source_datapipe.append(o)
    
#     if i==0: print(dl.datapipe)
#     print(o)

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()