In [None]:
import torch.nn as nn

class Vgg(nn.Module):
    def __init__(self, num_classes=10):
        super(Vgg, self).__init__()
        self.features = nn.Sequential( 
            ## 9개의 conv, 1개의 fc = 10
            # 32 x 32

            
            nn.Conv2d(3, 64, kernel_size=3, padding=1), ###### 01 ######
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1), ###### 02 ######
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, padding=1), ###### 03 ######
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),
            # 16 x 16

            nn.Conv2d(64, 128, kernel_size=3, padding=1), ###### 04 ######
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 128, kernel_size=3, padding=1), ###### 05 ######
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),
            # 8 x 8

            nn.Conv2d(128, 256, kernel_size=3, padding=1), ###### 06 ######
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 256, kernel_size=3, padding=1), ###### 07 ######
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),
            # 4 x 4

            nn.Conv2d(256, 512, kernel_size=3, padding=1), ###### 08 ######
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, kernel_size=3, padding=1), ###### 09 ######
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),
            # 2 x 2
        )
        self.classifier = nn.Linear(2048, num_classes) # 512 * 2 * 2 = 2048
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        # x.size()=[batch_size, channel, width, height] 
        #          [128, 512, 2, 2] 
        # flatten 결과 => [128, 512x2x2]
        x = self.classifier(x)
        return x