In [25]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from PIL import Image

In [47]:
# Load resnet101 freeze all layers, and add one extra output layer
model = models.resnet50(pretrained=True)
model = nn.Sequential(*list(model.children())[:-2])

# for module in model.modules():
#     print('layer: ',module._get_name())

# unfreeze linear layers
# for module in model.modules():
#     if module._get_name() != 'Linear':
#         print('layer: ',module._get_name())
#         for param in module.parameters():
#             param.requires_grad_(False)
#     elif module._get_name() == 'Linear':
#         print('layer: ',module._get_name())
#         for param in module.parameters():
#             param.requires_grad_(True)
# for param in model.parameters():
#     param.requires_grad = False

# # Added two linear layers with output of 80 classes and softmax activation.
# model.fc = nn.Sequential(nn.Linear(2048, 512),
#                                  nn.ReLu(),
#                                  nn.Dropout(0.2),
#                                  nn.Linear(512, NUM_CLASSES),
#                                  nn.LogSoftmax(dim=1))

from torchsummary import summary

# vgg = models.vgg16()
summary(model, (3, 448, 448))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]           4,096
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
           Conv2d-11        [-1, 256, 112, 112]          16,384
      BatchNorm2d-12        [-1, 256, 112, 112]             512
           Conv2d-13        [-1, 256, 112, 112]          16,384
      BatchNorm2d-14        [-1, 256, 1

In [51]:
class BCNN(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(BCNN, self).__init__()
        features = torchvision.models.resnet34(pretrained=pretrained)
        # Remove the pooling layer and full connection layer
        self.conv = nn.Sequential(*list(features.children())[:-2])
        self.fc = nn.Linear(512 * 512, num_classes)
        self.softmax = nn.Softmax()

        if pretrained:
            for parameter in self.conv.parameters():
                parameter.requires_grad = False
            nn.init.kaiming_normal_(self.fc.weight.data)
            nn.init.constant_(self.fc.bias, val=0)

    def forward(self, input):
        features = self.conv(input)
        # Cross product operation
        features = features.view(features.size(0), 512, 14 * 14)
        features_T = torch.transpose(features, 1, 2)
        features = torch.bmm(features, features_T) / (14 * 14)
        features = features.view(features.size(0), 512 * 512)
        # The signed square root
        features = torch.sign(features) * torch.sqrt(torch.abs(features) + 1e-12)
        # L2 regularization
        features = torch.nn.functional.normalize(features)

        out = self.fc(features)
        softmax = self.softmax(out)
        return out, softmax

model = BCNN(80, pretrained=True)

In [52]:
summary(model, (3, 448, 448))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]          36,864
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
       BasicBlock-11         [-1, 64, 112, 112]               0
           Conv2d-12         [-1, 64, 112, 112]          36,864
      BatchNorm2d-13         [-1, 64, 112, 112]             128
             ReLU-14         [-1, 64, 1

