In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
!pip install -q torchinfo
from torchinfo import summary

In [15]:
## VGG can serve as a general deep Neural Network architecture
## It proposed a convolutional neural network block to repeat use
## In their paper, they prefer deeper and narrow conv layer (3,3)

class VGG(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_arch = [(1,64), (1,128), (2,256), (2,512), (2,512)]


    def VGG_block(self, num_convs, in_channels, out_channels):
        layers = []
        for _ in range(num_convs):
          layers.append(nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1))
          layers.append(nn.ReLU())
          in_channels = out_channels

        layers.append(nn.MaxPool2d(2,2))
        return nn.Sequential(*layers)

    def forward(self, x):
        conv_blocks = []
        in_channels = 1

        for (num_convs, out_channels) in self.conv_arch:
          conv_blocks.append(self.VGG_block(num_convs, in_channels, out_channels))
          in_channels = out_channels

        result = nn.Sequential(
            *conv_blocks,
            nn.Flatten(),
            nn.Linear(out_channels * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 10)

        )

        return result(x)


net = VGG() 

In [17]:
summary(model=net, 
        input_size=(1, 1, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
VGG (VGG)                                [1, 1, 224, 224]     [1, 10]              --                   --
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.20
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.20