In [1]:
import rai

from torch.nn import Module
from torch import Tensor, nn

In [2]:
class MLP(Module):
    mlp: Module

    def __init__(self, widths: list[int]):
        super().__init__() # type: ignore
        assert 2 < len(widths), f"Need at least input and output dimensions; got {widths}"
        self.mlp = nn.Sequential()
        for i in range(len(widths) - 1):
            self.mlp.append(nn.Linear(widths[i], widths[i + 1]))
            if i < len(widths) - 2:
                self.mlp.append(nn.ReLU())
        self.mlp.append(nn.Sigmoid())
    
    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)

In [3]:
mlp = MLP([2, 100, 100, 1])
mlp

MLP(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [4]:
rai_mlp = rai.wrap(mlp)
rai_mlp

MLP(
  (mlp): Sequential(
    (0): RAI(
      (_forward): Linear(in_features=2, out_features=100, bias=True)
      (_loss_fn): MSELoss()
    )
    (1): ReLU()
    (2): RAI(
      (_forward): Linear(in_features=100, out_features=100, bias=True)
      (_loss_fn): MSELoss()
    )
    (3): ReLU()
    (4): RAI(
      (_forward): Linear(in_features=100, out_features=1, bias=True)
      (_loss_fn): MSELoss()
    )
    (5): Sigmoid()
  )
)