In [1]:
import torch
from torch import nn

In [2]:
def conv_block(num_channels):
    return nn.Sequential(
        nn.LazyBatchNorm2d(),nn.ReLU(),
        nn.LazyConv2d(num_channels, kernel_size=3, padding=1)
    )

In [7]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs, num_channels):
        super().__init__()
        layer = []
        for i in range(num_convs):
            layer.append(conv_block(num_channels))
        self.net = nn.Sequential(*layer)
    
    def forward(self, x):
        for blk in self.net:
            Y = blk(x)
            x = torch.cat((x,Y),dim=1)
            print(x.shape)
        return x

In [11]:
blk = DenseBlock(2, 10)
x = torch.randn(4,3,8,8)
Y = blk(x)
Y.shape

torch.Size([4, 13, 8, 8])
torch.Size([4, 23, 8, 8])




torch.Size([4, 23, 8, 8])

In [12]:
def transition_block(num_channels):
    return nn.Sequential(
        nn.LazyBatchNorm2d(),nn.ReLU(),
        nn.LazyConv2d(num_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2)
    )

In [13]:
blk = transition_block(10)
blk(Y).shape

torch.Size([4, 10, 4, 4])

In [14]:
from d2l_common import Classifier


class DenseNet(Classifier):
    def b1(self):
        return nn.Sequential(
            nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
            nn.LazyBatchNorm2d(), nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

    def __init__(self, num_channels=64, growth_rate=32, arch=(4, 4, 4, 4), lr=0.1, num_classes=10):
        super().__init__()
        self.net = self.b1()
        for i, num_convs in enumerate(arch):
            self.net.add_module(
                f'dense_blk{i+1}', DenseBlock(num_convs, growth_rate))
            num_channels+=num_convs*growth_rate
            if i!=len(arch)-1:
                num_channels//=2
                self.net.add_module(f'trans_blk{i+1}',transition_block(num_channels))
        self.net.add_module('last',nn.Sequential(
            nn.LazyBatchNorm2d(),nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),
            nn.LazyLinear(num_classes) 
        ))

In [17]:
model = DenseNet()
# x = torch.randn(1,1,224,224)
model.layer_summary((1,1,224,224))

Conv2d output shape:	 torch.Size([1, 64, 112, 112])
BatchNorm2d output shape:	 torch.Size([1, 64, 112, 112])
ReLU output shape:	 torch.Size([1, 64, 112, 112])
MaxPool2d output shape:	 torch.Size([1, 64, 56, 56])
torch.Size([1, 96, 56, 56])
torch.Size([1, 128, 56, 56])
torch.Size([1, 160, 56, 56])
torch.Size([1, 192, 56, 56])
DenseBlock output shape:	 torch.Size([1, 192, 56, 56])
Sequential output shape:	 torch.Size([1, 96, 28, 28])
torch.Size([1, 128, 28, 28])
torch.Size([1, 160, 28, 28])
torch.Size([1, 192, 28, 28])
torch.Size([1, 224, 28, 28])
DenseBlock output shape:	 torch.Size([1, 224, 28, 28])
Sequential output shape:	 torch.Size([1, 112, 14, 14])
torch.Size([1, 144, 14, 14])
torch.Size([1, 176, 14, 14])
torch.Size([1, 208, 14, 14])
torch.Size([1, 240, 14, 14])
DenseBlock output shape:	 torch.Size([1, 240, 14, 14])
Sequential output shape:	 torch.Size([1, 120, 7, 7])
torch.Size([1, 152, 7, 7])
torch.Size([1, 184, 7, 7])
torch.Size([1, 216, 7, 7])
torch.Size([1, 248, 7, 7])
DenseB

In [None]:
from d2l_common import Trainer,FasionMNIST
trainer = Trainer(max_epochs=10)
data = FasionMNIST(batch_size=128)
trainer.fit(model, data)