In [1]:
import torch
import models as cifar_models
from torch_receptive_field import receptive_field

In [2]:
model = cifar_models.Net(debug=True)

BATCH_SIZE, CHANNEL_SIZE, height, width = 1, 3, 32, 32

try:
    model.forward(torch.rand((BATCH_SIZE, CHANNEL_SIZE, height, width)))
    print("Image size is compatible with layer sizes.")
except RuntimeError as e:
    e = str(e)
    if e.endswith("Output size is too small"):
        print("Image size is too small.")
    elif "shapes cannot be multiplied" in e:
        required_shape = e[e.index("x") + 1:].split(" ")[0]
        print(f"Linear layer needs to have size: {required_shape}")
    else:
        print(f"Error not understood: {e}")

[INPUT] torch.Size([1, 3, 32, 32])
[AFTER C1] torch.Size([1, 32, 32, 32])
[AFTER C2] torch.Size([1, 32, 32, 32])
[AFTER C3] torch.Size([1, 64, 32, 32])
[AFTER C4] torch.Size([1, 128, 15, 15])
[AFTER GAP] torch.Size([1, 128, 1, 1])
[AFTER Flatten] torch.Size([1, 128])
[AFTER FC1] torch.Size([1, 10])
Image size is compatible with layer sizes.


In [3]:
receptive_field(model, input_size=(3, 32, 32))

[INPUT] torch.Size([2, 3, 32, 32])
[AFTER C1] torch.Size([2, 32, 32, 32])
[AFTER C2] torch.Size([2, 32, 32, 32])
[AFTER C3] torch.Size([2, 64, 32, 32])
[AFTER C4] torch.Size([2, 128, 15, 15])
[AFTER GAP] torch.Size([2, 128, 1, 1])
[AFTER Flatten] torch.Size([2, 128])
[AFTER FC1] torch.Size([2, 10])
------------------------------------------------------------------------------
        Layer (type)    map size      start       jump receptive_field 
        0               [32, 32]        0.5        1.0             1.0 
        1               [32, 32]        0.5        1.0             7.0 
        2               [32, 32]        0.5        1.0             7.0 
        3               [32, 32]        0.5        1.0             7.0 
        4               [32, 32]        0.5        1.0            11.0 
        5               [32, 32]        0.5        1.0            11.0 
        6               [32, 32]        0.5        1.0            11.0 
        7               [32, 32]        0.5  

OrderedDict([('0',
              OrderedDict([('j', 1.0),
                           ('r', 1.0),
                           ('start', 0.5),
                           ('conv_stage', True),
                           ('output_shape', [-1, 3, 32, 32])])),
             ('1',
              OrderedDict([('j', 1.0),
                           ('r', 7.0),
                           ('start', 0.5),
                           ('input_shape', [-1, 3, 32, 32]),
                           ('output_shape', [-1, 32, 32, 32])])),
             ('2',
              OrderedDict([('j', 1.0),
                           ('r', 7.0),
                           ('start', 0.5),
                           ('input_shape', [-1, 32, 32, 32]),
                           ('output_shape', [-1, 32, 32, 32])])),
             ('3',
              OrderedDict([('j', 1.0),
                           ('r', 7.0),
                           ('start', 0.5),
                           ('input_shape', [-1, 32, 32, 32]),
         

In [4]:
model.summary(input_size=(1, 3, 32, 32))

[INPUT] torch.Size([1, 3, 32, 32])
[AFTER C1] torch.Size([1, 32, 32, 32])
[AFTER C2] torch.Size([1, 32, 32, 32])
[AFTER C3] torch.Size([1, 64, 32, 32])
[AFTER C4] torch.Size([1, 128, 15, 15])
[AFTER GAP] torch.Size([1, 128, 1, 1])
[AFTER Flatten] torch.Size([1, 128])
[AFTER FC1] torch.Size([1, 10])
Layer (type:depth-idx)                   Output Shape              Param #
Net                                      [1, 10]                   --
├─Sequential: 1-1                        [1, 32, 32, 32]           --
│    └─Conv2d: 2-1                       [1, 32, 32, 32]           4,736
│    └─ReLU: 2-2                         [1, 32, 32, 32]           --
│    └─BatchNorm2d: 2-3                  [1, 32, 32, 32]           64
├─Sequential: 1-2                        [1, 32, 32, 32]           --
│    └─Conv2d: 2-4                       [1, 32, 32, 32]           832
│    └─ReLU: 2-5                         [1, 32, 32, 32]           --
│    └─BatchNorm2d: 2-6                  [1, 32, 32, 32]     

In [5]:
x = torch.randn((1, 256, 1, 1))
print("x shape", x.shape)
print("x dim", x.dim())
print("x size", x.size())

x shape torch.Size([1, 256, 1, 1])
x dim 4
x size torch.Size([1, 256, 1, 1])


In [6]:
x.view(-1, 256).shape

torch.Size([1, 256])

In [7]:
x.flatten(1).shape

torch.Size([1, 256])