In [1]:
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
import time
import os
import copy
import torch.nn.functional as F

In [2]:
class SemanticImageExtractor(nn.Module):
    """
    This class expected image as input with size (64x64x3)
    """

    def __init__(self, output_class_num, feature_size=200, pretrain=False):
        self.feature_size = feature_size
        self.output_class_num = output_class_num
        super(SemanticImageExtractor, self).__init__()
        self.features = nn.Sequential(
            # Alex1
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # Alex2
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # Alex3
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            # Alex4
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            # Alex5
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        # return the same number of features but change width and height of img

        if(pretrain):
            ori_alex = torchvision.models.alexnet(pretrained = True)
            ori_weight = ori_alex.state_dict()
            ori_weight.pop('classifier.1.weight')
            ori_weight.pop('classifier.1.bias')
            ori_weight.pop('classifier.4.weight')
            ori_weight.pop('classifier.4.bias')
            ori_weight.pop('classifier.6.weight')
            ori_weight.pop('classifier.6.bias')
            self.load_state_dict(ori_weight)
            del(ori_alex)
            del(ori_weight)
        # finally
        self.add_classifier()


    def add_classifier(self):
        self.fc06 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU()
        )

        self.fc07 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(4096, self.feature_size),
            nn.ReLU()
        )

        self.fc08 = nn.Sequential(
            nn.Linear(self.feature_size, self.output_class_num),
            nn.Softmax())


    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc06(x)
        semantic_features = self.fc07(x)
        p_label = self.fc08(semantic_features)
        return semantic_features, p_label

In [7]:
nop_alex = SemanticImageExtractor(output_class_num=6, pretrain='./weights/original_alex_no_classifier.pth')
nop_alex.eval()

SemanticImageExtractor(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (fc06): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU()
  )
  (fc07): Sequential(
    (0): D

In [16]:
# def custom_loader(model, state_dict):
#     # Copy from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
#     missing_keys: List[str] = []
#     unexpected_keys: List[str] = []
#     error_msgs: List[str] = []

#     # copy state_dict so _load_from_state_dict can modify it
#     metadata = getattr(state_dict, '_metadata', None)
#     print(metadata)
#     state_dict = state_dict.copy()
#     if metadata is not None:
#         # mypy isn't aware that "_metadata" exists in state_dict
#         state_dict._metadata = metadata  # type: ignore[attr-defined]

#     def load(module, prefix=''):
#         local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
#         module._load_from_state_dict(
#             state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
#         for name, child in module._modules.items():
#             if child is not None:
#                 load(child, prefix + name + '.')
#     load(model)
#     pass


In [17]:
nop_alex = SemanticImageExtractor(output_class_num=6)
# custom_loader(nop_alex, torch.load('./weights/original_alex.pth'))
nop_alex.eval()

OrderedDict([('', {'version': 1}), ('features', {'version': 1}), ('features.0', {'version': 1}), ('features.1', {'version': 1}), ('features.2', {'version': 1}), ('features.3', {'version': 1}), ('features.4', {'version': 1}), ('features.5', {'version': 1}), ('features.6', {'version': 1}), ('features.7', {'version': 1}), ('features.8', {'version': 1}), ('features.9', {'version': 1}), ('features.10', {'version': 1}), ('features.11', {'version': 1}), ('features.12', {'version': 1}), ('avgpool', {'version': 1}), ('classifier', {'version': 1}), ('classifier.0', {'version': 1}), ('classifier.1', {'version': 1}), ('classifier.2', {'version': 1}), ('classifier.3', {'version': 1}), ('classifier.4', {'version': 1}), ('classifier.5', {'version': 1}), ('classifier.6', {'version': 1})])


SemanticImageExtractor(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (fc06): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU()
  )
  (fc07): Sequential(
    (0): D

In [7]:
ori_alex = torchvision.models.alexnet(pretrained = True)
ori_alex.classifier
ori_alex.eval()
# torch.save(ori_alex.state_dict(), './weights/original_alex.pth')

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [35]:
ori_weight = ori_alex.state_dict()
ori_weight.pop('classifier.1.weight')
ori_weight.pop('classifier.1.bias')
ori_weight.pop('classifier.4.weight')
ori_weight.pop('classifier.4.bias')
ori_weight.pop('classifier.6.weight')
ori_weight.pop('classifier.6.bias')
ori_weight.keys()
torch.save(ori_weight, './weights/original_alex_no_classifier.pth')

In [38]:
nop_alex = SemanticImageExtractor(output_class_num=6)
nop_alex.load_state_dict(torch.load('./weights/original_alex_no_classifier.pth'))
nop_alex.eval()

SemanticImageExtractor(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
)

In [39]:
nop_alex.fc06 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU()
        )

nop_alex.fc07 = nn.Sequential(
    nn.Dropout(),
    nn.Linear(4096, nop_alex.feature_size),
    nn.ReLU()
)

nop_alex.fc08 = nn.Sequential(
    nn.Linear(nop_alex.feature_size, nop_alex.output_class_num),
nn.Softmax()
)

