In [2]:
import torch
from torch import nn

In [16]:
class BottleNeck(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super().__init__()

        self.conv1x1 = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=4*growth_rate, kernel_size=(1, 1), bias=False)
        )
        self.conv3x3 = nn.Sequential(
            nn.BatchNorm2d(4*growth_rate),
            nn.ReLU(),
            nn.Conv2d(in_channels=4*growth_rate, out_channels=growth_rate, kernel_size=(3, 3), padding=1, bias=False)
        )

    def forward(self, x):
        input_feat = x
        x = self.conv1x1(x)
        x = self.conv3x3(x)
        # :input: [N, in_channels, H, W] + :x: [N, growth_rate, H, W] = [N, in_channels + growth_rate, H, W]  
        return torch.cat([input_feat, x], dim=1)

In [17]:
class Transition(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.transition = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=(1, 1), bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.transition(x)
        return x

In [18]:
import torch
from torch import nn

class DenseNet(nn.Module):
    def __init__(self, block_list, growth_rate, num_classes):
        super().__init__()

        self.growth_rate = growth_rate
        self.in_channels = 2 * self.growth_rate
        
        # First Conv layer and pooling
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(2*growth_rate),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        # Dense Blocks
        self.denseblock1 = self._make_dense_block(block_list[0])
        self.denseblock2 = self._make_dense_block(block_list[1])
        self.denseblock3 = self._make_dense_block(block_list[2])
        self.denseblock4 = self._make_dense_block(block_list[3], last_block=True)
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, num_classes)
        

    def _make_dense_block(self, num_BottleNecks, last_block=False):
        layers = []
        for _ in range(num_BottleNecks):
            layers.append(BottleNeck(self.in_channels, self.growth_rate))
            self.in_channels += self.growth_rate # In a dense block, each layer's channel increase by growth_rate
        
        if last_block:
            layers.append(nn.BatchNorm2d(self.in_channels))    
            layers.append(nn.ReLU())
        else: # If not a last dense block, append Transition layer
            layers.append(Transition(self.in_channels))
            self.in_channels //= 2 # before moving to next Dense Block, reduce the channels by factor of 2

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.denseblock1(x)
        x = self.denseblock2(x)
        x = self.denseblock3(x)
        x = self.denseblock4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        return x

In [19]:
DenseNet_121 = DenseNet(block_list=[6, 12, 24, 16], growth_rate=32, num_classes=10)
DenseNet_169 = DenseNet(block_list=[6, 12, 32, 32], growth_rate=32, num_classes=10)
DenseNet_201 = DenseNet(block_list=[6, 12, 48, 32], growth_rate=32, num_classes=10)
DenseNet_264 = DenseNet(block_list=[6, 12, 64, 48], growth_rate=32, num_classes=10)

In [20]:
DenseNet_121

DenseNet(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (denseblock1): Sequential(
    (0): BottleNeck(
      (conv1x1): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (conv3x3): Sequential(
        (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (1): BottleNeck(
      (conv1x1): Sequential(
        (0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
