In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
net = nn.Sequential(
    nn.LazyLinear(256),
    nn.ReLU(),
    nn.LazyLinear(10)
)
x = torch.randn(2, 20)
net(x).shape



torch.Size([2, 10])

In [5]:
class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.hidden = nn.LazyLinear(256)
        self.out = nn.LazyLinear(10)

    def forward(self, x):
        return self.out(F.relu(self.hidden(x)))

In [6]:
x = torch.randn(2,20)
net = MLP()
net(x).shape



torch.Size([2, 10])

In [7]:
class MySequential(nn.Module):
    def __init__(self,*args):
        super().__init__()
        for idx, module in enumerate(args):
            self.add_module(str(idx), module=module)
    def forward(self, x):
        for module in self._modules.values():
            x = module(x)
        return x

In [8]:
seq = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
x = torch.randn(2, 20)
seq(x).shape



torch.Size([2, 10])

In [18]:
class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20, 20))
        self.linear = nn.LazyLinear(20)

    def forward(self, x):
        x = self.linear(x)
        x = F.relu(x@self.rand_weight+1)
        x = self.linear(x)
        while x.abs().sum()>1:
            x/=2
        return x.sum()

In [19]:
net = FixedHiddenMLP()
net(x)

tensor(-0.2180, grad_fn=<SumBackward0>)

In [20]:
class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.LazyLinear(64),
            nn.ReLU(),
            nn.LazyLinear(32),
            nn.ReLU(),
        )
        self.linear = nn.LazyLinear(16)
    def forward(self, x):
        return self.linear(self.net(x))

In [21]:
chimera = nn.Sequential(
    NestMLP(),nn.LazyLinear(20),FixedHiddenMLP()
)
chimera(x)



tensor(0.0645, grad_fn=<SumBackward0>)