In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transform
import torchvision.datasets as datasets
from tqdm import tqdm

In [18]:
VGG16 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

class VGG_net(nn.Module):
    def __init__(self, in_channels=3, num_classes=1000):
        super(VGG_net, self).__init__()
        self.in_channels = in_channels
        self.features_extractor = self.create_features_extractor(VGG16)
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes),
        )
    
    def forward(self, x):
        features = self.features_extractor(x)
        x = features.reshape(features.shape[0], -1)
        classifier = self.classifier(x)
        return classifier

    def create_features_extractor(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == int:
                out_channels = x
                layers += [
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU()
                ]
                in_channels = out_channels
            elif x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        return nn.Sequential(*layers)


In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VGG_net(in_channels=3, num_classes=1000).to(device)
BATCH_SIZE = 3
x = torch.randn(BATCH_SIZE, 3, 224, 224).to(device)
print(model(x).shape)

torch.Size([3, 1000])
