# Networks

> networks


In [1]:
#| default_exp networks

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export

from fastai.vision.all import ConvLayer, Lambda, MaxPool, NormType, np
from torch import cat as torch_cat
import torch.nn as nn
from torch.nn import functional as F, init

from Noise2Model.utils import attributesFromDict

import os
from importlib import import_module

In [4]:
from torch import randn as torch_randn
from fastai.vision.all import test_eq

In [5]:
#| export

network_class_dict = {}

def regist_network(model_class):
    model_name = model_class.__name__.lower()
    assert not model_name in network_class_dict, 'there is already registered model: %s in network_class_dict.' % model_name
    network_class_dict[model_name] = model_class

    return model_class

def get_network_class(model_name:str):
    model_name = model_name.lower()
    return network_class_dict[model_name]

In [6]:
x = torch_randn(16, 1, 32, 64, 64)
xdim = len(x.shape)-2

tst = ConvLayer(1, 1, ndim=xdim)
test_eq(tst(x).shape, [16, 1, 32, 64, 64])
tst = MaxPool(2, ndim=xdim)
test_eq(tst(x).shape, [16, 1, 16, 32, 32])
tst = Lambda(lambda x: x+np.float32(1e-3))
test_eq(tst(x).shape, [16, 1, 32, 64, 64])
test_eq(torch_cat((x, tst(x)), 1).shape, [16, 2, 32, 64, 64])

## DnCNN

In [7]:
#| export

@regist_network
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=18,features=64):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = kernel_size//2
        self.bias=True
        self.residual=True
        layers = list()
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=self.bias))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=self.bias))
            layers.append(nn.BatchNorm2d(features, momentum=0.9, eps=1e-04, affine=True))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=self.bias))
        self.dncnn = nn.Sequential(*layers)

    def forward(self, x, kwargs=None):
        if self.residual:
            out = x-self.dncnn(x)
        else:
            out = self.dncnn(x)
        return out
    


In [8]:
#| export

# class MyDnCNN(nn.Module):
#     def __init__(self, channels, num_of_layers=9, features=64, kernel_size=3):
#         super(DnCNN, self).__init__()
#         padding = 1
#         layers = []
#         layers.append(ConvLayer(channels, features,
#                       ks=kernel_size, padding=padding, norm_type=None))
#         for _ in range(num_of_layers-2):
#             layers.append(ConvLayer(features, features,
#                           ks=kernel_size, padding=padding))
#         layers.append(nn.Conv2d(in_channels=features, out_channels=channels,
#                       kernel_size=kernel_size, padding=padding, bias=False))
#         self.dncnn = nn.Sequential(*layers)

#     def forward(self, x):
#         residual = self.dncnn(x)
#         denoised = x - residual
#         return denoised

In [9]:
x = torch_randn(16, 1, 32, 64)

tst = DnCNN(1)
test_eq(tst(x).shape, [16, 1, 32, 64])
print(tst)

DnCNN(
  (dncnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(64, eps=0.0001, momentum=0.9, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(64, eps=0.0001, momentum=0.9, affine=True, track_running_stats=True)
    (7): ReLU(inplace=True)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(64, eps=0.0001, momentum=0.9, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(64, eps=0.0001, momentum=0.9, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchN

## My UNet

In [10]:
#| export

def SubNetConv(ks=3,
               stride=1,
               padding=None,
               bias=None,
               ndim=2,
               norm_type=NormType.Batch,
               bn_1st=True,
               act_cls=nn.ReLU,
               transpose=False,
               init='auto',
               xtra=None,
               bias_std=0.01,
               dropout=0.0,
               ):

    def _conv(n_in, n_out, n_conv=1):
        s = ConvLayer(n_in, n_out, ks=ks, stride=stride, padding=padding, bias=bias, ndim=ndim, norm_type=norm_type, bn_1st=bn_1st,
                      act_cls=act_cls, transpose=transpose, init=init, xtra=xtra, bias_std=bias_std)
        if dropout is not None and dropout > 0:
            s = nn.Sequential(s, nn.Dropout(dropout))
        for _ in range(n_conv-1):
            t = ConvLayer(n_out, n_out, ks=ks, stride=stride, padding=padding, bias=bias, ndim=ndim, norm_type=norm_type, bn_1st=bn_1st,
                          act_cls=act_cls, transpose=transpose, init=init, xtra=xtra, bias_std=bias_std)
            if dropout is not None and dropout > 0:
                t = nn.Sequential(t, nn.Dropout(dropout))
            s = nn.Sequential(s, t)
        return s

    return _conv

In [11]:
x = torch_randn(16, 1, 32, 64, 64)
xdim = len(x.shape)-2

# reduce
tst = SubNetConv(3, padding=1, stride=2, ndim=xdim,
                 norm_type=NormType.Batch, dropout=.1)(1, 2, 2)
y = tst(x)
test_eq(y.shape, [16, 2, 8, 16, 16])
print(tst)
# upsample
tst = SubNetConv(ks=4, padding=0, stride=4, ndim=xdim, norm_type=NormType.Batch,
                 transpose=True)(2, 1)  # to double the size, the kernel cannot be odd
test_eq(tst(y).shape, [16, 1, 32, 64, 64])
print(tst)
del y
# ConvLayer(2*n_out_channels, n_out_channels, ks=ks, transpose=True, padding=(ks-1)//2)

Sequential(
  (0): Sequential(
    (0): ConvLayer(
      (0): Conv3d(1, 2, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Dropout(p=0.1, inplace=False)
  )
  (1): Sequential(
    (0): ConvLayer(
      (0): Conv3d(2, 2, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Dropout(p=0.1, inplace=False)
  )
)
ConvLayer(
  (0): ConvTranspose3d(2, 1, kernel_size=(4, 4, 4), stride=(4, 4, 4), bias=False)
  (1): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)


In [12]:
#| export

class _Net_recurse(nn.Module):
    def __init__(self,
                 depth=4,						# depth of the UNet network
                 mult_chan=32,					# number of filters at first layer
                 in_channels=1,					# number of input channels
                 kernel_size=3,					# kernel size of convolutional layers
                 ndim=2,							# number of spatial dimensions of the input data
                 n_conv_per_depth=2,				# number of convolutions per layer
                 activation=nn.ReLU,				# activation function used in convolutional layers
                 norm_type=NormType.Batch,
                 dropout=0.0,
                 pool=MaxPool,
                 pool_size=2,
                 ):
        """Class for recursive definition of U-network.p

        Parameters:
        in_channels - (int) number of channels for input.
        mult_chan - (int) factor to determine number of output channels
        depth - (int) if 0, this subnet will only be convolutions that double the channel count.
        """
        super().__init__()
        # Parameters
        self.depth = depth
        n_out = in_channels*mult_chan

        # Layer types
        Pooling = pool(ks=pool_size, ndim=ndim)
        UpSample = nn.Upsample(scale_factor=pool_size, mode='nearest')
        SubNet_Conv = SubNetConv(ks=kernel_size, stride=1, padding=None, bias=None, ndim=ndim, norm_type=norm_type,
                                 bn_1st=True, act_cls=activation, transpose=False, dropout=dropout)

        # Blocks
        self.sub_conv_more = SubNet_Conv(in_channels, n_out, n_conv_per_depth)
        if self.depth > 0:
            in_channels = n_out
            mult_chan = 2
            depth = (self.depth - 1)
            self.sub_u = nn.Sequential(Pooling,                                                         # layer reducing the image size (usually a pooling layer)
                                       _Net_recurse(depth, mult_chan, in_channels, kernel_size,
                                                    ndim, n_conv_per_depth, activation, norm_type,
                                                    dropout, pool, pool_size),                          # lower unet level
                                       # layer increasing the image size (usually an upsampling layer)
                                       UpSample,
                                       )
            self.sub_conv_less = SubNet_Conv(3*n_out, n_out, n_conv_per_depth)

    def forward(self, x):
        if self.depth == 0:
            return self.sub_conv_more(x)
        else:  # depth > 0
            # convolutions with increasing number of channels
            x_conv_more = self.sub_conv_more(x)
            x_from_sub_u = self.sub_u(x_conv_more)
            # concatenate the upsampled outputs of the lower level with the outputs of the next level in size
            x_cat = torch_cat((x_from_sub_u, x_conv_more), 1)
            # convolutions with decreasing number of channels
            x_conv_less = self.sub_conv_less(x_cat)
        return x_conv_less

In [13]:
#| export

@regist_network
class MyUNet(nn.Module):
    def __init__(self,
                 depth=4,						# depth of the UNet network
                 mult_chan=32,					# number of filters at first layer
                 in_channels=1,					# number of input channels
                 out_channels=1,					# number of output channels
                 last_activation=None,			# last activation before final result
                 kernel_size=3,					# kernel size of convolutional layers
                 ndim=2,							# number of spatial dimensions of the input data
                 n_conv_per_depth=2,				# number of convolutions per layer
                 activation='ReLU',				# activation function used in convolutional layers
                 norm_type=NormType.Batch,
                 dropout=0.0,
                 pool=MaxPool,
                 pool_size=2,
                 residual=False,
                 prob_out=False,
                 eps_scale=1e-3,
                 ):
        super().__init__()
        last_activation = getattr(F, f"{activation.lower()}") if last_activation == None else getattr(
            F, f"{last_activation.lower()}")
        activation = getattr(nn, f"{activation}")
        attributesFromDict(locals())		# stores all the input parameters in self

        self.net_recurse = _Net_recurse(depth, mult_chan, in_channels, kernel_size, ndim,
                                        n_conv_per_depth, activation, norm_type, dropout, pool, pool_size)
        self.conv_out = ConvLayer(mult_chan*in_channels, out_channels, ndim=ndim,
                                  ks=kernel_size, norm_type=None, act_cls=None, padding=1)

    def forward(self, x):
        x_rec = self.net_recurse(x)
        final = self.conv_out(x_rec)

        if self.residual:
            if not (self.out_channels == self.in_channels):
                raise ValueError(
                    "number of input and output channels must be the same for a residual net.")
            final = final + x
        final = self.last_activation(final)

        if self.prob_out:
            scale = ConvLayer(self.out_channels, self.out_channels,
                              ndim=self.ndim, ks=1, norm_type=None, act_cls=nn.Softplus)(x_rec)
            scale = Lambda(lambda x: x+np.float32(self.eps_scale))(scale)
            final = torch_cat((final, scale), 1)

        return final

In [14]:
# show_doc(UNet)

In [15]:
x = torch_randn(16, 1, 32, 64, 64)
xdim = len(x.shape)-2

tst = MyUNet(depth=1, ndim=xdim, n_conv_per_depth=1, residual=True)
mods = list(tst.children())
print(mods)
test_eq(tst(x).shape, [16, 1, 32, 64, 64])

[_Net_recurse(
  (sub_conv_more): ConvLayer(
    (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (sub_u): Sequential(
    (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): _Net_recurse(
      (sub_conv_more): ConvLayer(
        (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (2): Upsample(scale_factor=2.0, mode='nearest')
  )
  (sub_conv_less): ConvLayer(
    (0): Conv3d(96, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
), ConvLayer(
  (0): Conv3d(32, 1, kernel_size=(3, 3, 3), stride=(1, 

## UNet

In [16]:
#| export

@regist_network
class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=1,
        depth=5,
        wf=6,
        padding=True,
        batch_norm=True,
        up_mode='upconv',
        residual=True,
        drop_p=0.15
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.residual = residual
        self.n_classes = n_classes
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm, drop_p)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x, kwargs=None):
        x_in = x.clone()
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        if self.residual:
            output = self.last(x) + x_in[:, -self.n_classes:, ...]
        else:
            output = self.last(x)

        return output


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm, drop_p=0.15):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if drop_p > 1e-5:
            block.append(nn.Dropout2d(p=drop_p))
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

## ResNet 1D

In [17]:
#| export

class ResidualBlock(nn.Module):
    """A general-purpose residual block. Works only with 1-dim inputs."""

    def __init__(self,
                 features,
                 context_features,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False,
                 zero_initialization=True):
        super().__init__()
        self.activation = activation

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList([
                #nn.BatchNorm1d(features, eps=1e-3, track_running_stats=False)
                nn.BatchNorm1d(features, eps=1e-3)
                for _ in range(2)
            ])
        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)
        self.linear_layers = nn.ModuleList([
            nn.Linear(features, features)
            for _ in range(2)
        ])
        if dropout_probability > 0.:
            self.dropout = nn.Dropout(p=dropout_probability)
        else:
            self.dropout = None
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        if self.dropout:
            temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(
                torch.cat(
                    (temps, self.context_layer(context)),
                    dim=1
                ),
                dim=1
            )
        return inputs + temps


In [18]:
#| export

@regist_network
class ResidualNet(nn.Module):
    """A general-purpose residual network. Works only with 1-dim inputs."""

    def __init__(self,
                 in_features,
                 out_features,
                 hidden_features,
                 context_features=None,
                 num_blocks=2,
                 activation=F.relu,
                 dropout_probability=0.,
                 use_batch_norm=False):
        super().__init__()
        self.hidden_features = hidden_features
        self.context_features = context_features
        if context_features is not None:
            self.initial_layer = nn.Linear(in_features + context_features, hidden_features)
        else:
            self.initial_layer = nn.Linear(in_features, hidden_features)
        self.blocks = nn.ModuleList([
            ResidualBlock(
                features=hidden_features,
                context_features=context_features,
                activation=activation,
                dropout_probability=dropout_probability,
                use_batch_norm=use_batch_norm,
            ) for _ in range(num_blocks)
        ])
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, inputs, context=None):
        if context is None:
            temps = self.initial_layer(inputs)
        else:
            temps = self.initial_layer(
                torch.cat((inputs, context), dim=1)
            )
        for block in self.blocks:
            temps = block(temps, context=context)
        outputs = self.final_layer(temps)
        return outputs

In [19]:
#| hide
import nbdev; nbdev.nbdev_export()