In [1]:
import torch
from torchvision import datasets, transforms

transforms = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class CNNModel(nn.Module):

    def __init__(self):
        super(CNNModel, self).__init__()

        # convolutional layer 1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        # convolutional layer 2
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        #convolutional layer 3
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)


        # fully connected layer 1
        self.fc1 = nn.Linear(32*7*7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.fc3 = nn.Linear(64*3*3, 128)
        
        self.maxPool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.maxPool2= nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        # input x is a 1x28x28 image
        x = F.relu(self.conv1(x))
        x = self.maxPool1(x)

        # x is a 16x14x14 image
        x = F.relu(self.conv2(x))
        x = self.maxPool2(x)

        # x is a 32x7x7 image
        x = F.relu(self.conv3(x))

        # x is a 64x3x3 image
        x = x.view(-1, 64*3*3)
        x = F.relu(self.fc3(x))

        # x is a 128-dimensional vector
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

model = CNNModel()
print(model)


CNNModel(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=1568, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (fc3): Linear(in_features=576, out_features=128, bias=True)
  (maxPool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (maxPool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
