In [2]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torchvision
from torchvision import models

In [19]:
def init_weights(modules):
    for m in modules:
        if isinstance(m, nn.Conv2d):
            init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()

class vgg16_bn(torch.nn.Module):
    def __init__(self, pretrained = True, freeze = True):
        super(vgg16_bn, self).__init__()

        vgg_pretrained_features = models.vgg16_bn(
            weights = models.VGG16_BN_Weights.DEFAULT
        ).features

        self.slice_1 = torch.nn.Sequential()
        self.slice_2 = torch.nn.Sequential()
        self.slice_3 = torch.nn.Sequential()
        self.slice_4 = torch.nn.Sequential()
        self.slice_5 = torch.nn.Sequential()

        for x in range(12):
            self.slice_1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 19):
            self.slice_2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(19, 29):
            self.slice_3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(29, 39):
            self.slice_4.add_module(str(x), vgg_pretrained_features[x])

        self.slice_5 = torch.nn.Sequential(
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
            nn.Conv2d(512, 1024, kernel_size = 3, padding = 6, dilation = 6),
            nn.Conv2d(1024, 1024, kernel_size = 1)
        )

        if not pretrained:
            init_weights(self.slice_1.modules())
            init_weights(self.slice_2.modules())
            init_weights(self.slice_3.modules())
            init_weights(self.slice_4.modules())

        init_weights(self.slice_5.modules())

        if freeze:
            for param in self.slice_1.parameters():
                param.requires_grad = False

    def forward(self, x):
        print(x.shape)

input_tensor = torch.randn(1, 3, 224, 224)
vgg16_bn = vgg16_bn()
print(vgg16_bn)

vgg16_bn(
  (slice_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (slice_2): Sequential(
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode

In [20]:
vgg16_bn(input_tensor)

torch.Size([1, 3, 224, 224])
