# AutoLayers in PyTorch

AutoLayers automatically adjust their shape, device and dtype based on the first input provided to them. For example, AutoConv2d does not need to be told its input channel count; it infers that from its first input.

It works by saving the arguments used to initialize the actual layer and deferring initialization until the first forward pass.
In the first forward pass, the layer is initialized based on the input tensor's shape, dtype and device.
Further, I use Python's `__class__` trick to convert the AutoLayer into PyTorch's standard layer.

You can freely mix standard PyTorch layers and AutoLayers in a module. Just remember to pass one example input to initialize the AutoLayers. 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

In [2]:
class AutoLayer(nn.Module):
    def __init__(self, *args, **kw):
        self._autoargs = deepcopy(args)
        self._autokw = deepcopy(kw)
        super().__init__()

    def forward(self, *args, **kw):
        # convert self from AutoLayer to actual layer
        self.__class__ = self._autocls

        # initialize layer now
        self.__init__(args[0].shape[1], *self._autoargs, **self._autokw)
        self.to(args[0])

        # run forward pass as if nothing happened
        return self.forward(*args, **kw)

    def extra_repr(self):
        alist = ['_'] + [repr(a) for a in self._autoargs]
        alist += [k + '=' + repr(v) for k,v in self._autokw.items()]
        return ', '.join(alist)

In [3]:
class AutoConv2d(AutoLayer):
    _autocls = nn.Conv2d

class AutoLinear(AutoLayer):
    _autocls = nn.Linear

class AutoBatchNorm2d(AutoLayer):
    _autocls = nn.BatchNorm2d

That's all it takes to define AutoLayers!

Now, we define an example CNN. For demonstration, we use a skip-concat connection. We will also change the activation tensors' device and dtype a few times through the network and see how the layers are automatically adjusted to match.

In [12]:
class Net1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = AutoConv2d(32, kernel_size=3, padding=1, bias=False)
        self.bn1 = AutoBatchNorm2d()
        self.conv2 = AutoConv2d(64, kernel_size=3, padding=1, bias=False)
        self.bn2 = AutoBatchNorm2d()
        self.fc = AutoLinear(10)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = F.relu(y)

        # move to GPU
        y = y.to('cuda')
        y1 = F.max_pool2d(y, 2)

        y = self.conv2(y1)
        y = self.bn2(y)
        y = F.relu(y)
        y = torch.cat([y1, y], dim=1)

        y = F.max_pool2d(y, 2)
        y = y.view(y.shape[0], -1)

        # do fully connected layer in higher precision for fun
        y = y.to(torch.float64)

        return self.fc(y)

Create the model as usual. The only difference is that the model starts out "empty". After one input is provided, it gets properly initialized.

In [5]:
model = Net1()
model

Net1(
  (conv1): AutoConv2d(_, 32, kernel_size=3, padding=1, bias=False)
  (bn1): AutoBatchNorm2d(_)
  (conv2): AutoConv2d(_, 64, kernel_size=3, padding=1, bias=False)
  (bn2): AutoBatchNorm2d(_)
  (fc): AutoLinear(_, 10)
)

In [6]:
y = model(torch.randn(1, 3, 32, 32))
y

tensor([[ 0.1460, -0.0849, -0.5889,  0.0922, -0.3834, -0.4427,  0.2605,  0.1484,
         -0.4818, -0.2454]], device='cuda:0', dtype=torch.float64,
       grad_fn=<AddmmBackward>)

In [7]:
model

Net1(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=6144, out_features=10, bias=True)
)

Notice how the layers have changed from `AutoConv2d` to `Conv2d`. Also note how the device and dtype of the three layers have been adjusted.

In [8]:
model.conv1.weight.type()

'torch.FloatTensor'

In [9]:
model.conv2.weight.type()

'torch.cuda.FloatTensor'

In [10]:
model.fc.weight.type()

'torch.cuda.DoubleTensor'