In [1]:
import torch
import torch.nn as nn

### Implementation details
The purpose of this notebook is to provide details about the implementation of DefectSegNet in defectsegnet.py. 

The model is described in the paper as follows:

"The DefectSegNet architecture, shown [below], consists of a total of 19 hidden layers. On the encoder side, max pooling is performed after each dense block, enabling the succeeding block to extract higher level, more contextual (and abstract) features from the defect images. For the decoder, to recover the resolution we employed the transposed convolutions, a more sophisticated operator than bilinear interpolation, for up-sampling. There are equal numbers of max pooling layers and transposed convolution layers, so the output probability map has the same spatial resolution as the input image. For the design of skip connections, besides those already introduced in dense blocks, feature maps created during encoding are input to all the decoder layers of the same spatial resolution. This allows the feature maps of a certain spatial resolution to connect cross the encoder-decoder performing in a similar manner to a single dense block. The incorporation of these skip connections both within and across blocks is the primary difference between our DefectSegNet and the U-Net and the fully convolutional DenseNet. Lastly, the final hidden layer is a 3 × 3 convolutional layer with a sigmoid activation function for classification."

![model arch](arch.png "DefectSegNet architecture")

It can be noticed that the model architecture (except the last 2 layers) consits of repeating blocks of either two convolution layers and a max pooling layer or two convolution layers and a transposed convolution layer. There are three of each type followed by 2 convolution layers. A block of layers is implemented in the DenseConvBlock class:

In [2]:
class DenseConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, transpose = False):
        super().__init__()
        self.block_in_channels = in_channels
        self.out_channels = out_channels
        self.block_out_channels = 2 * out_channels + in_channels
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1))
        self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
        self.act1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(in_channels + out_channels, out_channels, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1))
        self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
        self.act2 = nn.ReLU()
        
        if transpose:
            self.tconv = nn.ConvTranspose2d(self.block_out_channels, int(out_channels / 2), kernel_size = (2, 2), stride = (2, 2))
            self.bn3 = nn.BatchNorm2d(self.tconv.out_channels)
        else:
            self.pool1 = nn.MaxPool2d(kernel_size = (2, 2), stride = (2, 2))
            self.bn3 = nn.BatchNorm2d(self.block_out_channels)
                    
        self.act3 = nn.ReLU()
    
    def forward(self, x, concat_channels = None):
        if concat_channels is not None:
            c1 = self.act1(self.bn1(self.conv1(torch.cat([x, * concat_channels], dim = 1))))
            c2 = self.act2(self.bn2(self.conv2(torch.cat([x, * concat_channels, c1], dim = 1))))
            t1 = self.act3(self.bn3(self.tconv(torch.cat([x, * concat_channels, c1, c2], dim = 1))))
            
            return c1, c2, t1
            
        else:
            c1 = self.act1(self.bn1(self.conv1(x)))
            c2 = self.act2(self.bn2(self.conv2(torch.cat([x, c1], dim = 1))))
            p1 = self.act3(self.bn3(self.pool1(torch.cat([x, c1, c2], dim = 1))))
                    
            return c1, c2, p1

While creating a DenseConvBlock object, if transpose = True, the block will consist of a transposed convolution layer instead of a max pooling layer. In the forward operation, concat_channels array stores inputs that must be used in the skip connections (if any). Each block returns intermediate outputs that are stored for future use.

3 blocks with max pooling, 3 blocks with transposed convolution and 2 convolution layers are combined to form the DefectSegNet architecture below. The elements in concat_channels is based on the official implementation of the model. The in_channels of the transposed convolution block is the sum of the number of channels in input from the previous block and elements in the concat_channels array.

In [3]:
class DefectSegNet(nn.Module):
    def __init__(self, in_channels = 1):
        super().__init__()
        
        self.convblock1 = DenseConvBlock(in_channels = in_channels, out_channels = 4)
        self.convblock2 = DenseConvBlock(in_channels = self.convblock1.block_out_channels, out_channels = 16)
        self.convblock3 = DenseConvBlock(in_channels = self.convblock2.block_out_channels, out_channels = 32)
        
        self.tconvblock1 = DenseConvBlock(in_channels = self.convblock3.block_out_channels, out_channels = 64, transpose = True)
        self.tconvblock2 = DenseConvBlock(in_channels = int(self.convblock2.block_out_channels + 
                                                            (2 * self.convblock3.out_channels) + (self.tconvblock1.out_channels / 2)), out_channels = 32, transpose = True)
        self.tconvblock3 = DenseConvBlock(in_channels = int(self.convblock1.block_out_channels + 
                                                            (2 * self.convblock2.out_channels) + (self.tconvblock2.out_channels / 2)), out_channels = 16, transpose = True)
        
        self.conv13 = nn.Conv2d(int(self.convblock1.block_in_channels + 
                                   (2 * self.convblock1.out_channels) + (self.tconvblock3.out_channels / 2)), 4, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1))
        self.act19 = nn.ReLU()
        self.conv14 = nn.Conv2d(int(self.conv13.in_channels + self.conv13.out_channels), 1, kernel_size = (3, 3), stride = (1, 1), padding = (1, 1))
        self.act20 = nn.Sigmoid() 
        
    def forward(self, x):
        c1, c2, p1 = self.convblock1(x)
        c3, c4, p2 = self.convblock2(p1)
        c5, c6, p3 = self.convblock3(p2)
        
        c7, c8, t1 = self.tconvblock1(p3, concat_channels = [])
        c9, c10, t2 = self.tconvblock2(t1, concat_channels = [p2, c5, c6])
        c11, c12, t3 = self.tconvblock3(t2, concat_channels = [p1, c3, c4])
        
        c13 = self.act19(self.conv13(torch.cat([x, c1, c2, t3], dim = 1)))
        c14 = self.act20(self.conv14(torch.cat([x, c1, c2, t3, c13], dim = 1)))
        
        return c14

In [4]:
defectsegnet = DefectSegNet()
defectsegnet

DefectSegNet(
  (convblock1): DenseConvBlock(
    (conv1): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU()
    (conv2): Conv2d(5, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU()
    (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (bn3): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act3): ReLU()
  )
  (convblock2): DenseConvBlock(
    (conv1): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU()
    (conv2): Conv2d(25, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_r