In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Net(nn.Module):

    def __init__(self, img_spat_dim, w1, pool1window, w2, pool2window, fc1dim, fc2dim, fc3dim):
        super(Net, self).__init__()
        self.pool1window_ = pool1window
        self.pool2window_ = pool2window
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(*w1)
        img_spat_dim = [(d - w1[-1] + 1) // pool1window for d in img_spat_dim] # 14*14
        
        self.conv2 = nn.Conv2d(*w2)
        img_spat_dim = [(d - w1[-1] + 1) // pool1window for d in img_spat_dim] # 5*5

        
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(img_spat_dim[0]*img_spat_dim[1]*w2[1], fc1dim)
        
        self.fc2 = nn.Linear(fc1dim, fc2dim)

        self.fc3 = nn.Linear(fc2dim, fc3dim)

        
    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), self.pool1window_)
        
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), self.pool2window_)
        
        # flatten x so its shape is BS*Feats
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net([32,32], [1, 6, 5], 2, [6, 16, 5], 2, 120, 84, 10)
print(net)


Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [3]:
params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

10
torch.Size([6, 1, 5, 5])


In [4]:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)


tensor([[-0.0213, -0.0988,  0.0805,  0.0915,  0.1777,  0.0264,  0.2425, -0.0454,
         -0.0319,  0.0763]], grad_fn=<AddmmBackward>)
