# Understanding RLModules for custom policies in Ray RLlib

In [1]:
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from typing import Any, Dict
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule

import torch
import torch.nn as nn
import numpy as np

Ray RLlib defines [RLModules](https://docs.ray.io/en/latest/rllib/key-concepts.html) as:

> RLModules are framework-specific neural network containers. In a nutshell, they carry the neural networks and define how to use them during three phases that occur in reinforcement learning: Exploration, inference and training. A minimal RL Module can contain a single neural network and define its exploration-, inference- and training logic to only map observations to actions. Since RL Modules can map observations to actions, they naturally implement reinforcement learning policies in RLlib and can therefore be found in the RolloutWorker, where their exploration and inference logic is used to sample from an environment. The second place in RLlib where RL Modules commonly occur is the Learner, where their training logic is used in training the neural network. RL Modules extend to the multi-agent case, where a single MultiRLModule contains multiple RL Modules.

Summarizing, the RLModules implement the neural networks our RL algorithms are going to use for training and inferencing. So, in order to implement our own neural network architectures and input/output policy dynamics, it is essential to understand RLModules. Figure below illustrates the RLModules being responsible for dealing the neural network inputs and outputs.

![RLModules archtecture](./imgs/rl_module.png)

So, let's investigate the RLModule structure by implementing a custom RLModule for a discrete action space based on Torch framework.

In [2]:
class DiscreteTorchModule(TorchRLModule):  # We inherit from base class TorchRLModule

    def __init__(
        self,
        observation_space,
        action_space,
        inference_only,
        model_config,
        catalog_class,
    ) -> None:
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            inference_only=inference_only,
            model_config=model_config,
            catalog_class=catalog_class,
        )

    def setup(self):
        # Here we are going to create the policy network (neural network)
        input_dim = self.observation_space.shape[
            0
        ]  # The neural network input dimension is the same as the observation space dimension
        hidden_dim = self.model_config["fcnet_hiddens"][
            0
        ]  # Represents the number of hidden units in the neural network
        output_dim = (
            self.action_space.n
        )  # Finally, we have one neuron (output) per action in the action space

        self.policy = nn.Sequential(  # Here we create the neural network using PyTorch
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

        self.input_dim = input_dim

    def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        # Forward pass for training.
        with torch.no_grad():  # It disables gradient calculation and therefore no weight updates could be applied (and we don't want it since we are not training)
            action = np.argmax(self.policy(batch["obs"]))
            return {"actions": action}

    def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        exploration_epsilon = 0.2
        if np.random.rand() < exploration_epsilon:
            # Random action exploration
            action = np.random.choice(self.action_space.n)
            return {"actions": action}
        else:
            # Use logits from policy network for action selection
            logits = self.policy(batch["obs"])
            return {
                "action_dist_inputs": torch.distributions.Categorical(logits=logits)
            }

    def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        # Forward pass for training
        action_logits = self.policy(
            batch["obs"]
        )  # Here we pass the observation to the neural network and accounts the gradients
        # Be aware that the gradients are accounted here because we are training

        # When using a discrete action space, the neural network output is not the action itself, but the logits of the action distribution, therefore we need to apply a softmax function to get the discrete action. Therefore, we have one neural network output per action in the action space.
        return {
            "action_dist_inputs": torch.distributions.Categorical(logits=action_logits)
        }  # You can read more about the Categorical distribution in the PyTorch documentation here https://pytorch.org/docs/stable/distributions.html#categorical

We can use the `RLModuleSpec` class to build our final RLModule.

In [3]:
env = gym.make("CartPole-v1")
spec = RLModuleSpec(
    module_class=DiscreteTorchModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config={"fcnet_hiddens": [64]},
)

rlmodule = spec.build()



At this point, we have a variable `rlmodule` containing a RLModule to interact with our neural network. It was not trained before, so the neural network weights have their initial values intact. Let's interact with it to verify how it works.

The inference does not explore new possibilities and, instead, always choose the action with the highest logit.

In [4]:
batch = {"obs": torch.from_numpy(env.observation_space.sample())}
print(f"Obs: {batch['obs']}")

# Forward inference
inference = rlmodule.forward_inference(batch)
inference_actions = inference["actions"]
print(f"Action: {inference_actions}")

Obs: tensor([1.3627, 0.1826, 0.1172, 0.1369])
Action: 1


Our exploration function explores random actions with a 20% of chance, otherwise it chooses the action with the highest logit.

In [5]:
# Forward exploration
exploration = rlmodule.forward_exploration(batch)
if "action" in exploration:
    print(f"Action: {exploration['actions']}")
elif "action_dist_inputs" in exploration:
    exploration_action_dist_inputs = exploration["action_dist_inputs"]
    print(f"Action dist: {exploration}")
    print(f"Logits: {exploration_action_dist_inputs.logits}")
    print(f"Probabilities: {exploration_action_dist_inputs.probs}")

Action dist: {'action_dist_inputs': Categorical(logits: torch.Size([2]))}
Logits: tensor([-0.8223, -0.5788], grad_fn=<SubBackward0>)
Probabilities: tensor([0.4394, 0.5606], grad_fn=<SoftmaxBackward0>)


When training, we return the distribution of probabilities for each action with gradient calculation enabled.

In [6]:
# Forward train
train = rlmodule.forward_train(batch)
train_action_dist_inputs = train["action_dist_inputs"]
print(f"Action: {train_action_dist_inputs.sample()}")
print(f"Logits: {train_action_dist_inputs.logits}")
print(f"Probabilities: {train_action_dist_inputs.probs}")

Action: 1
Logits: tensor([-0.8223, -0.5788], grad_fn=<SubBackward0>)
Probabilities: tensor([0.4394, 0.5606], grad_fn=<SoftmaxBackward0>)
