In [1]:
import torch
from torch import nn

In [34]:
class VGG(nn.Module):
    _config = {
        "VGG11_A": (64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'),
        "VGG13_B": (64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'),
        "VGG16_D": (64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'),
        "VGG19_E": (64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512,'M'),
    }
        
    def __init__(self, in_channels=3, num_classes=1000, config_name='VGG16'):
        super(VGG, self).__init__()
        
        self.in_channels = in_channels
        self.num_classes = num_classes
        
        self.config = VGG._config[config_name]
            
        self.conv_layers = self._create_conv_layers()
        self.flatten = nn.Flatten()
        self.fc_layers = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.Dropout(p=0.5),
            nn.ReLU(inplace=True),
            nn.Linear(4096, self.num_classes)
        )
    
        
    def forward(self, X):
        X = self.conv_layers(X)
        X = self.flatten(X)
        X = self.fc_layers(X)
        
        return X
    
    def _create_conv_layers(self):
        layers = []
        in_channels = self.in_channels
        
        for config_param in self.config:
            if isinstance(config_param, int):
                layers.extend([
                    nn.Conv2d(in_channels, config_param, kernel_size=3, padding=1, stride=1),
                    nn.BatchNorm2d(config_param),
                    nn.ReLU()
                ])
                
                in_channels = config_param
            else:
                layers.append(
                    nn.MaxPool2d(kernel_size=2, stride=2)
                )
                
            
        return nn.Sequential(*layers)

In [35]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = VGG(config_name='VGG19').to(device)
data = torch.randn(1, 3, 224, 224).to(device)
print(model(data).shape)

torch.Size([1, 1000])


In [36]:
print([256] * 4)

[256, 256, 256, 256]
