## Model DenseNet

### Dataset 

In [1]:
import torch

torch.cuda.is_available()

False

In [2]:
import torch
from torchvision import datasets, transforms

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, download=True, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=8, shuffle=True, **kwargs)

Files already downloaded and verified


In [3]:
idx, data = next(enumerate(test_loader))
one_batch = data[0]
one_batch.shape

torch.Size([8, 3, 32, 32])

### 1. Basic Block & Transition

In [29]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class channel_selection(nn.Module):
    """
    Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
    The output shape of this layer is determined by the number of 1 in `self.indexes`.
    """

    def __init__(self, num_channels):
        """
        Initialize the `indexes` with all one vector with the length same as the number of channels.
        During pruning, the places in `indexes` which corresponds to the channels to be pruned will be set to 0.
        """
        super(channel_selection, self).__init__()
        self.indexes = nn.Parameter(torch.ones(num_channels))

    def forward(self, input_tensor):
        """
        :param input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
        """
        selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
        if selected_index.size == 1:
            selected_index = np.resize(selected_index, (1,))
        output = input_tensor[:, selected_index, :, :]
        return output
    
class BasicBlock(nn.Module):
    def __init__(self, in_planes, cfg, growth_rate=12, drop_rate=0):
        """
            1 block has 1 convd, and 1 layer have many blocks
            :param in_planes: input channel (last layer is Convd)
            :param cfg: in fact, in_planes = cfg
            
            Note: currently we fix the expansion ratio as the default value (expansion=1)

        """
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.select = channel_selection(in_planes)
        self.conv1 = nn.Conv2d(cfg, growth_rate, kernel_size=3, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.drop_rate = drop_rate

    def forward(self, x):
        out = self.bn1(x)
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training=self.training)
        out = torch.cat((x, out), 1)
        return out

    def forward_bn(self, x):
        out = self.bn1(x)
        bn_value = out.clone()
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training=self.training)
        out = torch.cat((x, out), 1)
        return out, bn_value
    
    
class Transition(nn.Module):
    def __init__(self, in_planes, out_planes, cfg):
        """
        :param in_planes: number of the input channel
        :param out_planes: number of the output channel
        :param cfg: 
        """
        super(Transition, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.select = channel_selection(in_planes)
        self.conv1 = nn.Conv2d(cfg, out_planes, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.bn1(x)
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)
        out = F.avg_pool2d(out, 2)
        return out
    
    def forward_bn(self, x):
        out = self.bn1(x)
        bn_value = out.clone()
        out = self.select(out)
        out = self.relu(out)
        out = self.conv1(out)
        out = F.avg_pool2d(out, 2)
        return out, bn_value


### 2. Densnet 

In [30]:
import math


class densenet(nn.Module):
    def __init__(self, depth=40, drop_rate=0, dataset='cifar10', growth_rate=12, compression_rate=1, cfg=None, block_cfg=None):
        """
        :param depth: 3 (layers) × n (number of conv2ds <- in every layer) + 4 (Conv2ds)
        :param drop_rate: dropout rate
        :param dataset: cifar10 or cifar100
        :param growth_rate: gradually increasing from the `n` conv2d to the `n+1` conv2d <- in every layer
        :param cfg:
            default cfg is None:
                start = 24, len(cfg) = 3, growth_rate = 12, cfg.shape = [3, 12 + 1]
                cfg[0] = [24(start), 36, 48, 60, 72, ..., 168]
                cfg[1] = [168(start), 170, 182, 194, ..., 312]
                cfg[2] = [312(start), 324, 336, 348, ..., 456]

        model cfg:
            1 conv2d -> (3, start, (3, 3))
            1st layer (13 conv2ds)
                input_channel = start (= growth_rate * 2)
        """
        super(densenet, self).__init__()
        assert (depth - 4) % 3 == 0, 'depth should be 3n+4'  # 3 layers
        n = (depth - 4) // 3    # how many Conv2ds in every layer
        block = BasicBlock

        self.growth_rate = growth_rate
        self.drop_rate = drop_rate

        if cfg is None:
            cfg = []
            start = growth_rate * 2
            for _ in range(3):
                cfg.append([start + growth_rate * i for i in range(n + 1)])
                start += growth_rate * n
            cfg = [item for sub_list in cfg for item in sub_list]
        assert len(cfg) == 3 * (n + 1), 'length of config variable cfg should be 3(n+1)'  # 39 = 3 × (12 + 1)

        # Global variable used to across multiple
        self.inplanes = growth_rate * 2
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, bias=False)

        self.dense1 = self._make_dense_block(block, num_block=n, cfg=cfg[0:n])
        self.trans1 = self._make_transition(compression_rate, cfg[n])
        self.dense2 = self._make_dense_block(block, n, cfg[n + 1:2 * n + 1])
        self.trans2 = self._make_transition(compression_rate, cfg[2 * n + 1])
        self.dense3 = self._make_dense_block(block, n, cfg[2 * n + 2:3 * n + 2])
        self.bn = nn.BatchNorm2d(self.inplanes)
        self.select = channel_selection(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)


        # model classifier
        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        else:
            raise ValueError('Model `dataset` parameter is Error!')
        self.classifier = nn.Linear(cfg[-1], num_classes)

        # Weight initialization
        self._initialize_weights()

    def _make_dense_block(self, block, num_block, cfg):
        """
        :param block: Basic Block (one block means one conv2d)
        :param num_block: number of blocks (n) <- in every layer
        :param cfg: all block's channels <- in every layer
        """
        layers = []
        assert num_block == len(cfg), 'Length of the cfg parameter is not right.'
        for i in range(num_block):
            layers.append(block(self.inplanes, cfg=cfg[i], growth_rate=self.growth_rate, drop_rate=self.drop_rate))
            self.inplanes += self.growth_rate
        return nn.Sequential(*layers)

    def _make_transition(self, compression_rate, cfg):
        # cfg is a number in this case.
        inplanes = self.inplanes
        outplanes = int(math.floor(self.inplanes // compression_rate))
        self.inplanes = outplanes
        return Transition(inplanes, outplanes, cfg)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def forward(self, x):
        # 1 convd
        x = self.conv1(x)

        # 3 layers × n convds
        x = self.trans1(self.dense1(x))
        x = self.trans2(self.dense2(x))
        x = self.dense3(x)
        x = self.bn(x)
        x = self.select(x)
        x = self.relu(x)

        # 1 pooling
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x

In [31]:
model = densenet()
model.eval()

densenet(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (dense1): Sequential(
    (0): BasicBlock(
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(24, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(36, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
    )
    (2): BasicBlock(
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (select): channel_selection()
      (conv1): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
    )
    (3): BasicBlock(
      (bn1)

In [32]:
num = 0
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        num += m.weight.data.shape[0]
num

9360

In [35]:
data = torch.randn(14, 3, 32, 32)

num = 0
for k, m in enumerate(model.children()):
    if isinstance(m, nn.Sequential):
        for l, n in enumerate(m.children()):
            data, bn_value = n.forward_bn(data)
            num += bn_value.shape[1]
    elif isinstance(m, Transition):
        data, bn_value = m.forward_bn(data)
        num += bn_value.shape[1]
    elif isinstance(m, nn.BatchNorm2d):
        data = m(data)
        num += m.weight.data.shape[0]
    elif isinstance(m, nn.Linear):
        continue
    else:
        data = m(data)
num

9360