In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp agents.core

In [None]:
#|export
# Python native modules
import os
from typing import List,Optional
# Third party libs
from fastcore.all import add_docs,ifnone
import torchdata.datapipes as dp
import torch
from torch import nn
from torchdata.dataloader2.graph import traverse_dps
import torch.multiprocessing as mp
# Local modules
from fastrl.core import StepTypes,SimpleStep
from fastrl.torch_core import evaluating,Module
from fastrl.pipes.core import find_dps,find_dp

# Agent Core
> Minimum Agent DataPipes, objects, and utilities

In [None]:
#|export
# Create a manager for shared objects
# manager = mp.Manager()
# shared_model_dict = manager.dict()
shared_model_dict = {}

def share_model(model: nn.Module, name="default"):
    """Move model's parameters to shared memory and store in manager's dictionary."""
    # TODO(josiahls): This will not survive multiprocessing. We will need to us something
    # like ray to better sync models.
    model.share_memory()
    shared_model_dict[name] = model

def get_shared_model(name="default"):
    """Retrieve model from shared memory using the manager's dictionary."""
    return shared_model_dict[name]

class AgentBase(dp.iter.IterDataPipe):
    def __init__(self,
            model:Optional[nn.Module], # The base NN that we getting raw action values out of.
            action_iterator:list=None, # A reference to an iterator that contains actions to process.
            logger_bases=None
    ):
        self.model = model
        self.iterable = ifnone(action_iterator,[])
        self.agent_base = self
        self.logger_bases = logger_bases
        self._mem_name = 'agent_model'
        
    def to(self,*args,**kwargs):
        if self.model is not None:
            self.model.to(**kwargs)

    def __iter__(self):
        while self.iterable:
            yield self.iterable.pop(0)

    def __getstate__(self):
        if self.model is not None:
            share_model(self.model,self._mem_name)
        # Store the non-model state
        state = self.__dict__.copy()
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # Assume a globally shared model instance or a reference method to retrieve it
        if self.model is not None:
            self.model = get_shared_model(self._mem_name)
            
add_docs(
AgentBase,
"""Acts as the footer of the Agent pipeline. 
Maintains important state such as the `model` being used for get actions from.
Also optionally allows passing a reference list of `action_iterator` which is a
persistent list of actions for the entire agent pipeline to process through.

> Important: Must be at the start of the pipeline, and be used with AgentHead at the end.

> Important: `action_iterator` is stored in the `iterable` field. However the recommended
way of passing actions to the pipeline is to call an `AgentHead` instance.
""",
to=torch.Tensor.to.__doc__
) 

In [None]:
#|export               
class AgentHead(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe):
        self.source_datapipe = source_datapipe
        self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase)

    def __call__(self,steps:list):
        if issubclass(steps.__class__,StepTypes.types):
            raise Exception(f'Expected List[{StepTypes.types}] object got {type(steps)}\n{steps}')
        self.agent_base.iterable.extend(steps)
        return self

    def __iter__(self): yield from self.source_datapipe
    
    def augment_actions(self,actions): return actions

    def create_step(self,**kwargs): return SimpleStep(**kwargs,batch_size=[])
    
add_docs(
    AgentHead,
    """Acts as the head of the Agent pipeline. 
    Used for conveniently adding actions to the pipeline to process.
    
    > Important: Must be paired with `AgentBase`
    """,
    augment_actions="""Called right before being fed into the env. 
    
    > Important: The results of this function will not be kept / used in the step or forwarded to 
    any training code.

    There are cases where either the entire action shouldn't be fed into the env,
    or the version of the action that we want to train on would be compat with the env.
    
    This is also useful if we want to train on the original raw values of the action prior to argmax being run on it for example.
    """,
    create_step="Creates the step used by the env for running, and used by the model for training."
)  

In [None]:
#|export
class SimpleModelRunner(dp.iter.IterDataPipe):
    "Takes input from `source_datapipe` and pushes through the agent bases model assuming there is only one model field."
    def __init__(self,
                 source_datapipe
                ): 
        self.source_datapipe = source_datapipe
        self.agent_base = find_dp(traverse_dps(self.source_datapipe),AgentBase)
        self.device = None

    def to(self,*args,**kwargs):
        if 'device' in kwargs: self.device = kwargs.get('device',None)
        return self
    
    def __iter__(self):
        for x in self.source_datapipe:
            if self.device is not None: x = x.to(self.device)
            if len(x.shape)==1:         x = x.unsqueeze(0)
            with torch.no_grad():
                with evaluating(self.agent_base.model):
                    res = self.agent_base.model(x)
            yield res

Check that the 1x4 tensor assuccessfully pushes through the model can get expected outputs...

In [None]:
torch.manual_seed(0)

class DQN(Module):
    def __init__(self,state_sz:int,action_sz:int,hidden=512):
        self.layers=nn.Sequential(
            nn.Linear(state_sz,hidden),
            nn.ReLU(),
            nn.Linear(hidden,action_sz),
        )
    def forward(self,x): return self.layers(x)


In [None]:
import pickle
from fastcore.all import test_eq

In [None]:
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4, 2)
# Setup the agent
agent = AgentBase(model)
agent = SimpleModelRunner(agent)
agent = AgentHead(agent)

# Extract model parameters before pickling
original_params = [param for param in agent.agent_base.model.parameters()]

# Pickle and unpickle the agent
pickled_agent = pickle.dumps(agent)
unpickled_agent = pickle.loads(pickled_agent)

# Modify the parameters of the unpickled model
with torch.no_grad():
    for param in unpickled_agent.agent_base.model.parameters():
        param += 1.0

# Extract model parameters after modification
modified_params = list(agent.agent_base.model.parameters())

# Ensure that the original model's parameters have changed in the same way as the unpickled model's
for orig_param, modif_param in zip(original_params, modified_params):
    assert torch.equal(orig_param, modif_param), "Model parameters didn't change after modification!"

input_tensor = torch.tensor([1, 2, 3, 4]).float()

for action in agent([input_tensor]):
    print(action)

test_eq(input_tensor, torch.tensor([1., 2., 3., 4.]))


In [None]:
input_tensor = torch.tensor([1,2,3,4]).float()

for action in agent([input_tensor]):
    print(action)
    
test_eq(input_tensor,torch.tensor([1., 2., 3., 4.]))

In [None]:
#|export
class StepFieldSelector(dp.iter.IterDataPipe):
    "Grabs `field` from `source_datapipe` to push to the rest of the pipeline."
    def __init__(self,
         source_datapipe, # datapipe whose next(source_datapipe) -> `StepTypes`
         field='state' # A field in `StepTypes` to grab
        ): 
        # TODO: support multi-fields
        self.source_datapipe = source_datapipe
        self.field = field
    
    def __iter__(self):
        for step in self.source_datapipe:
            if not issubclass(step.__class__,StepTypes.types):
                raise Exception(f'Expected typing.NamedTuple object got {type(step)}\n{step}')
            yield getattr(step,self.field)

Check that using `StepFieldSelector`, we can grab the `state` field from the `Simplestep` to push through the model...

In [None]:
agent = AgentBase(model)
agent = StepFieldSelector(agent,field='state')
agent = SimpleModelRunner(agent)
agent = AgentHead(agent)

for action in agent([SimpleStep.random(state=torch.tensor([1.,2.,3.,4.]),batch_size=[])]):
    print(action)

In [None]:
# Setup up the core NN
torch.manual_seed(0)
model = DQN(4,2)
# Setup the agent
agent = AgentBase(model,[])
# All the things that make this agent unique and special
# In this instance, all this module does is pass the action directly through to the model.
agent = SimpleModelRunner(agent)
# Bring everything together into the AgentHead where actions will be passed and then run through the pipeline
agent = AgentHead(agent)

If we pass a list of tensors, we will get a list of actions:

In [None]:
for action in agent([torch.tensor([1,2,3,4]).float()]):
    print(action)

In [None]:
for action in agent([torch.tensor([1,2,3,4]).float()]*3):
    print(action)
traverse_dps(agent); # Check that we can traverse it

In [None]:
# from fastrl.pipes.core import *
# from fastrl.pipes.map.transforms import *
# from fastrl.data.block import *
# from fastrl.envs.gym import *

In [None]:

# def baseline_test(envs,total_steps,seed=0):
#     pipe = dp.map.Mapper(envs)
#     pipe = TypeTransformer(pipe,[GymTypeTransform])
#     pipe = dp.iter.MapToIterConverter(pipe)
#     pipe = dp.iter.InMemoryCacheHolder(pipe)
#     pipe = pipe.cycle()
#     pipe = GymStepper(pipe,seed=seed)

#     steps = [step for _,step in zip(*(range(total_steps),pipe))]
#     return steps, pipe


In [None]:
# steps, pipe = baseline_test(['CartPole-v1'],0)

In [None]:
#|export
class NumpyConverter(dp.iter.IterDataPipe):
    debug=False

    def __init__(self,source_datapipe): 
        self.source_datapipe = source_datapipe
        
    def debug_display(self,step):
        print(f'Step: {step}')
    
    def __iter__(self) -> torch.LongTensor:
        for step in self.source_datapipe:
            if not issubclass(step.__class__,torch.Tensor):
                raise Exception(f'Expected Tensor to  convert to numpy, got {type(step)}\n{step}')
            if self.debug: self.debug_display(step)
            yield step.detach().cpu().numpy()

add_docs(
NumpyConverter,
"""Given input `Tensor` from `source_datapipe` returns a numpy array of same shape with argmax set to 1.""",
debug_display="Display the step being processed"
)

In [None]:
tensors = [torch.tensor([4]) for _ in range(10)]
pipe = NumpyConverter(tensors)
list(pipe);

In [None]:
#|eval:false
tensors = [torch.tensor([4]).to(device='cuda') for _ in range(10)]
pipe = NumpyConverter(tensors)
list(pipe);

In [None]:
#|hide
#|eval: false
!nbdev_export