In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

An autolayer saves the arguments used to initialize the actual layer.
The actual layer is intialized when the first input is provided.

In [6]:
class AutoLayer(nn.Module):
    def __init__(self, *args, **kw):
        self._autoargs = deepcopy(args)
        self._autokw = deepcopy(kw)
        super().__init__()

    def forward(self, *args, **kw):
        self.__class__ = self._autocls
        self.__init__(args[0].shape[1], *self._autoargs, **self._autokw)
        return self.forward(*args, **kw)

    def extra_repr(self):
        alist = ['_'] + [repr(a) for a in self._autoargs]
        alist += [k + '=' + repr(v) for k,v in self._autokw.items()]
        return ', '.join(alist)

In [7]:
class AutoConv2d(AutoLayer):
    _autocls = nn.Conv2d

class AutoLinear(AutoLayer):
    _autocls = nn.Linear

class AutoBatchNorm2d(AutoLayer):
    _autocls = nn.BatchNorm2d

Sample CNN with skip-concat connections:

In [8]:
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = AutoConv2d(32, kernel_size=3, padding=1, bias=False)
        self.bn1 = AutoBatchNorm2d()
        self.conv2 = AutoConv2d(64, kernel_size=3, padding=1, bias=False)
        self.bn2 = AutoBatchNorm2d()
        self.fc = AutoLinear(10)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = F.relu(y)
        y1 = F.max_pool2d(y, 2)

        y = self.conv2(y1)
        y = self.bn2(y)
        y = F.relu(y)
        y = torch.cat([y1, y], dim=1)

        y = F.max_pool2d(y, 2)
        y = y.view(y.shape[0], -1)
        return self.fc(y)

Create the model as usual. The only difference is that the model starts out "empty". After one input is provided, it gets properly initialized.

In [9]:
model = Net1()
model

Net1(
  (conv1): AutoConv2d(_, 32, kernel_size=3, padding=1, bias=False)
  (bn1): AutoBatchNorm2d(_)
  (conv2): AutoConv2d(_, 64, kernel_size=3, padding=1, bias=False)
  (bn2): AutoBatchNorm2d(_)
  (fc): AutoLinear(_, 10)
)

In [6]:
y = model(torch.randn(1, 3, 32, 32))
y

tensor([[ 0.4880, -0.2052,  0.4830,  0.2529, -0.9341, -0.6224,  0.1815, -0.1040,
         -0.6797, -0.4987]], grad_fn=<AddmmBackward>)

In [7]:
model

Net1(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=6144, out_features=10, bias=True)
)

Notice how the layers have changed from `AutoConv2d` to `Conv2d`.