![](6.DensNET架構圖.png)

![](DensNET論文各層.png)

In [1]:
import torch as t
import torch.nn as nn
import torchvision 
import torch.utils.data as Date
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable as V

## 卷积块：BN->ReLU->1x1Conv->BN->ReLU->3x3Conv 
## DenseBlock內部的結構

In [2]:
class _DenseLayer(nn.Sequential):#卷积块：BN->ReLU->1x1Conv->BN->ReLU->3x3Conv
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
#num_input_features:输入特征图个数
#growth_rate:增长速率，第二个卷积层输出特征图
#grow_rate * bn_size:第一个卷积层输出特征图
#drop_rate:dropout失活率

        super(_DenseLayer, self).__init__()
        
        self.module1 = nn.Sequential(
                            nn.BatchNorm2d(num_input_features),
                            nn.ReLU(),
                            nn.Conv2d(num_input_features, 
                                      bn_size * growth_rate, 
                                      kernel_size=1,
                                      stride=1, 
                                      bias=False
                            
                            ),
                # bn_size bottleneck結構需要先把k*l個通道變成4k個通道，用1x1conv變成k個通道，整體看來是一個降維過程
                nn.BatchNorm2d(bn_size * growth_rate),
                nn.ReLU(),
                nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False),
 
                )
        
        self.drop = drop_rate #最後也加入dropout層以用於訓練過程       
        
    def forward(self, x):
        
        new_features = self.module1(x)
        if self.drop >0:
            new_features = F.dropout(new_features, p = self.drop, training=self.training)
            
        #括自己本身和提取的feature層堆疊在一起
        return t.cat([x, new_features], dim=1)#將原始的x跟新的x合併並且變成按维数1（列）拼接
        


In [3]:
class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
 
# "num_layers:每个block内dense layer层数"  
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)

### 除了dense block，DenseNet 中還有一個模塊叫過渡層（transition block），因為DenseNet 會不斷地對維度進行拼接， 所以當層數很高的時候，輸出的通道數就會越來越大，參數和計算量也會越來越大，為了避免這個問題，需要引入過渡層將輸出通道降低下來，同時也將輸入的長寬減半，這個過渡層可以使用1 x 1 的捲積

In [4]:
def _Transition(num_input_features, num_output_features):#过渡层，将特征图个数减半
    trans_layer = nn.Sequential(
        nn.BatchNorm2d(num_input_features),
        nn.ReLU(),
        nn.Conv2d(num_input_features, num_output_features, 1),
        nn.AvgPool2d(2, 2)
    )
    return trans_layer

In [5]:
test_net = _Transition(3, 32)
test_x = V(t.zeros(1, 3, 228, 228))
test_y = test_net(test_x)
test_y.shape



torch.Size([1, 32, 114, 114])

In [6]:
Bn_size=4
Growth_rate=32
Block_config=(6,12,24,16)
Out_channels=64
Drop_rate=0
Num_classes=10

In [7]:
class DesNET(nn.Module):
    def __init__(self, 
                 growth_rate=Growth_rate, 
                 block_config=Block_config,
                 out_channels=Out_channels,
                 bn_size=Bn_size, 
                 drop_rate=Drop_rate, 
                 num_classes=Num_classes):
        
        super(DesNET,self).__init__()
        
        self.pre_layer = nn.Sequential(
                            nn.Conv2d(3, out_channels, kernel_size=7, stride=2, padding=3, bias=False ),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU(),
                            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )
        
        

        
        
        num_features = out_channels
        
        #blcok_seq = []
        
        
        
        
        for i,num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, 
                                num_input_features=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate, 
            )
            
            self.pre_layer.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            
            #過度層只需要做block層-1，並且特徵要扣半
            
            if i != len(block_config)-1:
                Trans = _Transition(num_input_features = num_features, num_output_features= num_features//2)
                self.pre_layer.add_module('transition%d' % (i + 1), Trans)
                num_features = num_features // 2           
        
                
        self.pre_layer.add_module('Norm5',nn.BatchNorm2d(num_features))

        self.classifier = nn.Linear(num_features, num_classes)
        

        
    def forward(self,x):
        x = self.pre_layer(x)
        out = F.relu(x, inplace=True)
        out = F.avg_pool2d(out, kernel_size=7, stride=1).view(x.size(0), -1)
        out = self.classifier(out)
        return out

        
        
        
        
        

In [8]:
DesNET121 = DesNET()
DesNET121

DesNET(
  (pre_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (module1): Sequential(
          (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
          (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (4): ReLU()
          (5): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (denselayer2): _DenseLayer(
        (module1): Sequential(
          (0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        

In [9]:
test_input = t.randn(1, 3, 228, 228)
test_out = DesNET121(test_input)