In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp learner.core

In [None]:
#|export
# Python native modules
import os
from contextlib import contextmanager
from typing import List,Union,Dict,Optional,Iterable,Tuple
# Third party libs
from fastcore.all import add_docs
import torchdata.datapipes as dp
from torchdata.dataloader2.graph import list_dps 
import torch
from torch import nn
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.graph import traverse_dps,DataPipeGraph,DataPipe
# Local modules
from fastrl.torch_core import evaluating
from fastrl.pipes.core import find_dp
from fastrl.loggers.core import Record,EpochCollector,BatchCollector

# Learner Core
> Core DataPipes for building Learners

In [None]:
#|export
class LearnerBase(dp.iter.IterDataPipe):
    def __init__(self,
            # The base NN that we getting raw action values out of.
            # This can either be a `nn.Module` or a dict of multiple `nn.Module`s
            # For multimodel training
            model:Union[nn.Module,Dict[str,nn.Module]], 
            # The dataloaders to read data from for training. This can be a single
            # DataLoader2 or an iterable that yields from a DataLoader2.
            dls:Union[DataLoader2,Iterable], 
            # By default for reinforcement learning, we want to keep the workers
            # alive so that simluations are not being shutdown / restarted.
            # Epochs are expected to be handled semantically via tracking the number 
            # of batches.
            infinite_dls:bool=True
    ):
        self.model = model
        self.iterable = dls
        self.learner_base = self
        self.infinite_dls = infinite_dls
        self._dls = None
        self._ended = False

    def __getstate__(self):
        state = {k:v for k,v in self.__dict__.items() if k not in ['_dls']}
        # TODO: Needs a better way to serialize / deserialize states.
        # state['iterable'] = [d.state_dict() for d in state['iterable']]
        if dp.iter.IterDataPipe.getstate_hook is not None:
            return dp.iter.IterDataPipe.getstate_hook(state)
        return state

    def __setstate__(self, state):
        # state['iterable'] = [d.from_state_dict() for d in state['iterable']]
        for k,v in state.items():
            setattr(self,k,v)

    def end(self):
        self._ended = True
   
    def __iter__(self):
        self._ended = False
        for data in self.iterable:
            if self._ended:
                break
            yield data

add_docs(
LearnerBase,
"Combines models,dataloaders, and optimizers together for running a training pipeline.",
reset="""If `infinite_dls` is false, then all dls will be reset, otherwise they will be
kept alive.""",
end="When called, will cause the Learner to stop iterating and cleanup."
)

In [None]:
#|export
class LearnerHead(dp.iter.IterDataPipe):
    def __init__(
            self,
            source_datapipes:Tuple[dp.iter.IterDataPipe],
            model
        ):
        if not isinstance(source_datapipes,tuple):
            self.source_datapipes = (source_datapipes,)
        else:
            self.source_datapipes = source_datapipes
        self.dp_idx = 0
        self.model = model

    def __iter__(self): yield from self.source_datapipes[self.dp_idx]
    
    def fit(self,epochs):
        self.dp_idx = 0
        epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector)
        epocher.epochs = epochs
        if isinstance(self.model,tuple):
            for m in self.model: 
                m.train()
        else:
            self.model.train()
        for _ in self: pass

    def validate(self,epochs=1,batches=100,show=True) -> DataPipe:
        self.dp_idx = 1
        epocher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),EpochCollector)
        epocher.epochs = epochs
        batcher = find_dp(traverse_dps(self.source_datapipes[self.dp_idx]),BatchCollector)
        batcher.batches = batches
        with evaluating(self.model):
            for _ in self: pass
            if show:
                pipes = list_dps(traverse_dps(self.source_datapipes[self.dp_idx]))
                for pipe in pipes:
                    if hasattr(pipe,'show'):
                        return pipe.show() 
        
add_docs(
LearnerHead,
"""
""",
fit="Runs the `LearnerHead` pipeline for `epochs`",
validate="""If there is more than 1 dl, then run 1 epoch of that dl based on 
`dl_idx` and returns the original datapipe for displaying."""
)  

In [None]:
from fastrl.dataloading.core import dataloaders
from fastrl.loggers.core import EpochCollector

In [None]:
class Printer(dp.iter.IterDataPipe):
    def __init__(self,pipe): 
        self.pipe = pipe

    def __iter__(self):
        for o in self.pipe:
            print(o, end=" ")
            yield o
        print()

def TestLearner(train_dl,valid_dls):
    model = nn.Module()
    learner = LearnerBase(model,train_dl)
    learner = Printer(learner)
    learner = EpochCollector(learner)

    val_learner = LearnerBase(model,valid_dls)
    val_learner = Printer(val_learner)
    val_learner = BatchCollector(val_learner,batches=1000)
    val_learner = EpochCollector(val_learner)

    learner = LearnerHead((learner,val_learner),model)
    return learner

dls = dataloaders((
        dp.iter.IterableWrapper(range(10)),
        dp.iter.IterableWrapper(range(10,20))
    ),
    do_concat=True
)
(dl3,) = dataloaders(dp.iter.IterableWrapper(range(20,30)))
print('Concated Dataloaders')

learn = TestLearner(dls,dl3)
learn.fit(5)

print("Validating Concated Dataloaders:")
learn.validate(1)  # using one epoch for validation by default

dls = dataloaders((
        dp.iter.IterableWrapper(range(10)),
        dp.iter.IterableWrapper(range(10,20))
    ),
    do_multiplex=True
)
(dl3,) = dataloaders(dp.iter.IterableWrapper(range(20,30)))
print('Muxed Dataloaders')

learn = TestLearner(dls,dl3)
learn.fit(5)

print("Validating Muxed Dataloaders:")
learn.validate(1)

> Warning: Pickling the LearnerBase will exclude the '_dls','opt','iterable' fields since
these aren't easily picklable (yet).

In [None]:
from fastrl.agents.dqn.basic import DQN
from fastrl.agents.core import AgentBase,AgentHead,StepFieldSelector,SimpleModelRunner,NumpyConverter
from fastrl.agents.discrete import ArgMaxer,PyPrimativeConverter
from fastrl.envs.gym import GymDataPipe

In [None]:
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the agent
agent = AgentBase(model,[])
agent = StepFieldSelector(agent,field='state')
# All the things that make this agent unique and special
# In this instance, all this module does is pass the action directly through to the model.
agent = SimpleModelRunner(agent)
agent = ArgMaxer(agent,only_idx=True)
agent = NumpyConverter(agent)
agent = PyPrimativeConverter(agent)
# Bring everything together into the AgentHead where actions will be passed and then run through the pipeline
agent = AgentHead(agent)

In [None]:
from fastrl.loggers.vscode_visualizers import VSCodeDataPipe

In [None]:
# Setup the DataBlock
def gym_block(num_workers=0,vis=False):
    pipe = GymDataPipe(['CartPole-v1']*1,agent=agent,nsteps=1,nskips=1,firstlast=False,include_images=True,n=100,bs=1)
    if vis:
        pipe = VSCodeDataPipe(pipe)
    return pipe

train_dl = dataloaders((gym_block(),gym_block(vis=True)))


In [None]:
def TestLearner(model,train_dl,valid_dls):
    learner = LearnerBase(model,train_dl)
    learner = Printer(learner)
    learner = EpochCollector(learner)

    val_learner = LearnerBase(model,valid_dls)
    val_learner = Printer(val_learner)
    val_learner = BatchCollector(val_learner,batches=100)
    val_learner = EpochCollector(val_learner)

    learner = LearnerHead((learner,val_learner),model)
    return learner

In [None]:
import pickle

In [None]:
learner = TestLearner(model,train_dl,train_dl)

out = pickle.dumps(learner)
pickle.loads(out)

In [None]:
learner.fit(2)

In [None]:
learner.validate(1)

In [None]:
#|export
class StepBatcher(dp.iter.IterDataPipe):
    def __init__(self,
            source_datapipe,
            device=None
        ):
        self.source_datapipe = source_datapipe
        self.device = device
        
    def vstack_by_fld(self,batch,fld):
        try:
            t = torch.vstack(tuple(getattr(step,fld) for step in batch))
            if self.device is not None:
                t = t.to(torch.device(self.device))
            t.requires_grad = False
            return t
        except RuntimeError as e:
            print(f'Failed to stack {fld} given batch: {batch}')
            raise
        
    def __iter__(self):
        for batch in self.source_datapipe:
            cls = batch[0].__class__
            batched_step = cls(**{fld:self.vstack_by_fld(batch,fld) for fld in cls._fields})
            yield batched_step 

add_docs(
StepBatcher,
"Converts multiple `StepType` into a single `StepType` with the fields concated.",
vstack_by_fld="vstacks a `fld` in `batch`"
)

In [None]:
#|hide
#|eval: false
!nbdev_export