# Understanding RLModules for custom policies in Ray RLlib

In [1]:
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from typing import Any, Dict
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule

import torch
import torch.nn as nn
import numpy as np

Ray RLlib defines [RLModules](https://docs.ray.io/en/latest/rllib/key-concepts.html) as:

> RLModules are framework-specific neural network containers. In a nutshell, they carry the neural networks and define how to use them during three phases that occur in reinforcement learning: Exploration, inference and training. A minimal RL Module can contain a single neural network and define its exploration-, inference- and training logic to only map observations to actions. Since RL Modules can map observations to actions, they naturally implement reinforcement learning policies in RLlib and can therefore be found in the RolloutWorker, where their exploration and inference logic is used to sample from an environment. The second place in RLlib where RL Modules commonly occur is the Learner, where their training logic is used in training the neural network. RL Modules extend to the multi-agent case, where a single MultiRLModule contains multiple RL Modules.

Summarizing, the RLModules implement the neural networks our RL algorithms are going to use for training and inferencing. So, in order to implement our own neural network architectures and input/output policy dynamics, it is essential to understand RLModules. Figure below illustrates the RLModules being responsible for dealing the neural network inputs and outputs.

![RLModules archtecture](./imgs/rl_module.png)

So, let's investigate the RLModule structure by implementing a custom RLModule for a discrete action space based on Torch framework.

In [2]:
class DiscreteTorchModule(TorchRLModule):  # We inherit from base class TorchRLModule

    def __init__(
        self,
        observation_space,
        action_space,
        inference_only,
        model_config,
        catalog_class,
    ) -> None:
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            inference_only=inference_only,
            model_config=model_config,
            catalog_class=catalog_class,
        )

    def setup(self):
        # Here we are going to create the policy network (neural network)
        input_dim = self.observation_space.shape[
            0
        ]  # The neural network input dimension is the same as the observation space dimension
        hidden_dim = self.model_config["fcnet_hiddens"][
            0
        ]  # Represents the number of hidden units in the neural network
        output_dim = (
            self.action_space.n
        )  # Finally, we have one neuron (output) per action in the action space

        self.policy = nn.Sequential(  # Here we create the neural network using PyTorch
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

        self.input_dim = input_dim

    def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        # Forward pass for training.
        with torch.no_grad():  # It disables gradient calculation and therefore no weight updates could be applied (and we don't want it since we are not training)
            action = np.argmax(self.policy(batch["obs"]))
            return {"actions": action}

    def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        exploration_epsilon = 0.2
        if np.random.rand() < exploration_epsilon:
            # Random action exploration
            action = np.random.choice(self.action_space.n)
            return {"actions": action}
        else:
            # Use logits from policy network for action selection
            logits = self.policy(batch["obs"])
            return {
                "action_dist_inputs": torch.distributions.Categorical(logits=logits)
            }

    def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        # Forward pass for training
        action_logits = self.policy(
            batch["obs"]
        )  # Here we pass the observation to the neural network and accounts the gradients
        # Be aware that the gradients are accounted here because we are training

        # When using a discrete action space, the neural network output is not the action itself, but the logits of the action distribution, therefore we need to apply a softmax function to get the discrete action. Therefore, we have one neural network output per action in the action space.
        return {
            "action_dist_inputs": torch.distributions.Categorical(logits=action_logits)
        }  # You can read more about the Categorical distribution in the PyTorch documentation here https://pytorch.org/docs/stable/distributions.html#categorical

We can use the `RLModuleSpec` class to build our final RLModule.

In [3]:
env = gym.make("CartPole-v1")
spec = RLModuleSpec(
    module_class=DiscreteTorchModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config={"fcnet_hiddens": [64]},
)

rlmodule = spec.build()



At this point, we have a variable `rlmodule` containing a RLModule to interact with our neural network. It was not trained before, so the neural network weights have their initial values intact. Let's interact with it to verify how it works.

The inference does not explore new possibilities and, instead, always choose the action with the highest logit.

In [4]:
batch = {"obs": torch.from_numpy(env.observation_space.sample())}
print(f"Obs: {batch['obs']}")

# Forward inference
inference = rlmodule.forward_inference(batch)
inference_actions = inference["actions"]
print(f"Action: {inference_actions}")

Obs: tensor([1.3627, 0.1826, 0.1172, 0.1369])
Action: 1


Our exploration function explores random actions with a 20% of chance, otherwise it chooses the action with the highest logit.

In [5]:
# Forward exploration
exploration = rlmodule.forward_exploration(batch)
if "action" in exploration:
    print(f"Action: {exploration['actions']}")
elif "action_dist_inputs" in exploration:
    exploration_action_dist_inputs = exploration["action_dist_inputs"]
    print(f"Action dist: {exploration}")
    print(f"Logits: {exploration_action_dist_inputs.logits}")
    print(f"Probabilities: {exploration_action_dist_inputs.probs}")

Action dist: {'action_dist_inputs': Categorical(logits: torch.Size([2]))}
Logits: tensor([-0.8223, -0.5788], grad_fn=<SubBackward0>)
Probabilities: tensor([0.4394, 0.5606], grad_fn=<SoftmaxBackward0>)


When training, we return the distribution of probabilities for each action with gradient calculation enabled.

In [6]:
# Forward train
train = rlmodule.forward_train(batch)
train_action_dist_inputs = train["action_dist_inputs"]
print(f"Action: {train_action_dist_inputs.sample()}")
print(f"Logits: {train_action_dist_inputs.logits}")
print(f"Probabilities: {train_action_dist_inputs.probs}")

Action: 1
Logits: tensor([-0.8223, -0.5788], grad_fn=<SubBackward0>)
Probabilities: tensor([0.4394, 0.5606], grad_fn=<SoftmaxBackward0>)


Finally, let's now make a quick look at PPO RLModule. Remember that PPO is an actor-critic method and therefore utilizes two neural networks, one for the actor (phi) and another for the critic (vf). See the Ray RLlib implementation for [PPO RLModule](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/ppo_rl_module.py) below:

In [None]:
"""
This file holds framework-agnostic components for PPO's RLModules.
"""

import abc
from typing import List

from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
    override,
    OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.util.annotations import DeveloperAPI


@DeveloperAPI(stability="alpha")
class PPORLModule(RLModule, InferenceOnlyAPI, ValueFunctionAPI, abc.ABC):
    @override(RLModule)
    def setup(self):
        if self.catalog is None and hasattr(self, "_catalog_ctor_error"):
            raise self._catalog_ctor_error

        is_stateful = isinstance(
            self.catalog.actor_critic_encoder_config.base_encoder_config,
            RecurrentEncoderConfig,
        )
        if is_stateful:
            self.inference_only = False

        if self.inference_only and self.framework == "torch":
            self.catalog.actor_critic_encoder_config.inference_only = True

        # Build models from catalog.
        # Here we build the pi and vf policies (neural networks) to be used by the PPO algorithm
        self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
        self.pi = self.catalog.build_pi_head(framework=self.framework)
        self.vf = self.catalog.build_vf_head(framework=self.framework)
        # __sphinx_doc_end__

    @override(RLModule)
    def get_initial_state(self) -> dict:
        if hasattr(self.encoder, "get_initial_state"):
            return self.encoder.get_initial_state()
        else:
            return {}

    @OverrideToImplementCustomLogic_CallToSuperRecommended
    @override(InferenceOnlyAPI)
    def get_non_inference_attributes(self) -> List[str]:
        """Return attributes, which are NOT inference-only (only used for training)."""
        return ["vf"] + (
            []
            if self.model_config.get("vf_share_layers")
            else ["encoder.critic_encoder"]
        )

    # Where are the functions _forward_inference, _forward_exploration and _forward_train?

Pay attention that the PPORLModule does not implement the _forward_inference(), _forward_exploration() and _forward_train() methods. This is because the PPORLModule is a base class for the PPO algorithm, but the forward mechanism is coupled to the framework to compute gradients used in Ray RLlib: Pytorch and Tensorflow. Therefore, there is a class [PPOTorchRLModule](https://github.com/ray-project/ray/blob/master/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py) responsible for inheriting `PPORLModule` and implementing the functions `_forward_inference()`, `_forward_exploration()` and `_forward_train()` using Pytorch specific commands. The code for class `PPOTorchRLModule` is:

In [None]:
from typing import Any, Dict, Optional

from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()


class PPOTorchRLModule(TorchRLModule, PPORLModule):
    # We don't need to implement the setup method since it is already implemented in the PPORLModule class
    @override(RLModule)
    def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Default forward pass (used for inference and exploration)."""
        # Here, it uses the _forward common method to substitute the _forward_inference and _forward_exploration methods
        output = {}
        # Encoder forward pass.
        encoder_outs = self.encoder(batch)
        # Stateful encoder?
        if Columns.STATE_OUT in encoder_outs:
            output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]
        # Pi head.
        output[Columns.ACTION_DIST_INPUTS] = self.pi(
            encoder_outs[ENCODER_OUT][ACTOR]
        )  # The forward uses the pi policy
        return output

    @override(RLModule)
    def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Train forward pass (keep embeddings for possible shared value func. call)."""
        output = {}
        encoder_outs = self.encoder(batch)
        output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC]
        if Columns.STATE_OUT in encoder_outs:
            output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT]
        output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR])
        return output

    @override(ValueFunctionAPI)
    def compute_values(
        self,
        batch: Dict[str, Any],
        embeddings: Optional[Any] = None,
    ) -> TensorType:
        if embeddings is None:
            # Separate vf-encoder.
            if hasattr(self.encoder, "critic_encoder"):
                batch_ = batch
                if self.is_stateful():
                    # The recurrent encoders expect a `(state_in, h)`  key in the
                    # input dict while the key returned is `(state_in, critic, h)`.
                    batch_ = batch.copy()
                    batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC]
                embeddings = self.encoder.critic_encoder(batch_)[ENCODER_OUT]
            # Shared encoder.
            else:
                embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC]

        # Value head.
        vf_out = self.vf(embeddings)
        # Squeeze out last dimension (single node value head).
        return vf_out.squeeze(-1)