In [5]:
import torch
import torch.nn as nn

class VGG16(nn.Module):
    def __init__(self, in_channels = 3):
        super(VGG16, self).__init__()
        self.in_channels = in_channels
        self.features = nn.Sequential(
            nn.Conv2d(in_channels = self.in_channels, out_channels = 64, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #1
            nn.ReLU(inplace = True),    
            nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #2
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),

            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #3
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #4
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),

            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #5
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #6
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #7
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),

            nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #8
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #9
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #10
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),

            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #11
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #12
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1)), #13
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        self.avgpool = nn.AdaptiveAvgPool2d(output_size = (7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096, bias = True), #14
            nn.ReLU(inplace = True),
            nn.Dropout(p = 0.5, inplace = False), 
            nn.Linear(4096, 4096, bias = True), #15
            nn.ReLU(inplace = True), 
            nn.Dropout(p = 0.5, inplace = False),
            nn.Linear(4096, 10, bias = True) #16
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x
