In [1]:
import torch
import torch.nn as nn

localization = nn.Sequential(nn.Conv2d(1, 8, kernel_size=7),  # 1 input channel, 8 output channels, 7x7 kernel
                                                              # Output size: ( 28 + 0 - 6 - 1 ) / 1 + 1 = 22
                                                              # so shape becomes (8, 22, 22)
                             nn.MaxPool2d(2, stride=2),       # Max pooling with 2x2 window
                                                              # Output size: ( 22 + 0 - 1 - 1 ) / 2 + 1 = 11
                                                              # so shape becomes (8, 11, 11)
                             nn.ReLU(True),                   # max(0, x)
                             nn.Conv2d(8, 10, kernel_size=5), # 8 input channels, 10 output channels, 5x5 kernel
                                                              # Output size: 11 + 0 - 4 - 1 / 1 + 1 = 7
                                                              # so shape becomes (10, 7, 7)
                             nn.MaxPool2d(2, stride=2),       # Max pooling with 2x2 window
                                                              # Output size: 7 + 0 - 1 - 1 / 2 + 1 = 3
                                                              # so shape becomes (10, 3, 3)
                             nn.ReLU(True)                    # max(0, x)
                           )
  
dummy_input = torch.randn(1, 1, 28, 28)

output = localization(dummy_input)

print("Localization output shape:", output.shape)
print("Localization output:", output)

Localization output shape: torch.Size([1, 10, 3, 3])
Localization output: tensor([[[[0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.2192, 0.0000]],

         [[0.0667, 0.0767, 0.4927],
          [0.1802, 0.0000, 0.0755],
          [0.4317, 0.0000, 0.3128]],

         [[0.0000, 0.0300, 0.0000],
          [0.2942, 0.0000, 0.3180],
          [0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000]],

         [[0.2538, 0.0795, 0.3774],
          [0.0087, 0.0883, 0.0000],
          [0.0000, 0.0671, 0.4181]],

         [[0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000]],

         [[0.8234, 0.8853, 0.6888],
          [0.7737, 0.8117, 0.9439],
          [0.4719, 0.6555, 0.8744]],

         [[0.5793, 0.8604, 0.7705],
          [0.4471, 0.5315, 0.7647],
          [0.5420, 0.9703, 0.7290]],

         [[0.1144, 0.0000, 0.0252],
          

In [8]:
import torch

dummy_input = torch.randn(2, 1, 28, 28)

x = localization(dummy_input)
print("Input shape:", x.shape)
xs = x.view(-1, 10 * 3 * 3)
print("Flattened shape:", xs.shape)

Input shape: torch.Size([2, 10, 3, 3])
Flattened shape: torch.Size([2, 90])


In [14]:
import torch
import torch.nn.functional as F

theta = torch.randn(2, 2, 3)

grid = F.affine_grid(theta, dummy_input.size())

x = F.grid_sample(dummy_input, grid)

theta.shape, grid.shape, x.shape



(torch.Size([2, 2, 3]), torch.Size([2, 28, 28, 2]), torch.Size([2, 1, 28, 28]))

In [16]:
import torch.nn as nn

conv2_drop = nn.Dropout2d(0.5)

conv2_drop(x).shape

torch.Size([2, 1, 28, 28])

In [17]:
import torch
import torch.nn as nn

dummy_input = torch.randn(2, 10, 24, 24)

nn.MaxPool2d(2, stride=2)(dummy_input).shape

torch.Size([2, 10, 12, 12])