In [1]:
import torch
from torch import nn

# Classic Simple PyTorch CNN Model
- This model have 3 Conv layers and 4 fc layers

## Model Arch
<img src="Arch.svg" width="1000" height="400">

# Torch Model
- Takes 3 channel RGB image as input, produces 10 dimensional class scores
- Has 3 conv, 3 fc layers

In [2]:
class TorchModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.cnn_layers = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.fc_layers = nn.Sequential(
            nn.Linear(in_features=512, out_features=256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(in_features=256, out_features=128),
            nn.BatchNorm1d(128),
            nn.ReLU(),

            nn.Linear(in_features=128, out_features=64),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.Linear(in_features=64, out_features=num_classes)
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        x_flat = x.flatten(start_dim=1)
        logits = self.fc_layers(x_flat)
        return logits

In [7]:
model = TorchModel(num_classes=10)
from torchsummary import summary
summary(model, (3, 48, 48), device="cpu", batch_size=64)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [64, 16, 46, 46]             448
       BatchNorm2d-2           [64, 16, 46, 46]              32
              ReLU-3           [64, 16, 46, 46]               0
         MaxPool2d-4           [64, 16, 23, 23]               0
            Conv2d-5           [64, 32, 21, 21]           4,640
       BatchNorm2d-6           [64, 32, 21, 21]              64
              ReLU-7           [64, 32, 21, 21]               0
         MaxPool2d-8           [64, 32, 10, 10]               0
            Conv2d-9             [64, 32, 8, 8]           9,248
      BatchNorm2d-10             [64, 32, 8, 8]              64
             ReLU-11             [64, 32, 8, 8]               0
        MaxPool2d-12             [64, 32, 4, 4]               0
           Linear-13                  [64, 256]         131,328
      BatchNorm1d-14                  [