# MNIST DenseNet

In [2]:
import numpy as np
import torch 
import torch.nn as nn 
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

[DenseNet和ResNet区别](https://blog.csdn.net/u013247002/article/details/84857593)

[ReLU1](https://www.zhihu.com/question/29021768)

[ReLU2](https://blog.csdn.net/lee813/article/details/80993355)

In [3]:
class Dense_Block(nn.Module):
    def __init__(self, in_channels):
        super(Dense_Block, self).__init__()

        self.relu = nn.ReLU(inplace = True)
        self.bn = nn.BatchNorm2d(num_features = in_channels)

        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        self.conv4 = nn.Conv2d(in_channels = 96, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
        self.conv5 = nn.Conv2d(in_channels = 128, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)

    def forward(self, x):

        bn = self.bn(x)
        conv1 = self.relu(self.conv1(bn))

        conv2 = self.relu(self.conv2(conv1))
        c2_dense = self.relu(torch.cat([conv1, conv2], 1))

        conv3 = self.relu(self.conv3(c2_dense))
        c3_dense = self.relu(torch.cat([conv1, conv2, conv3], 1))

        conv4 = self.relu(self.conv4(c3_dense))
        c4_dense = self.relu(torch.cat([conv1, conv2, conv3, conv4], 1))

        conv5 = self.relu(self.conv5(c4_dense))
        c5_dense = self.relu(torch.cat([conv1, conv2, conv3, conv4, conv5], 1))

        return c5_dense

class Transition_Layer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Transition_Layer, self).__init__()

        self.relu = nn.ReLU(inplace = True)
        self.bn = nn.BatchNorm2d(num_features = out_channels)
        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, bias = False)
        self.avg_pool = nn.AvgPool2d(kernel_size = 2, stride = 2, padding = 0)

    def forward(self, x):

        bn = self.bn(self.relu(self.conv(x)))
        out = self.avg_pool(bn)

        return out

class DenseNet(nn.Module):
    def __init__(self, nr_classes):
        super(DenseNet, self).__init__()

        self.lowconv = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 7, padding = 3, bias = False)
        self.relu = nn.ReLU()

        # Make Dense Blocks
        self.denseblock1 = self._make_dense_block(Dense_Block, 64)
        self.denseblock2 = self._make_dense_block(Dense_Block, 128)
        self.denseblock3 = self._make_dense_block(Dense_Block, 128)

        # Make transition Layers
        self.transitionLayer1 = self._make_transition_layer(Transition_Layer, in_channels = 160, out_channels = 128)
        self.transitionLayer2 = self._make_transition_layer(Transition_Layer, in_channels = 160, out_channels = 128)
        self.transitionLayer3 = self._make_transition_layer(Transition_Layer, in_channels = 160, out_channels = 64)

        # Classifier
        self.bn = nn.BatchNorm2d(num_features = 64)
        self.pre_classifier = nn.Linear(64*3*3, 512)
        self.classifier = nn.Linear(512, nr_classes)

    def _make_dense_block(self, block, in_channels):
        layers = []
        layers.append(block(in_channels))
        return nn.Sequential(*layers)

    def _make_transition_layer(self, layer, in_channels, out_channels):
        modules = []
        modules.append(layer(in_channels, out_channels))
        return nn.Sequential(*modules)

    def forward(self, x):
        out = self.relu(self.lowconv(x))
        print(list(out.size()))
        out = self.denseblock1(out)
        print(list(out.size()))
        out = self.transitionLayer1(out)
        print(list(out.size()))
        out = self.denseblock2(out)
        print(list(out.size()))
        out = self.transitionLayer2(out)
        print(list(out.size()))
        out = self.denseblock3(out)
        print(list(out.size()))
        out = self.transitionLayer3(out)
        print(list(out.size()))
        out = self.bn(out)
#         print(out.shape)
        out = out.view(-1, 64*3*3)

        out = self.pre_classifier(out)
        out = self.classifier(out)
        print("="*20)
        return out

In [None]:
model = DenseNet(nr_classes=10)
batch_size = 8
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081,))
])
train_datasets = datasets.MNIST(
    root="../../dataset/mnist",
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_datasets, shuffle=True, batch_size=batch_size)
test_datasets = datasets.MNIST(root="../../dataset/mnist",
                               train=False,
                               download=True,
                               transform=transform)
test_loader = DataLoader(test_datasets, shuffle=True, batch_size=batch_size)


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
#         inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if batch_idx % 500 == 0:
            print('[%d, %5d] loss : %.3f' % (epoch + 1, batch_idx + 1, running_loss/300))
            running_loss = 0.
        break
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data 
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Accuracy on test set : %d %%" % ( 100 *correct / total))
# if __name__ == '__main__':
# model = Net()
for epoch in range(1):
    train(epoch)
    test()


[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[1,     1] loss : 0.007
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7

[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 

[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 

[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 

[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 

[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 28, 28]
[8, 160, 28, 28]
[8, 128, 14, 14]
[8, 160, 14, 14]
[8, 128, 7, 7]
[8, 160, 7, 7]
[8, 1, 28, 28]
[8, 64, 