In [7]:
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
  def __init__(self, num_classes=20):
    super(SimpleCNN, self).__init__()

    self.relu = nn.ReLU(inplace=True)
    
    # Convolution Feature Extraction Part
    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
    self.bn1   = nn.BatchNorm2d(64)
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    self.bn2   = nn.BatchNorm2d(128)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

    self.fc    = nn.Linear(128*16*16, 20)
  def forward(self, x):
    # Convolution Feature Extraction Part
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.pool1(x)

    x = self.conv2(x)
    x = self.bn2(x)
    x = self.relu(x)
    x = self.pool2(x)

    # Fully Connected Classifier Part
    x = torch.flatten(x, 1)
    x = self.fc(x)
    return x

In [8]:
# Network
model = SimpleCNN(num_classes=20)

# Random input
x = torch.randn((1, 3, 64, 64))

# Forward
out = model(x)

# Check the output shape
print("Output tensor shape is :", out.shape)

Output tensor shape is : torch.Size([1, 20])
