## Model Resnet: modules & parameters

In [1]:
import torch

torch.cuda.is_available()

True

In [2]:
import torch.nn as nn
import numpy

class channel_selection(nn.Module):
    """
    Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
    The output shape of this layer is determined by the number of 1 in `self.indexes`.
    """
    def __init__(self, num_channels):
        """
        Initialize the `indexes` with all one vector with the length same as the number of channels.
        During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
        """
        super(channel_selection, self).__init__()
        self.indexes = nn.Parameter(torch.ones(num_channels))

    def forward(self, input_tensor):
        """
        Parameter
        ---------
        input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
        """
        selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
        if selected_index.size == 1:
            selected_index = np.resize(selected_index, (1,)) 
        output = input_tensor[:, selected_index, :, :]
        return output

In [3]:
import math
import torch.nn as nn
import numpy as np

"""
preactivation resnet with bottleneck design.
"""


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, cfg, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.select = channel_selection(inplanes)
        self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(cfg[1])
        self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(cfg[2])
        self.conv3 = nn.Conv2d(cfg[2], planes * self.expansion, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        # group1
        out = self.bn1(x)
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)

        # group2
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        # group3
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        # down sample
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        return out

    def forward_bn(self, x):
        bn_value = []
        residual = x

        out = self.bn1(x)
        bn_value.append(out.clone())
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        bn_value.append(out.clone())
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        bn_value.append(out.clone())
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        return out, bn_value

    def mask_bn(self, index, cfg_mask):
        if index == 0:
            self.bn1.weight.data.mul_(cfg_mask)
            self.bn1.bias.data.mul_(cfg_mask)
        elif index == 1:
            self.bn2.weight.data.mul_(cfg_mask)
            self.bn2.bias.data.mul_(cfg_mask)
        elif index == 2:
            self.bn3.weight.data.mul_(cfg_mask)
            self.bn3.bias.data.mul_(cfg_mask)
        else:
            raise ValueError("Index is not including.")


class resnet(nn.Module):
    def __init__(self, depth=164, dataset='cifar10', cfg=None, conv_cfg=None):
        """
        :param depth:
            164 layers => 1 conv2d + 3 layers × 18 blocks (every layer)  × 3 conv2ds (every block)  + 1 avgPool2d
            param n = (depth - 2) // 9:
                n means how many blocks in every layer
                9 = 3 layers × 3 conv2d (every block)
        :param cfg:
            if depth = 164, then len(cfg) = 164
        :param conv_cfg:
            layer block index examples (index starts at 1 & ≤ 18):
                3 indexes / layer: [4, 9, 14] or [6, 12, 18]
                2 indexes / layer: [9, 18]
                1 index / layer: [18]

        number of BatchNorm2d:
            163 = 162 (3 layers × 18 Bottlenecks × 3 BatchNorm2ds) + 1 BatchNorm2d
        """
        super(resnet, self).__init__()
        assert (depth - 2) % 9 == 0, 'depth should be 9n+2'

        n = (depth - 2) // 9  # depth = 164, n = 18
        block = Bottleneck
        self.block_cfg = conv_cfg
        self.inplanes = 16

        # model config
        if cfg is None:
            cfg = [[16, 16, 16], [64, 16, 16] * (n - 1),
                   [64, 32, 32], [128, 32, 32] * (n - 1),
                   [128, 64, 64], [256, 64, 64] * (n - 1), [256]]
            cfg = [item for sub_list in cfg for item in sub_list]

        # model feature
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 16, n, cfg=cfg[0: 3 * n])
        self.layer2 = self._make_layer(block, 32, n, cfg=cfg[3 * n: 6 * n], stride=2)
        self.layer3 = self._make_layer(block, 64, n, cfg=cfg[6 * n:9 * n], stride=2)
        self.bn = nn.BatchNorm2d(64 * block.expansion)
        self.select = channel_selection(64 * block.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)

        # model classifier
        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        else:
            raise ValueError('Model `dataset` parameter is Error!')
        self.fc = nn.Linear(cfg[-1], num_classes)

        # model initialize weight
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, cfg, stride=1):
        """
        :param block: Bottleneck item
        :param planes: record the layer's output channel size
                        planes * expansion = output_planes_size,
                        like 16 * 4 = 128 (the first block output)
        :param blocks: number of Bottleneck in layer
        :param cfg:  channel config of all blocks
                        3 cfg items / Bottleneck
        :param stride: default = 1
        """
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False))
        layers = [block(self.inplanes, planes, cfg[0:3], stride, downsample)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, cfg[3 * i: 3 * (i + 1)]))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        block_value = []
        if self.block_cfg:
            for idx, item in enumerate(self.layer1):
                x = item(x)
                if idx + 1 in self.block_cfg:
                    block_value.append(x)
            for idx, item in enumerate(self.layer2):
                x = item(x)
                if idx + 1 in self.block_cfg:
                    block_value.append(x)
            for idx, item in enumerate(self.layer3):
                x = item(x)
                if idx + 1 in self.block_cfg:
                    block_value.append(x)
        else:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)

        x = self.bn(x)
        x = self.select(x)
        x = self.relu(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)

        if len(block_value):
            return y, block_value
        return y


In [22]:
import os

def resume_model(resume_file):
    if not os.path.isfile(resume_file):
        raise ValueError("Resume model file is not found at '{}'".format(resume_file))
    print("=> loading checkpoint '{}'".format(resume_file))
    checkpoint = torch.load(resume_file)
    if 'epoch' in checkpoint:
        start_epoch = checkpoint['epoch']
    else:
        start_epoch = None
        
    if 'best_prec1' in checkpoint:
        best_prec1 = checkpoint['best_prec1']
    else:
        best_prec1 = None
        
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = None
    
    if 'optimizer' in checkpoint:
        opti_dict = checkpoint['optimizer']
    else:
        opti_dict = None
        
    if 'cfg' in checkpoint:
        cfg = checkpoint['cfg']
        print("-> model cfg is loading...\n cfg: {}".format(list(cfg)))
    else:
        cfg = None
        print("-> not found model cfg...")
    print("=>  epoch {} Prec1: {}".format(start_epoch, best_prec1))
    return state_dict, opti_dict, start_epoch, best_prec1, cfg

# root1
root_path = r'D:\Project\Pycharm\network-slimming\logs'
file_name = 'model_best.pth.tar'
name = [
    'ft_inherit_bn_resnet164_cifar100_percent_0.4_seed_2', 
    'ori_sparsity_resnet164_cifar100_s_1e_5'
]
file_path = os.path.join(root_path, name[1], file_name)

# root2
root_path2 = r'D:\Project\Gitee\network-slimming\logs'
file_name2 = 'model_best.pt'
pruned_name = 'pruned.pth.tar'
name2 = [
    'ft_at_resnet164_cifar100_percent_0.6_seed_2',
    'bn_prune_resnet164_cifar100_percent_0.6',
    'at_prune_resnet164_cifar100_percent_0.6'
]
file_path2 = os.path.join(root_path2, name2[0], file_name2)


state_dict, opti_dict, start_epoch, best_prec1, cfg = resume_model(file_path2)

=> loading checkpoint 'D:\Project\Gitee\network-slimming\logs\ft_at_resnet164_cifar100_percent_0.6_seed_2\model_best.pt'
-> model cfg is loading...
 cfg: [5, 8, 13, 8, 9, 11, 13, 12, 15, 9, 10, 10, 9, 13, 13, 10, 14, 14, 9, 11, 11, 14, 14, 14, 8, 10, 12, 13, 12, 10, 9, 10, 12, 10, 11, 16, 2, 5, 7, 5, 8, 13, 11, 12, 13, 11, 12, 16, 2, 6, 8, 6, 10, 15, 7, 10, 29, 21, 26, 28, 24, 25, 24, 17, 25, 26, 26, 25, 25, 26, 27, 20, 22, 22, 28, 25, 27, 29, 28, 30, 23, 23, 25, 29, 26, 22, 27, 19, 28, 27, 24, 27, 30, 25, 27, 27, 27, 24, 28, 27, 26, 29, 26, 27, 28, 30, 22, 27, 10, 25, 64, 54, 55, 60, 54, 59, 54, 42, 58, 62, 56, 57, 61, 49, 56, 62, 45, 58, 62, 50, 56, 60, 51, 60, 62, 38, 60, 62, 44, 59, 64, 34, 60, 63, 39, 61, 62, 42, 61, 61, 40, 58, 62, 41, 60, 63, 42, 58, 60, 47, 60, 61, 24]
=>  epoch 129 Prec1: 0.745


In [24]:
model = resnet(depth=164, dataset='cifar100', cfg=cfg, conv_cfg=None)
model.load_state_dict(state_dict)
model.eval()

resnet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(5, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(8, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(13, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection

In [25]:
from torchstat import stat

stat(model, (3, 32, 32))

[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: ch

[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: ch

[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: channel_selection is not supported!
[Memory]: channel_selection is not supported!
[MAdd]: channel_selection is not supported!
[Flops]: ch