In [10]:
import torch
from torch import nn
from torch.nn import functional as F

from NNUtils.wutorchkeras import wutorchkeras

In [25]:
class Net32(nn.Module):
    def __init__(self, in_channels=3, out_features=10):
        super(Net32, self).__init__()
        # (batch, 3, 32, 32)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5, stride=1, padding=0)
        # (batch, 6, 28, 28)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        # (batch, 6, 14, 14)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
        # (batch, 16, 10, 10)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        # (batch, 16, 5, 5)
        self.flatten = nn.Flatten()
        # (batch, 16*5*5=400)
        self.fc1 = nn.Linear(in_features=400, out_features=128)
        # (batch, 128)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        # (batch, 64)
        self.fc3 = nn.Linear(in_features=64, out_features=32)
        # (batch, 32)
        self.fc4 = nn.Linear(in_features=32, out_features=out_features)
        # (batch, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        return x


model = wutorchkeras.Model(Net32())
model.summary(input_shape=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 28, 28]             456
         MaxPool2d-2            [-1, 6, 14, 14]               0
            Conv2d-3           [-1, 16, 10, 10]           2,416
         MaxPool2d-4             [-1, 16, 5, 5]               0
           Flatten-5                  [-1, 400]               0
            Linear-6                  [-1, 128]          51,328
            Linear-7                   [-1, 64]           8,256
            Linear-8                   [-1, 32]           2,080
            Linear-9                   [-1, 10]             330
Total params: 64,866
Trainable params: 64,866
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.011719
Forward/backward pass size (MB): 0.064957
Params size (MB): 0.247444
Estimated Total Size (MB): 0.324120
-----------------------------