In [2]:
import torch.nn as nn

In [3]:
class CNN(nn.Module):
    def __init__(self, input_dim, input_channels, num_classes, GAP=False, GAP_dim=1):
        super().__init__()
        
        self.GAP = GAP
        self.activation = nn.ReLU()
        self.mp = nn.MaxPool2d(kernel_size=2)
        
        # Feature Extraction (CNN modules)
        # Input size: (3, 28,28)
        self.conv1 = nn.Conv2d(input_channels, out_channels=32, kernel_size=3, stride=1, padding=0)
        outdim_1_w = self.compute_output_size(input_dim[0], kernel_size=3, stride=1, padding=0, pooling=0)
        outdim_1_h = self.compute_output_size(input_dim[1], kernel_size=3, stride=1, padding=0, pooling=0)
        
        # size: (32, 26,26)
        self.conv2 = nn.Conv2d(32, out_channels=64, kernel_size=3, stride=1, padding=0)
        outdim_2_w = self.compute_output_size(outdim_1_w, kernel_size=3, stride=1, padding=0, pooling=0)
        outdim_2_h = self.compute_output_size(outdim_1_h, kernel_size=3, stride=1, padding=0, pooling=0)
        print(outdim_2_w, outdim_2_h)
        
        # size: (64, 24,24)
        self.conv3 = nn.Conv2d(64, out_channels=128, kernel_size=3, stride=1, padding=0)
        outdim_3_w = self.compute_output_size(outdim_2_w, kernel_size=3, stride=1, padding=0, pooling=2)
        outdim_3_h = self.compute_output_size(outdim_2_h, kernel_size=3, stride=1, padding=0, pooling=2)
        
        print(outdim_3_w, outdim_3_h)
        if GAP:
            self.global_average_pooling = nn.AdaptiveAvgPool2d(GAP_dim)
            flatten_shape = GAP_dim * GAP_dim * 128
        else:
            flatten_shape = outdim_3_w * outdim_3_h * 128
            
        
        self.flatten = lambda x : torch.reshape(x, (-1, flatten_shape))
        
        # Fully Connected Region (Linear modules)
        self.fc1 = nn.Linear(flatten_shape, 128)
        self.fc2 = nn.Linear(128, num_classes)
    
    def compute_output_size(self, input_size, kernel_size, stride, padding=0, pooling=0):
        output_size = int(np.floor(((input_size - kernel_size + 2*padding) / stride))) + 1
        if pooling:
            output_size = self.compute_output_size(output_size, pooling, pooling)
        return output_size

    def forward(self, x, verbose=False):
        if verbose:
            print("Input shape: ", x.shape)
        # Feature Extraction
        out = self.activation(self.conv1(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.mp(self.conv3(out)))
        
        if self.GAP:
            if verbose:
                print("Prior GAP: ", out.shape)
            out = self.global_average_pooling(out)
            if verbose:
                print("Post GAP: ", out.shape)
        
        # Flatten
        if verbose:
            print("Prior flat: ", out.shape)
        out = self.flatten(out)
        if verbose:
            print("Post flat: ", out.shape)
        
        # Fully Connected
        out = self.activation(self.fc1(out))
        out = self.fc2(out)
        
        return out