In [6]:
#Let's implement the Impala small and large architectures from the paper
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import gymnasium as gym
from ray.rllib.models import ModelCatalog
import torch
import torch.nn as nn
import torch.nn.functional as F

# From the RLLIB docs:
Custom PyTorch Models
Similarly, you can create and register custom PyTorch models by subclassing TorchModelV2 and implement the __init__() and forward() methods. forward() takes a dict of tensor inputs (mapping str to PyTorch tensor types), whose keys and values depend on the view requirements of the model. Usually, the dict contains only the current observation obs and an is_training boolean flag, as well as an optional list of RNN states. forward() should return the model output (of size self.num_outputs) and - if applicable - a new list of internal states (in case of RNNs or attention nets). You can also override extra methods of the model such as value_function to implement a custom value branch.

Additional supervised/self-supervised losses can be added via the TorchModelV2.custom_loss method:

See these examples of fully connected, convolutional, and recurrent torch models.

Example of a Conv model:
https://github.com/ray-project/ray/blob/master/rllib/models/torch/visionnet.py

In [7]:


class ImpalaSmall(TorchModelV2,nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
    
    
    
    def forward(self,obs):
        obs = obs / 255.0
        # First Conv2D layer with 16 8x8 filters with stride 4:
        x = nn.Conv2d(3, 16, 8, stride = 4, padding = 0)(obs)
        # ReLU activation
        x = F.relu(x)
        # Second Conv2D layer with 32 4x4 filters with stride 2:
        x = nn.Conv2d(16, 32, 4, stride = 2, padding = 0)(x)
        # ReLU activation
        x = F.relu(x)
        # Flatten the output
        x = x.view(x.size(0), -1)
        # Fully connected layer with 256 units
        x = nn.Linear(32 * 9 * 9, 256)(x)
        # ReLU activation
        x = F.relu(x)
        return x


ModelCatalog.register_custom_model("impala_small", ImpalaSmall)
