In [1]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [2]:
#|default_exp loggers.core

In [3]:
#|export
# Python native modules
from typing import Optional,List,Any,Iterable
from collections import deque
# Third party libs
from fastcore.all import add_docs,merge,ifnone
# from torch.multiprocessing import Pool,Process,set_start_method,Manager,get_start_method,Queue
import torchdata.datapipes as dp
from fastprogress.fastprogress import master_bar,progress_bar
from torchdata.dataloader2.graph import find_dps,traverse_dps,list_dps
# from torch.utils.data.datapipes._hook_iterator import _SnapshotState
import numpy as np
# Local modules
# from fastrl.core import *
from fastrl.pipes.core import find_dp,find_dps
from fastrl.core import StepType,Record
# from fastrl.torch_core import *

# Loggers Core
> Utilities used for handling log messages and display over multiple processes.

In [4]:
#|export
def not_record(data:Any):
    "Intended for use with dp.iter.Filter"
    return type(data)!=Record

In [5]:
import torch
from fastcore.all import test_eq
from fastrl.core import test_in,test_out

In [6]:
input_pipe = dp.iter.IterableWrapper([
    torch.ones((1,1)),
    torch.ones((1,1)),
    torch.ones((1,1)),
    Record('loss',0.5),
    torch.ones((1,1)),
    Record('loss',0.5),
    torch.ones((1,1))
]).filter(not_record)

test_eq(len(list(input_pipe)),5)
test_in(torch.ones((1,1)),list(input_pipe))
test_out(Record('loss',0.5),list(input_pipe))

In [7]:
#|export
class LoggerBase(object):
    debug:bool
    buffer:list
    source_datapipe:dp.iter.IterDataPipe
    
    def dequeue(self): 
        while self.buffer: yield self.buffer.pop(0)
    
    # def reset(self):
        # Note: trying to decide if this is really needed.
        # if self.debug:
        #     print(self,' resetting buffer.')
        # if self._snapshot_state!=_SnapshotState.Restored:
        #     self.buffer = []
        
add_docs(
    LoggerBase,
    """The `LoggerBase` class is an iterface for datapipes that also collect `Record` objects
    for logging purposes.
    """,
    dequeue="Empties the `self.buffer` yielding each of its contents."
)        

In [8]:
class A(dp.iter.IterDataPipe,LoggerBase):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.buffer = []

logger_base = A([1,2,3,4])

traverse_dps(logger_base)

{140355346068912: (A, {})}

In [28]:
#|export
class LogCollector(object):
    debug:bool=False
    title:Optional[str] = None
    main_buffers:Optional[List] = None        

    def enqueue_title(self):
        "Sends a empty `Record` to tell all the `LoggerBase`s of the `LogCollector's` existance."
        for q in self.main_buffers: 
            q.append(Record(self.title,None))
    
    def enqueue_value(
        self,
        value:Any
    ):
        "Sends a `Record` with `value` to all `LoggerBase`s"
        for q in self.main_buffers:
            q.append(Record(self.title,value))

    def reset(self):
        if self.main_buffers is None:
            if self.debug: print(f'Resetting {self}')
            logger_bases = list_dps(traverse_dps(self))
            logger_bases = [o for o in logger_bases if isinstance(o,LoggerBase)]
            self.main_buffers = [o.buffer for o in logger_bases]
            self.enqueue_title()

add_docs(
LogCollector,
"""`LogCollector` specifically manages finding and attaching itself to
`LoggerBase`s found earlier in the pipeline.""",
reset="Grabs buffers from all logger bases in the pipeline."
)  


In [37]:
class A(LoggerBase,dp.iter.IterDataPipe):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.buffer = []

    def __iter__(self):
        for o in self.source_datapipe:
            yield o

class B(LogCollector,dp.iter.IterDataPipe):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe

    def __iter__(self):
        for o in self.source_datapipe:
            self.reset()
            self.enqueue_value(o)
            yield o

logger_base = A([1,2,3,4])
collector = B(logger_base)

# Collect data from collector to trigger the enqueue methods
data_collected = list(collector)

# Check if data is passed through
test_eq(data_collected, [1, 2, 3, 4])

# Check if logger_base has received records
records = logger_base.buffer

# Check if titles and values are recorded correctly
expected_records = [
    Record(name=None, value=None),  # The title is recorded first
    Record(name=None, value=1),
    # Record(title=None, value=None),  # The title is recorded every time before the value
    Record(name=None, value=2),
    # Record(title=None, value=None),
    Record(name=None, value=3),
    # Record(title=None, value=None),
    Record(name=None, value=4)
]

test_eq(records, expected_records)


Notes:

User can init multiple different logger bases if they want

We then can manually add Collectors, custom for certain pipes such as for collecting rewards. 

In [40]:
#|export
class EpocherCollector(dp.iter.IterDataPipe,LogCollector):
    debug:bool=False
    title:str='epoch'

    def __init__(self,
            source_datapipe,
            # Epochs is the number of times we iterate, and exhaust `source_datapipe`.
            # This is expected behavior of more traditional dataset iteration where
            # an epoch is a single full run through of a dataset.
            epochs:int=0
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        self.iteration_started = False
        self.epochs = epochs
        self.epoch = 0

    def __iter__(self): 
        for i in range(self.epochs):
            self.reset() 
            self.epoch = i
            self.enqueue_value(self.epoch)
            yield from self.source_datapipe
            
add_docs(
EpocherCollector,
"""Tracks the number of epochs that the pipeline is currently on.""",
reset="Grabs buffers from all logger bases in the pipeline."
)

In [41]:
# Define a mock source_datapipe
source_datapipe = dp.iter.IterableWrapper([1, 2, 3, 4, 5])

# Define some mock LoggerBases with buffers
class A(dp.iter.IterDataPipe,LoggerBase):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.buffer = []

    def __iter__(self):
        yield from self.source_datapipe

logger1 = A([])
logger2 = A([])

source_datapipe = source_datapipe.concat(logger1,logger2)

# Create an EpocherCollector with 3 epochs and two loggers
epochs = 3
collector = EpocherCollector(source_datapipe=source_datapipe, epochs=epochs)
collector.main_buffers = [logger1.buffer, logger2.buffer]

# Define a function to collect data from the collector
def collect_data(collector):
    return [item for item in collector]

# Collect data
data = collect_data(collector)

# Test whether the source_datapipe data was yielded correctly for each epoch
test_eq(data, [1, 2, 3, 4, 5]*epochs)

# Test whether the epoch was correctly pushed to the main buffers of the logger bases
test_eq([record.value for record in logger1.buffer], list(range(epochs)))
test_eq([record.value for record in logger2.buffer], list(range(epochs)))


In [52]:
#|export
class BatchCollector(LogCollector):
    title:str='batch'

    def __init__(self,
            source_datapipe,
            batches:Optional[int]=None,
            # If `batches` is None, `BatchCollector` with try to find: `batch_on_pipe` instance
            # and try to grab a `batches` field from there.
            batch_on_pipe:dp.iter.IterDataPipe=None 
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        self.iteration_started = False
        self.batches = (
            batches if batches is not None else self.batch_on_pipe_get_batches(batch_on_pipe)
        )
        self.batch = 0
        
    def batch_on_pipe_get_batches(self,batch_on_pipe):
        pipe = find_dp(traverse_dps(self.source_datapipe),batch_on_pipe)
        if hasattr(pipe,'batches'):
            return pipe.batches
        elif hasattr(pipe,'limit'):
            return pipe.limit
        else:
            raise RuntimeError(f'Pipe {pipe} isnt recognized as a batch tracker.')

    def __iter__(self): 
        self.batch = 0
        for batch,record in enumerate(self.source_datapipe): 
            yield record
            if type(record)!=Record:
                self.batch += 1
                self.enqueue_value(self.batch)
            if self.batch>=self.batches: 
                break

add_docs(
BatchCollector,
"""Tracks the number of batches that the pipeline is currently on.""",
batch_on_pipe_get_batches="Gets the number of batches from `batch_on_pipe`",
reset="Grabs buffers from all logger bases in the pipeline."
)

In [54]:
# Define a mock source_datapipe
source_datapipe = [1, 2, 3, 4, 5]

# Define some mock LoggerBases with buffers
class A(dp.iter.IterDataPipe,LoggerBase):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.buffer = []

    def __iter__(self):
        yield from self.source_datapipe
        
logger1 = A([])
logger2 = A([])

# Create a BatchCollector with 3 batches and two loggers
batches = 3
collector = BatchCollector(source_datapipe=source_datapipe, batches=batches)
collector.main_buffers = [logger1.buffer, logger2.buffer]

# Define a function to collect data from the collector
def collect_data(collector):
    return [item for item in collector]

# Collect data
data = collect_data(collector)

# Test whether the source_datapipe data was yielded correctly for each batch
test_eq(data, source_datapipe[:batches])

# Test whether the batch was correctly pushed to the main buffers of the logger bases
test_eq([record.value for record in logger1.buffer], list(range(1, batches+1)))
test_eq([record.value for record in logger2.buffer], list(range(1, batches+1)))

# Test behavior with batch_on_pipe
source_datapipe = dp.iter.IterableWrapper([1, 2, 3, 4, 5])

class B(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe,batches):
        self.source_datapipe = source_datapipe
        self.batches = batches

    def __iter__(self):
        yield from self.source_datapipe

source_datapipe = B(source_datapipe,batches=4)
collector_with_pipe = BatchCollector(source_datapipe=source_datapipe, batch_on_pipe=B)
test_eq(collector_with_pipe.batches, 4)

In [None]:
#|export
class ProgressBarLogger(LoggerBase):
    debug:bool=False

    def __init__(self,
                 # This does not need to be immediately set since we need the `LogCollectors` to 
                 # first be able to reference its queues.
                 source_datapipe=None, 
                 # For automatic pipe attaching, we can designate which pipe this should be
                 # referneced for information on which epoch we are on
                 epoch_on_pipe:dp.iter.IterDataPipe=EpochCollector,
                 # For automatic pipe attaching, we can designate which pipe this should be
                 # referneced for information on which batch we are on
                 batch_on_pipe:dp.iter.IterDataPipe=BatchCollector
                ):
        super().__init__(source_datapipe=source_datapipe)
        self.epoch_on_pipe = epoch_on_pipe
        self.batch_on_pipe = batch_on_pipe
        
        self.collector_keys = None
        self.attached_collectors = None
    
    def __iter__(self):
        epocher = find_dp(traverse_dps(self),self.epoch_on_pipe)
        batcher = find_dp(traverse_dps(self),self.batch_on_pipe)
        mbar = master_bar(range(epocher.epochs)) 
        pbar = progress_bar(range(batcher.batches),parent=mbar,leave=False)

        mbar.update(0)
        i = 0
        for record in self.source_datapipe:
            if self.filter_record(record):
                self.buffer.append(record)
                # We only want to start setting up logging when the data loader starts producing 
                # real data.
                continue
   
            if i==0:
                self.attached_collectors = {o.name:o.value for o in self.dequeue()}
                if self.debug: print('Got initial values: ',self.attached_collectors)
                mbar.write(self.attached_collectors, table=True)
                self.collector_keys = list(self.attached_collectors)
                pbar.update(0)
                    
            attached_collectors = {o.name:o.value for o in self.dequeue()}
            if self.debug: print('Got running values: ',self.attached_collectors)

            if attached_collectors:
                self.attached_collectors = merge(self.attached_collectors,attached_collectors)
            
            if 'batch' in attached_collectors: 
                pbar.update(attached_collectors['batch'])
                
            if 'epoch' in attached_collectors:
                mbar.update(attached_collectors['epoch'])
                collector_values = {k:self.attached_collectors.get(k,None) for k in self.collector_keys}
                mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in collector_values.values()], table=True)
                
            i+=1  
            yield record

        attached_collectors = {o.name:o.value for o in self.dequeue()}
        if attached_collectors: self.attached_collectors = merge(self.attached_collectors,attached_collectors)

        collector_values = {k:self.attached_collectors.get(k,None) for k in self.collector_keys}
        mbar.write([f'{l:.6f}' if isinstance(l, float) else str(l) for l in collector_values.values()], table=True)

        pbar.on_iter_end()
        mbar.on_iter_end()
            

In [None]:
epoch_pipe = dp.iter.IterableWrapper(range(10, 100))
batch_pipe = dp.iter.IterableWrapper(range(10, 100))

pbl = ProgressBarLogger(source_datapipe=None, epoch_on_pipe=epoch_pipe, batch_on_pipe=batch_pipe)

test_eq(pbl.epoch_on_pipe, epoch_pipe)
test_eq(pbl.batch_on_pipe, batch_pipe)
test_eq(pbl.attached_collectors, None)

In [None]:
#|export
class RewardCollector(LogCollector):
    title:str='reward'

    def __iter__(self):
        for i,steps in enumerate(self.source_datapipe):
            # if i==0: self.push_title('reward')
            if isinstance(steps,dp.DataChunk):
                for step in steps:
                    for q in self.main_buffers: q.append(Record('reward',step.reward.detach().numpy()))
            else:
                for q in self.main_buffers: q.append(Record('reward',steps.reward.detach().numpy()))
            yield steps

In [None]:
#|export
class EpisodeCollector(LogCollector):
    title:str='episode'
    
    def episode_detach(self,step): 
        try:
            v = step.episode_n.cpu().detach().numpy()
            if len(v.shape)==0: return int(v)
            return v[0]
        except IndexError:
            print(f'Got IndexError getting episode_n which is unexpected: \n{step}')
            raise
    
    def __iter__(self):
        for i,steps in enumerate(self.source_datapipe):
            # if i==0: self.push_title('episode')
            if isinstance(steps,dp.DataChunk):
                for step in steps:
                    for q in self.main_buffers: q.append(Record('episode',self.episode_detach(step)))
            else:
                for q in self.main_buffers: q.append(Record('episode',self.episode_detach(steps)))
            yield steps

add_docs(
EpisodeCollector,
"""Collects the `episode_n` field from steps.""",
episode_detach="Moves the `episode_n` tensor to numpy.",
)

In [None]:
#|export
class RollingTerminatedRewardCollector(LogCollector):
    debug:bool=False
    title:str='rolling_reward'

    def __init__(self,
         source_datapipe, # The parent datapipe, likely the one to collect metrics from
         rolling_length:int=100
        ):
        self.source_datapipe = source_datapipe
        self.main_buffers = None
        self.rolling_rewards = deque([],maxlen=rolling_length)
        
    def step2terminated(self,step): return bool(step.terminated)

    def reward_detach(self,step): 
        try:
            v = step.total_reward.cpu().detach().numpy()
            if len(v.shape)==0: return float(v)
            return v[0]
        except IndexError:
            print(f'Got IndexError getting reward which is unexpected: \n{step}')
            raise

    def __iter__(self):
        for i,steps in enumerate(self.source_datapipe):
            if self.debug: print(f'RollingTerminatedRewardCollector: ',steps)
            if isinstance(steps,dp.DataChunk):
                for step in steps:
                    if self.step2terminated(step):
                        self.rolling_rewards.append(self.reward_detach(step))
                        for q in self.main_buffers: q.append(Record('rolling_reward',np.average(self.rolling_rewards)))
            elif self.step2terminated(steps):
                self.rolling_rewards.append(self.reward_detach(steps))
                for q in self.main_buffers: q.append(Record('rolling_reward',np.average(self.rolling_rewards)))
            yield steps

add_docs(
RollingTerminatedRewardCollector,
"""Collects the `total_reward` field from steps if `terminated` is true and 
logs a rolling average of size `rolling_length`.""",
reward_detach="Moves the `total_reward` tensor to numpy.",
step2terminated="Casts the `terminated` field in steps to a bool"
)

In [None]:
from torchdata.dataloader2.dataloader2 import DataLoader2
import fastrl.pipes.iter.cacheholder
# from fastrl.data.dataloader2 import *
# import pandas as pd
from fastrl.envs.gym import *
import gymnasium as gym
# from fastrl.pipes.map.transforms import *

In [None]:
envs = ['CartPole-v1']*10

logger_base = ProgressBarLogger(batch_on_pipe=BatchCollector,epoch_on_pipe=EpocherCollector)

pipe = dp.iter.IterableWrapper(envs)
pipe = pipe.map(gym.make)
pipe = pipe.pickleable_in_memory_cache()
pipe = LoggerBasePassThrough(pipe,[logger_base])
pipe = dp.iter.InMemoryCacheHolder(pipe)
pipe = pipe.cycle()
pipe = GymStepper(pipe,synchronized_reset=True)
pipe = RewardCollector(pipe)
# pipe = InputInjester(pipe)
# pipe = TestSync(pipe)
pipe = pipe.header(limit=10)


pipe = BatchCollector(pipe,batch_on_pipe=dp.iter.Header)
pipe = EpocherCollector(pipe,epochs=5)
pipe = logger_base.connect_source_datapipe(pipe)
# Turn off the seed so that some envs end before others...
steps = list(pipe)

epoch,batch,reward
1,10,1.0
2,10,1.0
3,10,1.0
4,10,1.0
4,10,1.0


  state=torch.tensor(step.next_state),


In [None]:
# dl = DataLoader2(
#     pipe,
#     reading_service=PrototypeMultiProcessingReadingService(
#         num_workers = 1,
#         protocol_client_type = InputItemIterDataPipeQueueProtocolClient,
#         protocol_server_type = InputItemIterDataPipeQueueProtocolServer,
#         pipe_type = item_input_pipe_type,
#         eventloop = SpawnProcessForDataPipeline
#     )
# )

# # dl = logger_base.connect_source_datapipe(dl)

In [None]:
# #|export
# class ActionPublish(dp.iter.IterDataPipe):
#     def __init__(self,
#             source_datapipe, # Pretend this is in the middle of a learner training segment
#             dls
#         ):
#         self.source_datapipe = source_datapipe
#         self.dls = dls
#         self.protocol_clients = []
#         self._expect_response = []
#         self.initialized = False
        
#     def __iter__(self): 
#         for step in self.source_datapipe:
#             if not self.initialized:
#                 for dl in self.dls:
#                     # dataloader.IterableWrapperIterDataPipe._IterateQueueDataPipes,[QueueWrappers]
#                     for q_wrapper in dl.datapipe.iterable.datapipes:
#                         self.protocol_clients.append(q_wrapper.protocol)
#                         self._expect_response.append(False)
#                 self.initialized = True
            
#             if isinstance(step,StepType):
#                 for i,client in enumerate(self.protocol_clients):
#                     if self._expect_response[i]: 
#                         client.get_response_input_item()
#                     else:
#                         client.request_input_item(
#                             'action_augmentation',value=100
#                         )

#             yield step
#         self.protocol_clients = []
#         self._expect_response = []
# add_docs(
#     ActionPublish,
#     """Publishes an action augmentation to the dataloader."""
# )

In [None]:
# #|hide
# #|eval: false
# learn_pipe = ActionPublish(dl,[dl])

# for o in learn_pipe:pass
#     # print('Final Output',o)

# # for i,o in enumerate(dl):
# #     learn_pipe.source_datapipe.append(o)
    
# #     if i==0: print(dl.datapipe)
# #     print(o)

In [None]:
#|export
class CacheLoggerBase(LoggerBase):
    "Short lived logger base meant to dump logs"
    def reset(self):
        # This logger will be exhausted frequently if used in an agent.
        # We need to get the buffer alive so we dont lose reference
        pass
    
    def __iter__(self):
        print('Iterating through buffer of len: ',len(self.buffer))
        yield from self.buffer
        self.buffer.clear()

In [None]:
#|hide
#|eval: false
!nbdev_export