### Laboratory Activity

Instruction: Convert the following CNN architecture diagram into a PyTorch CNN Architecture.

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [69]:
class CNNModel(nn.Module):
    def __init__(self, num_classes):
        super(CNNModel, self).__init__()
        
        self.Conv1 = nn.Conv2d(1, 32, 3, 1, 1)
        self.MaxPool1 = nn.MaxPool2d(2, 2, 1)
        self.Conv2 = nn.Conv2d(32, 64, 3, 1, 1)
        self.Conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.Conv4 = nn.Conv2d(128, 256, 3, 1, 1)
        self.MaxPool2 = nn.MaxPool2d(2, 2, 0)
        self.DropOut = nn.Dropout(0.2)

        # Temporary placeholder â€” will replace later after seeing flatten size
        self.FCN1 = nn.Linear(1, 1000)
        self.FCN2 = nn.Linear(1000, 500)
        self.FCN3 = nn.Linear(500, num_classes)

    def forward(self, x):
        x = F.relu(self.Conv1(x))
        x = self.MaxPool1(x)
        x = F.relu(self.Conv2(x))
        x = F.relu(self.Conv3(x))
        x = F.relu(self.Conv4(x))
        x = self.MaxPool2(x)
        x = self.DropOut(x)

        x = torch.flatten(x, 1)
        print(x.shape)  # ðŸ‘ˆ prints the flatten size
        return x


In [70]:
model = CNNModel(num_classes=10)
dummy = torch.randn(1, 1, 28, 28)
model(dummy)
model

torch.Size([1, 12544])


CNNModel(
  (Conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (MaxPool1): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
  (Conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (MaxPool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (DropOut): Dropout(p=0.2, inplace=False)
  (FCN1): Linear(in_features=1, out_features=1000, bias=True)
  (FCN2): Linear(in_features=1000, out_features=500, bias=True)
  (FCN3): Linear(in_features=500, out_features=10, bias=True)
)