In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [12]:
device = 'cuda' if torch.cuda.is_available else 'cpu'
IMG_SIZE = 224

In [13]:
class VGG_net(nn.Module):
    def __init__(self, architecture, in_channels=3, num_class=10):
        super(VGG_net, self).__init__()
        self.num_pooling = 0
        self.in_channels = in_channels
        self.conv_layers = self.create_conv_layers(architecture)
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * IMG_SIZE//(2**self.num_pooling) * IMG_SIZE//(2**self.num_pooling), 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            
            nn.Linear(4096, num_class)
        )
        
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc_layers(x)
        return x
    
    
    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels
        
        for x in architecture:
            if type(x) == int:
                out_channels = x
                layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),\
                           nn.BatchNorm2d(x),\
                           nn.ReLU()]
                in_channels = x
            elif x == 'M':
                self.num_pooling += 1
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                
        return nn.Sequential(*layers)

In [14]:
VGG16 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

In [15]:
model = VGG_net(VGG16, in_channels=3, num_class=10).to(device)
x = torch.randn(10, 3, IMG_SIZE, IMG_SIZE).to(device)
model(x)

tensor([[ 0.3763, -0.0303,  0.0065, -0.1575,  0.3476,  0.0674,  0.2578, -0.0704,
         -0.3504, -0.2901],
        [ 0.0431, -0.1185, -0.0736, -0.2434,  0.4924, -0.4610,  0.0632, -0.0484,
         -0.1141,  0.1746],
        [ 0.0951,  0.4222, -0.1675,  0.0180,  0.0664, -0.1375, -0.1733, -0.2234,
         -0.4054, -0.0530],
        [-0.0689,  0.0801,  0.1954, -0.1408,  0.2169,  0.3348, -0.1144, -0.1439,
         -0.2228, -0.3067],
        [ 0.1803, -0.1716, -0.1095,  0.1325,  0.2146, -0.2766,  0.0600, -0.1177,
         -0.4977, -0.3634],
        [ 0.4258,  0.0939, -0.1024, -0.1377,  0.2563,  0.4514,  0.0478, -0.0517,
         -0.0296,  0.0195],
        [ 0.0665,  0.3284, -0.3230,  0.0137,  0.1747, -0.0195, -0.0428, -0.4560,
          0.1536,  0.0232],
        [-0.0749, -0.2347, -0.1689,  0.0582,  0.3567,  0.1214, -0.2004, -0.5142,
          0.2490, -0.4662],
        [-0.1895, -0.2625, -0.0105,  0.1862,  0.0600,  0.1485, -0.5182, -0.3595,
         -0.3979,  0.5013],
        [ 0.2364, -