In [4]:
# DenseNet和ResNet的思路类似，最明显的区别就是，跨层
# 的数据不再通过加法进行连接，而是通过concat进行拼接，
# 确保上一层的信息可以完整的进入下一层当中

# DenseNet的卷积快使用ResNet的改进版本BN->Relu->Conv
# 每个卷积的输出通道数被称之为growth_rate
# 假定输入为in_channels，并且有layers层数，则输出的
# 通道数为in_channels + growth_rate * layers


from mxnet import nd
from mxnet.gluon import nn


def conv_block(channels):
    out = nn.Sequential()
    out.add(nn.BatchNorm(), nn.Activation('relu'), 
    nn.Conv2D(channels, kernel_size=3, padding=1))
    return out


class DenseBlock(nn.Block):

    def __init__(self, layers, growth_rate, *args, **kwargs):
        super(DenseBlock, self).__init__(*args, **kwargs)
        self.net = nn.Sequential()
        for i in range(layers):
            self.net.add(conv_block(growth_rate))

    def forward(self, x):
        for layer in self.net:
            out = layer(x)
            x = nd.concat(x, out, dim=1)
        return x

In [5]:
dblk = DenseBlock(2, 10)
dblk.initialize()
x = nd.random_uniform(shape=(4, 3, 8, 8))
dblk(x).shape

(4, 23, 8, 8)

In [6]:
# 过渡块： 由于使用拼接的缘故，每一次dense之后，输出的通道数可能会激增
# 为控制模型复杂度，引入过渡块，把输入的长宽减半，同时使用1×1的卷积来改变通道数

def transition_block(channels):
    out = nn.Sequential()
    out.add(nn.BatchNorm(), nn.Activation('relu'),
            nn.Conv2D(channels, kernel_size=1), 
            nn.AvgPool2D(pool_size=2, strides=2))
    return out

In [7]:
tblk = transition_block(10)
tblk.initialize()
tblk(x).shape

(4, 10, 4, 4)

In [8]:
# DenseNet的主体就是交替串联使用稠密块（DenseBlock）和过渡块
# 过渡层利用过渡块，每次将通道数减半

init_channel = 64
growth_rate = 32
block_layers = [6, 12, 24, 16]

num_classes = 10


def dense_net():
    net = nn.Sequential()
    with net.name_scope():
        net.add(nn.Conv2D(init_channel, kernel_size=7, strides=2, padding=3),
                nn.BatchNorm(), nn.Activation('relu'),
                nn.MaxPool2D(pool_size=3, strides=2, padding=1))

        channels = init_channel
        for i, layers in enumerate(block_layers):
            net.add(DenseBlock(layers, growth_rate))
            channels += layers * growth_rate
            if len(block_layers) - 1 != i:
                net.add(transition_block(channels // 2))

        net.add(nn.BatchNorm(), nn.Activation('relu'),
                nn.AvgPool2D(pool_size=1), nn.Flatten(),
                nn.Dense(num_classes))
    return net

In [None]:
from mxnet import gluon
from mxnet import init
import utils

train_data, test_data = utils.load_data_fashion_mnist_new(batch_size=64, resize=32)
ctx = utils.try_gpu()

net = dense_net()
net.initialize(ctx=ctx, init=init.Xavier())

loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
utils.train(train_data, test_data, net, loss, trainer, ctx, num_epochs=1)