In [1]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [2]:
#|default_exp learner.core

In [56]:
#|export
# Python native modules
import os
from contextlib import contextmanager
from typing import List,Union,Dict,Optional,Iterable
# Third party libs
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import list_dps 
import torch
from torch import nn
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.graph import traverse_dps,DataPipeGraph,DataPipe
# Local modules
from fastrl.torch_core import evaluating
from fastrl.pipes.core import find_dp
from fastrl.loggers.core import Record,EpochCollector

# Learner Core
> Core DataPipes for building Learners

In [65]:
#|export
class LearnerBase(dp.iter.IterDataPipe):
    def __init__(self,
            # The base NN that we getting raw action values out of.
            # This can either be a `nn.Module` or a dict of multiple `nn.Module`s
            # For multimodel training
            model:Union[nn.Module,Dict[str,nn.Module]], 
            # The dataloaders to read data from for training. This can be a single
            # DataLoader2 or an iterable that yields from a DataLoader2.
            fit_dls:Union[DataLoader2,Iterable], 
            # The dataloaders to read data from for validation. This can be a single
            # DataLoader2 or an iterable that yields from a DataLoader2.
            val_dls:Optional[Union[DataLoader2,Iterable]]=None, 
            # By default for reinforcement learning, we want to keep the workers
            # alive so that simluations are not being shutdown / restarted.
            # Epochs are expected to be handled semantically via tracking the number 
            # of batches.
            infinite_dls:bool=True
    ):
        self.model = model
        self.fit_iterable = fit_dls
        self.val_iterable = val_dls
        self.learner_base = self
        self.infinite_dls = infinite_dls
        self._dls = None
        self._ended = False
        self._validating = False

    def __getstate__(self):
        state = {k:v for k,v in self.__dict__.items() if k not in ['_dls']}
        # TODO: Needs a better way to serialize / deserialize states.
        # state['iterable'] = [d.state_dict() for d in state['iterable']]
        if dp.iter.IterDataPipe.getstate_hook is not None:
            return dp.iter.IterDataPipe.getstate_hook(state)

        return state

    def __setstate__(self, state):
        # state['iterable'] = [d.from_state_dict() for d in state['iterable']]
        for k,v in state.items():
            setattr(self,k,v)

    def end(self):
        self._ended = True
   
    def __iter__(self):
        self._ended = False
        for data in (self.val_iterable if self._validating else self.fit_iterable):
            if self._ended:
                break
            yield data

add_docs(
LearnerBase,
"Combines models,dataloaders, and optimizers together for running a training pipeline.",
reset="""If `infinite_dls` is false, then all dls will be reset, otherwise they will be
kept alive.""",
end="When called, will cause the Learner to stop iterating and cleanup."
)

In [68]:
#|export
class LearnerHead(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.learner_base = find_dp(traverse_dps(self.source_datapipe),LearnerBase)

    def __iter__(self): yield from self.source_datapipe
    
    def fit(self,epochs):
        epocher = find_dp(traverse_dps(self),EpochCollector)
        epocher.epochs = epochs
        
        for iteration in self: 
            pass

    def validate(self,epochs=1,show=True) -> DataPipe:
        with evaluating(self.learner_base.model):
            try:
                self.learner_base._validating = True
                epocher = find_dp(traverse_dps(self),EpochCollector)
                epocher.epochs = epochs
                for iteration in self: 
                    pass
            finally:
                self.learner_base._validating = False

            if show:
                pipes = list_dps(traverse_dps(self.learner_base.val_iterable))
                for pipe in pipes:
                    if hasattr(pipe,'show'):
                        return pipe.show() 
        
add_docs(
LearnerHead,
"""
""",
fit="Runs the `LearnerHead` pipeline for `epochs`",
validate="""If there is more than 1 dl, then run 1 epoch of that dl based on 
`dl_idx` and returns the original datapipe for displaying."""
)  

In [69]:

from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService
from fastrl.loggers.core import EpochCollector

In [70]:
def source_make_dataloaders(source):
    pipe = dp.iter.IterableWrapper(source)
    dl = DataLoader2(pipe,
        reading_service=MultiProcessingReadingService(num_workers = 0)
    )
    return dp.iter.IterableWrapper(dl)

class Printer(dp.iter.IterDataPipe):
    def __init__(self,pipe): 
        self.pipe = pipe

    def __iter__(self):
        for o in self.pipe:
            print(o, end=" ")
            yield o
        print()

def TestLearner(dls,val_dls):
    learner = LearnerBase(nn.Module(),dls,val_dls=val_dls)
    learner = Printer(learner)
    learner = EpochCollector(learner)
    learner = LearnerHead(learner)
    return learner

dl1 = source_make_dataloaders(range(10))
dl2 = source_make_dataloaders(range(10,20))
dl3 = source_make_dataloaders(range(20,30))
print('Concated Dataloaders')
dls = dp.iter.Concater(dl1,dl2)

learn = TestLearner(dls,dl3)
learn.fit(5)

print("Validating Concated Dataloaders:")
learn.validate(1)  # using one epoch for validation by default

dl1 = source_make_dataloaders(range(10))
dl2 = source_make_dataloaders(range(10,20))
dl3 = source_make_dataloaders(range(20,30))
print('Muxed Dataloaders')
dls = dp.iter.Multiplexer(dl1,dl2)

learn = TestLearner(dls,dl3)
learn.fit(5)

print("Validating Muxed Dataloaders:")
learn.validate(1)

Concated Dataloaders
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 
Validating Concated Dataloaders:
20 21 22 23 24 25 26 27 28 29 
Muxed Dataloaders
0 10 1 11 2 12 3 13 4 14 5 15 6 16 7 17 8 18 9 19 
0 10 1 11 2 12 3 13 4 14 5 15 6 16 7 17 8 18 9 19 
0 10 1 11 2 12 3 13 4 14 5 15 6 16 7 17 8 18 9 19 
0 10 1 11 2 12 3 13 4 14 5 15 6 16 7 17 8 18 9 19 
0 10 1 11 2 12 3 13 4 14 5 15 6 16 7 17 8 18 9 19 
Validating Muxed Dataloaders:
20 21 22 23 24 25 26 27 28 29 




> Warning: Pickling the LearnerBase will exclude the '_dls','opt','iterable' fields since
these aren't easily picklable (yet).

In [26]:
# from fastrl.torch_core import *
from fastrl.agents.dqn.basic import DQN
from fastrl.agents.core import AgentBase,AgentHead,StepFieldSelector,SimpleModelRunner,NumpyConverter
from fastrl.agents.discrete import ArgMaxer,PyPrimativeConverter
# from fastrl.data.block import *
from fastrl.envs.gym import GymDataPipe
# from fastrl.loggers.vscode_visualizers import *

In [27]:
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the agent
agent = AgentBase(model,[])
agent = StepFieldSelector(agent,field='state')
# All the things that make this agent unique and special
# In this instance, all this module does is pass the action directly through to the model.
agent = SimpleModelRunner(agent)
agent = ArgMaxer(agent,only_idx=True)
agent = NumpyConverter(agent)
agent = PyPrimativeConverter(agent)
# Bring everything together into the AgentHead where actions will be passed and then run through the pipeline
agent = AgentHead(agent)

In [28]:
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe
from torchdata.dataloader2 import DataLoader2,MultiProcessingReadingService

In [29]:
# Setup the DataBlock
# block = DataBlock(
#     GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,n=100,bs=1),
#     (GymTransformBlock(agent=agent,nsteps=1,nskips=1,firstlast=False,n=100,bs=1,include_images=True),VSCodeTransformBlock())
# )
def gym_block(elements,num_workers,vis=False):
    pipe = GymDataPipe(elements,agent=agent,nsteps=1,nskips=1,firstlast=False,include_images=True,n=100,bs=1)
    if vis:
        pipe = VSCodeDataPipe(pipe)
    dl = DataLoader2(pipe,
        reading_service=MultiProcessingReadingService(num_workers = num_workers)
    )
    return dl

dls = (gym_block(['CartPole-v1']*1,num_workers=0),gym_block(['CartPole-v1']*1,num_workers=0,vis=True))


In [30]:
def TestLearner(model,dls):
    learner = LearnerBase(model,dls)
    learner = LearnerHead(learner)
    return learner

In [31]:
import pickle

In [32]:
learner = TestLearner(model,dls)

out = pickle.dumps(learner)
pickle.loads(out)

LearnerHead

In [33]:
#|hide
#|eval: false
learner.validate(1)

TypeError: 'NoneType' object is not subscriptable

In [None]:
#|export
class StepBatcher(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe,
            device=None
        ):
        self.source_datapipe = source_datapipe
        self.device = device
        
    def vstack_by_fld(self,batch,fld):
        try:
            if self.device is None: return torch.vstack(tuple(getattr(step,fld) for step in batch))
            return torch.vstack(tuple(getattr(step,fld) for step in batch)).to(torch.device(self.device))
        except RuntimeError as e:
            print(f'Failed to stack {fld} given batch: {batch}')
            raise
        
    def __iter__(self):
        for batch in self.source_datapipe:
            cls = batch[0].__class__
            yield cls(**{fld:self.vstack_by_fld(batch,fld) for fld in cls._fields})

add_docs(
StepBatcher,
"Converts multiple `StepType` into a single `StepType` with the fields concated.",
vstack_by_fld="vstacks a `fld` in `batch`"
)

In [None]:
#|hide
#|eval: false
!nbdev_export

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
