In [1]:
import torch
from torch import tensor, nn, optim

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device

'cpu'

In [2]:
def conv(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):
  if bias is not None: bias = not isinstance(bias, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
  layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)]
  if norm: layers.append(norm(nf))
  if act: layers.append(act())
  return nn.Sequential(*layers)

In [10]:
def get_model(act=nn.ReLU, nfs=None, norm=None):
  if nfs is None: nfs = [1,8,16,32,64]
  layers = [conv(nfs[i], nfs[i+1], act=act, norm=norm) for i in range(len(nfs)-1)]
  return nn.Sequential(*layers, conv(nfs[-1],10, act=None, norm=False, bias=True),
                       nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()).to(device)

In [11]:
x = torch.randn(1, 1, 28, 28)
model = get_model(norm=nn.BatchNorm2d)
for i, layer in enumerate(model):
    x = layer(x)
    print(f"After layer {i}: {x.shape}")

After layer 0: torch.Size([1, 8, 14, 14])
After layer 1: torch.Size([1, 16, 7, 7])
After layer 2: torch.Size([1, 32, 4, 4])
After layer 3: torch.Size([1, 64, 2, 2])
After layer 4: torch.Size([1, 10, 1, 1])
After layer 5: torch.Size([1, 10, 1, 1])
After layer 6: torch.Size([1, 10])


In [12]:
x = torch.randn(1, 1, 32, 32)
model = get_model(norm=nn.BatchNorm2d)
for i, layer in enumerate(model):
    x = layer(x)
    print(f"After layer {i}: {x.shape}")

After layer 0: torch.Size([1, 8, 16, 16])
After layer 1: torch.Size([1, 16, 8, 8])
After layer 2: torch.Size([1, 32, 4, 4])
After layer 3: torch.Size([1, 64, 2, 2])
After layer 4: torch.Size([1, 10, 1, 1])
After layer 5: torch.Size([1, 10, 1, 1])
After layer 6: torch.Size([1, 10])


In [13]:
x = torch.randn(1, 1, 64, 64)
model = get_model(norm=nn.BatchNorm2d)
for i, layer in enumerate(model):
    x = layer(x)
    print(f"After layer {i}: {x.shape}")

After layer 0: torch.Size([1, 8, 32, 32])
After layer 1: torch.Size([1, 16, 16, 16])
After layer 2: torch.Size([1, 32, 8, 8])
After layer 3: torch.Size([1, 64, 4, 4])
After layer 4: torch.Size([1, 10, 2, 2])
After layer 5: torch.Size([1, 10, 1, 1])
After layer 6: torch.Size([1, 10])
