## Underwater image enhancement model with PyTorch

This notebook contains the `model archithecture` for `Genaretor`

The generator is based on `ResNet`

In [2]:
import torch
import functools
from torch import nn
from torch.utils.tensorboard import SummaryWriter

In [9]:
class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, norm_layer=nn.BatchNorm2d, padding_type='reflect', use_dropout=True, use_bias=True):
        """Initialize the Resnet block

        Construct a convolutional block.

        :param dim: the number of channels in the conv layer.
        :param norm_layer: normalization layer
        :param padding_type: the name of padding layer: reflect | replicate | zero
        :param use_dropout: if use dropout layers.
        :param use_bias: if the conv layer uses bias or not

        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))


        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, norm_layer, padding_type, use_dropout, use_bias)
        
        
    def build_conv_block(self, dim, norm_layer, padding_type, use_dropout, use_bias):
        """Initialize the Resnet block

        Construct a convolutional block.

        :param dim: the number of channels in the conv layer.
        :param norm_layer: normalization layer
        :param padding_type: the name of padding layer: reflect | replicate | zero
        :param use_dropout: if use dropout layers.
        :param use_bias: if the conv layer uses bias or not
        :return: conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []

        # First Reflection Padding
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        # First conv layer
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]

        # First dropout layer
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        # First Reflection Padding
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        
        out = x + self.conv_block(x)  # add skip connections
        return out

In [10]:
ResnetBlock(3)

ResnetBlock(
  (conv_block): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (2): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Dropout(p=0.5, inplace=False)
    (5): ReflectionPad2d((1, 1, 1, 1))
    (6): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
    (7): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [11]:
class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, padding_type='reflect', use_dropout=False, n_blocks=9):
        """Construct a Resnet-based generator

        :param input_nc: the number of channels in input images
        :param output_nc: the number of channels in output images
        :param ngf: the number of filters in the last conv layer
        :param norm_layer: normalization layer
        :param padding_type: the name of padding layer: reflect | replicate | zero
        :param use_dropout: if use dropout layers
        :param n_blocks: the number of ResNet blocks
        """
        
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        
        # use_bias is set false if BatchNorm2d is used as norm layer
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
            
        # First Conv layer
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        # 2 Down sampling layers (2nd and 3rd)
        n_down_sampling = 2
        for i in range(n_down_sampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_down_sampling
        
        # n_blocks resnet layers
        for i in range(n_blocks):

            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, padding_type=padding_type, use_dropout=use_dropout, use_bias=use_bias)]

        # Add up sampling Layers
        for i in range(n_down_sampling):
            mult = 2 ** (n_down_sampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        """Standard forward"""
        return self.model(x)

In [12]:
ResnetGenerator(3,3, use_dropout=True)

ResnetGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): Dropout

In [8]:
model = ResnetGenerator(3,3, use_dropout=True)
res= model(torch.autograd.Variable(torch.Tensor(1,3,256,256), requires_grad=True))
res.shape

torch.Size([1, 3, 256, 256])

In [9]:
writer = SummaryWriter()
writer.add_graph(model, torch.autograd.Variable(torch.Tensor(1,3,256,256), requires_grad=True))
writer.close()