In [1]:
# Modification of https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py

import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from typing import Any, List, Dict

from torchsummary import summary


__all__ = [
    'AlexNet',
    'AlexNetR',
    'alexnet',
]


pretrained_model_urls = {
    'alexnetr': 'https://download.pytorch.org/models/alexnet-owt-7be5be79.pth'
}


class AlexNet(nn.Module):
    """
    Original AlexNet:
    Input: image of size 3 x 227 x 227
    Output: number class (1000-numbers)
    """
    def __init__(
        self,
        num_channels: int = 3,
        num_classes: int = 1000
    ) -> None:
        super(AlexNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(num_channels, 96, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(6*6*256, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

        self.num_channels = num_channels
        self.num_classes = num_classes

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x


class AlexNetR(nn.Module):
    """
    Modified AlexNet:
    Input: image of size 3 x 227 x 227
    Output: number class (1000-numbers)
    """
    def __init__(
        self,
        num_channels: int = 3,
        num_classes: int = 10
    ) -> None:
        super(AlexNetR, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(6*6*256, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

        self.num_channels = num_channels
        self.num_classes = num_classes

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x


def alexnet(
    net_type: str = 'revised',
    flag_pretrained: bool = False,
    flag_download_progress: bool = True,
    idx_model_layer: List = [0],
    idx_pretrained_model_layer: List = [0],
    **kwargs: Any
) ->AlexNetR:
    """Alexnet network architecture

    Args:
        net_type (str): alexnet varients ('original', 'revised')
        flag_pretrained (bool): False - if "True" initialize with pre-trained model weights
        flag_download_progress: True - progress bar will show during pre-trained model download
        idx_model_layer (List): [0] - adjust 0-th layer's wights
        idx_pretrained_model_layer (List): [0] - use 0-th layer's wights of pre-trained model to adjust the current model
    """
    if net_type=='original':
        model = AlexNet(**kwargs)
    else:
        model = AlexNetR(**kwargs)
        if flag_pretrained:
            dict_pretrained_state = load_state_dict_from_url(pretrained_model_urls['alexnetr'], progress=flag_download_progress)
            print('-'*70)
            print('Using weight from the pretrained model from: {}' .format(pretrained_model_urls['alexnetr']))
            print('-'*70)
            dict_model = weight_transform_layer_pos(model.state_dict(), dict_pretrained_state, idx_model_layer, idx_pretrained_model_layer)
            print('-'*70)
            model.load_state_dict(dict_model)

    return model


def weight_transform_layer_pos(
    dict_model: Dict[str, torch.Tensor],
    dict_pretrained_state: Dict[str, torch.Tensor],
    idx_model_layer: List[int] = [0],
    idx_pretrained_model_layer: List[int] = [0]
) ->Dict:
    """Weights update of the contom layers based on the pre-trained weights on ImageNet

    Args:
        dict_model (state-dict): new (custom) model weights state dict
        dict_pretrained_state (state-dict): pretrained model weights state dict
        idx_model_layer (list of intergers): index of the new (custom) model state dict want to change
        idx_pretrained_model_layer (list of intergers): index of the pre-trained model state dict from where weights will be replaced
    """
    # first copy all weights with same name and size
    for k, v in dict_pretrained_state.items():
        if k in dict_model:# matched weights name
            if len(v.shape)==len(dict_model[k].shape):# matched weights size
                dims_same = 1
                for i in range(len(v.shape)):
                    if v.shape[i]!=dict_model[k].shape[i]:
                        dims_same = 0
                if dims_same:
                    print('weight: model[{}] <= pretrained_model[{}]' .format(k, k))
                    dict_model[k] = v
    # adjust the weights of the custom layers (by mean of the pretrained weights)
    for idx in range(len(idx_model_layer)):
        key_model = list(dict_model.keys())[idx_model_layer[idx]]
        key_pretrained_model = list(dict_pretrained_state.keys())[idx_pretrained_model_layer[idx]]
        print('weight: model[{}] <= pretrained_model[{}]' .format(key_model, key_pretrained_model))
        w_org = dict_model[key_model]
        w_trans = dict_pretrained_state[key_pretrained_model]

        if w_org.shape[1] != w_trans.shape[1]:

            w_trans = w_trans.mean(axis=1)
            for i in range(w_org.shape[1]):
                w_org[:,i,:,] = w_trans
        dict_model[key_model] = w_org

    return dict_model


if __name__ == '__main__':

    num_channels = 1
    H, W = 227, 227
    num_classes = 10
    batch_size = 5
    model = alexnet(flag_pretrained=True, num_channels=num_channels, num_classes=num_classes)

    print('-'*70)
    print('Network architechture (num channels-{}, num classes- {}) as follows:' .format(num_channels, num_classes))
    print('-'*70)
    print(model)
    print('-'*70)

    print('Network summary:')
    print('-'*70)
    summary(model, (num_channels, H, W, ), device=str("cpu"))
    print('-'*70)

    print('Network input output dims check')
    print('-'*70)
    x = torch.randn(batch_size, num_channels, H, W)
    y = model(x)
    print('input shape(batch_size x num_channels x height x width): {}\noutput shape(batch_size x num_classes): {}' .format(x.shape, y.shape))
    print('-'*70)





Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:03<00:00, 66.2MB/s]


----------------------------------------------------------------------
Using weight from the pretrained model from: https://download.pytorch.org/models/alexnet-owt-7be5be79.pth
----------------------------------------------------------------------
weight: model[features.0.bias] <= pretrained_model[features.0.bias]
weight: model[features.3.weight] <= pretrained_model[features.3.weight]
weight: model[features.3.bias] <= pretrained_model[features.3.bias]
weight: model[features.6.weight] <= pretrained_model[features.6.weight]
weight: model[features.6.bias] <= pretrained_model[features.6.bias]
weight: model[features.8.weight] <= pretrained_model[features.8.weight]
weight: model[features.8.bias] <= pretrained_model[features.8.bias]
weight: model[features.10.weight] <= pretrained_model[features.10.weight]
weight: model[features.10.bias] <= pretrained_model[features.10.bias]
weight: model[classifier.1.weight] <= pretrained_model[classifier.1.weight]
weight: model[classifier.1.bias] <= pretrain

In [None]:
# just to check the pretrained weights loaded correctly or not!
dict_model = model.state_dict()
dict_pretrained_state = load_state_dict_from_url(pretrained_model_urls['alexnetr'], progress=True)

dict_model = model.state_dict()
idx = 'classifier.1.weight'
idx = 'features.0.bias'
idx = 'features.0.weight'

T1 = torch.abs(dict_model[idx] - dict_pretrained_state[idx]).sum()

print(T1)


tensor(996.3445, grad_fn=<SumBackward0>)
