In [1]:
# Resnet 是跨层求和，Densenet 是跨层特征在通道维度上进行拼接将
# 主要由dense block 构成 
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10

In [2]:
# 主要由dense block 构成 ，首先定义一个卷积块，顺序是bn->relu->conv
def conv_block(in_channel,out_channel):
    layer=nn.Sequential(
        nn.BatchNorm2d(in_channel),
        nn.ReLU(True),
        nn.Conv2d(in_channel,out_channel,3,padding=1,bias=False)
    
    
    )
    return layer

In [4]:
# 将每次的卷积输出称为growth_rate,,因为如果输入是in_channel ,有n层，那么输出就是in_channel +n*growth_rate
class dense_block(nn.Module):
    def __init__(self,in_channel,growth_rate,num_layers):
        super(dense_block,self).__init__()
        block=[]
        channel =in_channel
        for i in range(num_layers):
            block.append(conv_block(channel,growth_rate))
            channel +=growth_rate
        self.net=nn.Sequential(*block)
    
    def forward(self,x):
        for layer in self.net:
            out=layer(x)
            x=torch.cat((out,x),dim=1)
        return x
        

In [6]:
# 验证一下输出的channel 是否正确
test_net=dense_block(3,12,3)
test_x=Variable( torch.zeros(1,3,96,96))
print(test_x.shape[1],test_x.shape[2],test_x.shape[3])#同时也将输入的长宽减半，这个过渡层可以)
      
test_y=test_net(test_x)
print(test_y.shape[1],test_y.shape[2],test_y.shape[3] )

3 96 96
39 96 96


In [7]:
# 除了dense block ，DenseNet 中还有一个模块叫做过渡层(transition block) ,因为DenseNet 会不断得对维度进行拼接，所以层数很高的时候，
#输出的通道就会越来越大，参数和计算量也会越来越大，。为了避免
#这个问题，需要引入过渡层将输出通道降低一下，同时输入的长宽减半，这个过渡层可以用1*1 的卷积

def transition(in_channel,out_channel):
    trans_layer=nn.Sequential(
        nn.BatchNorm2d(in_channel),
        nn.ReLU(True),
        nn.Conv2d(in_channel,out_channel,1),
        nn.AvgPool2d(2,2)
    
    
    )
    return trans_layer

In [9]:
# 验证一下过渡层是否正确
test_net=transition(3,12)
test_x=Variable(torch.zeros(1,3,96,96))
print(test_x.shape[1],test_x.shape[2],test_x.shape[3])
test_y=test_net(test_x)
print(test_y.shape[1],test_y.shape[2],test_y.shape[3])

3 96 96
12 48 48


In [14]:
# 最后定义DenseNet 
class densenet(nn.Module):
    def __init__(self,in_channel,num_classes,growth_rate=32,block_layers=[6,12,24,16]):
        super(densenet,self).__init__()
        self.block1=nn.Sequential(
            nn.Conv2d(in_channel,64,7,2,3),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxUnpool2d(3,2,padding=1)
        )
        channels=64
        block=[]
        for i, layers in enumerate(block_layers):
            block.append(dense_block(channels,growth_rate,layers))
            channels+=layers*growth_rate
            if i !=len(block_layers)-1:
                block.append(transition(channels,channels//2)) # 通过transition 层将大小减半，通道数减半
                channels=channels//2
        self.block2=nn.Sequential(*block)
        self.block2.add_module("bn",nn.BatchNorm2d(channels))  
        self.block2.add_module("relu",nn.ReLU(True))
        self.block2.add_module("avg_pool",nn.AvgPool2d(3))
        self.classifier=nn.Linear(channels,num_classes)
    def forward(self,x):
        x=self.block1(x)
        x=self.block2(x)
        x=x.view(x.shape[0],-1)
        x=self.classifier(x)
        return x
        

In [15]:
test_net=densenet(3,10)
test_x=Variable(torch.zeros(1,3,96,96))
test_y=test_net(test_x)
print(test_y.shape)

TypeError: forward() missing 1 required positional argument: 'indices'

In [None]:
# DenseNet 将残差连接改为了特征拼接，使得网络有了更稠密的连接