In [3]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [4]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [5]:
#|default_exp learner.core

In [6]:
#|export
# Python native modules
import os
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
import torch
from fastai.torch_basics import *
from fastai.torch_core import *
from torchdata.dataloader2 import DataLoader2
from torchdata.dataloader2.graph import find_dps,traverse
# Local modules
from fastrl.core import *
from fastrl.pipes.core import *
from fastrl.loggers.core import *
from fastrl.data.dataloader2 import *

# Learner Core
> Core DataPipes for building Learners

In [28]:
#|export
class LearnerBase(dp.iter.IterDataPipe):
    def __init__(self,
            model:Module, # The base NN that we getting raw action values out of.
            dls:List[DataLoader2], # The dataloaders to read data from for training
            device=None,
            loss_func=None, # The loss function to use
            opt=None, # The optimizer to use
            # LearnerBase will yield each dl individually by default. If `zipwise=True`
            # next() will be called on `dls` and will `yield next(dl1),next(dl2),next(dl1)...`
            zipwise:bool=False,
            # For reinforcement learning, the iterables/workers will live forever and so we dont want
            # to shut them down. We still want a concept of "batch" and "epoch" so this param
            # can handle that.
            batches:int=None
    ):
        self.loss_func = loss_func
        self.opt = opt
        self.model = model
        self.iterable = dls
        self.zipwise = zipwise
        self.learner_base = self
        self.infinite_dls = False
        self._dls = None
        if batches is not None: 
            self.batches = batches
            self.infinite_dls = True
        else:                   
            self.batches = find_dp(traverse(dls[0].datapipe,only_datapipe=True),dp.iter.Header).limit

    def __getstate__(self):
        state = super().__getstate__()
        # TODO: Needs a better way to serialize / deserialize states.
        # state['iterable'] = [d.state_dict() for d in state['iterable']]
        return {k:v for k,v in state.items() if k not in ['_dls','opt','iterable']}

    def __setstate__(self, state):
        # state['iterable'] = [d.from_state_dict() for d in state['iterable']]
        super().__setstate__(state)

    def reset(self):
        if not self.infinite_dls:
            self._dls = [iter(dl) for dl in self.iterable]
        elif self._dls is None:
            self._dls = [iter(dl) for dl in self.iterable]
            
    def increment_batch(self,value):
        # I dont make this inline, because there is a likihood we will have additional conditions
        # and I want to actually be able to read and understand each one...
        if type(value)==Record:               return False
        if type(value)==GetInputItemResponse: return False
        return True
            
    def __iter__(self):
        self.reset()
        exhausted = []
        dl_batch_tracker = [0 for _ in self._dls]
        if self.zipwise:
            while len(exhausted)!=len(self._dls):
                zip_list = []
                for i,dl in self._dls:
                    if i in exhausted: 
                        zip_list.append(None)
                    else:              
                        try: 
                            zip_list.append(next(dl))
                            if self.increment_batch(zip_list[-1]): dl_batch_tracker[i]+=1
                            if self.infinite_dls and dl_batch_tracker[i]>self.batches:
                                raise StopIteration
                        except StopIteration:
                            exhausted.append(i)
                            zip_list.append(None)
        else:
            while len(exhausted)!=len(self._dls):
                for i,dl in enumerate(self._dls): 
                    while i not in exhausted:
                        try:
                            v = next(dl)
                            if self.increment_batch(v): dl_batch_tracker[i]+=1
                            yield v
                            if self.infinite_dls and dl_batch_tracker[i]>self.batches:
                                raise StopIteration
                        except StopIteration:
                            exhausted.append(i)

In [29]:
#|export
class LearnerHead(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.learner_base = find_dp(traverse(self.source_datapipe),LearnerBase)

    def __iter__(self): yield from self.source_datapipe
    
    def fit(self,epochs):
        epocher = find_dp(traverse(self),EpocherCollector)
        epocher.epochs = epochs
        
        for iteration in self: 
            pass
        
add_docs(
    LearnerHead,
    """
    """,
    fit="Runs the `LearnerHead` pipeline for `epochs`"
)  

> Warning: Pickling the LearnerBase will exclude the '_dls','opt','iterable' fields since
these aren't easily picklable (yet).

In [30]:
from fastai.torch_basics import *
from fastai.torch_core import *
from fastrl.agents.dqn.basic import *
from fastrl.agents.core import *

In [31]:
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the agent
agent = AgentBase(model,[])
# All the things that make this agent unique and special
# In this instance, all this module does is pass the action directly through to the model.
agent = SimpleModelRunner(agent)
# Bring everything together into the AgentHead where actions will be passed and then run through the pipeline
agent = AgentHead(agent)

If we pass a list of tensors, we will get a list of actions:

In [32]:
agent.__getstate__()

{'source_datapipe': SimpleModelRunner, 'agent_base': AgentBase}

In [33]:
for action in agent([tensor([1,2,3,4]).float()]):
    print(action)

tensor([[-0.2909, -1.0357]], grad_fn=<AddmmBackward0>)


In [34]:
for action in agent([tensor([1,2,3,4]).float()]*3):
    print(action)

tensor([[-0.2909, -1.0357]], grad_fn=<AddmmBackward0>)
tensor([[-0.2909, -1.0357]], grad_fn=<AddmmBackward0>)
tensor([[-0.2909, -1.0357]], grad_fn=<AddmmBackward0>)


In [35]:
from fastrl.pipes.core import *
from fastrl.pipes.map.transforms import *
from fastrl.data.block import *
from fastrl.envs.gym import *

def baseline_test(envs,total_steps,seed=0):
    pipe = dp.map.Mapper(envs)
    pipe = TypeTransformer(pipe,[GymTypeTransform])
    pipe = dp.iter.MapToIterConverter(pipe)
    pipe = dp.iter.InMemoryCacheHolder(pipe)
    pipe = pipe.header(limit=10)
    pipe = GymStepper(pipe,seed=seed)

    steps = [step for _,step in zip(*(range(total_steps),pipe))]
    return steps, pipe


In [36]:
steps, pipe = baseline_test(['CartPole-v1'],0)
steps

[]

In [37]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()