# Test notebook 

In [1]:
# For the matplotlib 
%matplotlib inline
# For reload functions explicitly
%load_ext autoreload
%autoreload 2

## Add the modules to the system path
import os
import sys
sys.path.append(os.path.join(".."))

## Test MaxBlurPool

In [2]:
# Imports
import torch

In [6]:
test_sample = torch.ones((8, 16, 16, 256,256))
pool_kernel_size = (2,2,2)

In [9]:

normal_pooling = torch.nn.MaxPool3d(kernel_size=pool_kernel_size)
test_sample_maxpool = normal_pooling(test_sample)

In [10]:
print(test_sample.size())
print(test_sample_maxpool.size())

torch.Size([8, 16, 16, 256, 256])
torch.Size([8, 16, 8, 128, 128])


In [25]:
# Generate MaxBlurPool3D for the N2V2 approach
class MaxBlurPool3D(torch.nn.Module):

    def __init__(self, pool_kernel_size, blur_kernel=None):
        super(MaxBlurPool3D, self).__init__()
        self.pool_kernel_size = pool_kernel_size
        self.blur_kernel = blur_kernel
        self.kernel = None

        if self.blur_kernel is None:
            self.blur_kernel = torch.tensor(
                [[[0.02595968, 0.03575371, 0.02595968],
                [0.03575371, 0.04924282, 0.03575371],
                [0.02595968, 0.03575371, 0.02595968]],

                [[0.03575371, 0.04924282, 0.03575371],
                [0.04924282, 0.06782107, 0.04924282],
                [0.03575371, 0.04924282, 0.03575371]],

                [[0.02595968, 0.03575371, 0.02595968],
                [0.03575371, 0.04924282, 0.03575371],
                [0.02595968, 0.03575371, 0.02595968]]],
                dtype = torch.float32, 
                device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
            )
            self.blur_kernel = self.blur_kernel / self.blur_kernel.sum()


    def forward(self, x):

        x = torch.nn.functional.max_pool3d(x,
                                           self.pool_kernel_size,
                                           stride=1,
                                           padding=self.pool_kernel_size[0]//2
                                           )
        if self.kernel is None:
            self.kernel = self.blur_kernel[None].repeat_interleave(x.size(dim=1), dim=0)[None].repeat_interleave(x.size(dim=1), dim=0)
        x = torch.nn.functional.conv3d(x,
                                       weight=self.kernel,
                                       stride=self.pool_kernel_size)
        return x

In [27]:
blur_maxpool = MaxBlurPool3D(pool_kernel_size)
test_sample_blur = blur_maxpool(test_sample)

In [28]:
print(test_sample.size())
print(test_sample_blur.size())

torch.Size([8, 16, 16, 256, 256])
torch.Size([8, 16, 8, 128, 128])


## Test network

In [2]:
from network import Noise2NoiseUNet3D

In [11]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import torch

In [12]:
## Submodules

class SingleConv(nn.Sequential):
    """
    Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
    of operations can be specified via the `order` parameter

    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, order='cr', num_groups=8, padding=1):
        super(SingleConv, self).__init__()

        for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
            self.add_module(name, module)


class DoubleConv(nn.Sequential):
    """
    A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
    We use (Conv3d+ReLU+GroupNorm3d) by default.
    This can be changed however by providing the 'order' argument, e.g. in order
    to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
    Use padded convolutions to make sure that the output (H_out, W_out) is the same
    as (H_in, W_in), so that you don't have to crop in the decoder path.

    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
        kernel_size (int): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='cr', num_groups=8):
        super(DoubleConv, self).__init__()
        if encoder:
            # we're in the encoder path
            conv1_in_channels = in_channels
            conv1_out_channels = out_channels // 2
            if conv1_out_channels < in_channels:
                conv1_out_channels = in_channels
            conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
        else:
            # we're in the decoder path, decrease the number of channels in the 1st convolution
            conv1_in_channels, conv1_out_channels = in_channels, out_channels
            conv2_in_channels, conv2_out_channels = out_channels, out_channels

        # conv1
        self.add_module('SingleConv1',
                        SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
        # conv2
        self.add_module('SingleConv2',
                        SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))


def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
    return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)

# Generate MaxBlurPool3D for the N2V2 approach
class MaxBlurPool3d(torch.nn.Module):

    def __init__(self, pool_kernel_size, blur_kernel=None):
        super(MaxBlurPool3d, self).__init__()
        self.pool_kernel_size = pool_kernel_size
        self.blur_kernel = blur_kernel
        self.kernel = None

        if self.blur_kernel is None:
            self.blur_kernel = torch.tensor(
                [[[0.02595968, 0.03575371, 0.02595968],
                [0.03575371, 0.04924282, 0.03575371],
                [0.02595968, 0.03575371, 0.02595968]],

                [[0.03575371, 0.04924282, 0.03575371],
                [0.04924282, 0.06782107, 0.04924282],
                [0.03575371, 0.04924282, 0.03575371]],

                [[0.02595968, 0.03575371, 0.02595968],
                [0.03575371, 0.04924282, 0.03575371],
                [0.02595968, 0.03575371, 0.02595968]]],
                dtype = torch.float32, 
                device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
            )
            self.blur_kernel = self.blur_kernel / self.blur_kernel.sum()


    def forward(self, x):

        x = torch.nn.functional.max_pool3d(x,
                                           self.pool_kernel_size,
                                           stride=1,
                                           padding=self.pool_kernel_size[0]//2
                                           )
        if self.kernel is None:
            self.kernel = self.blur_kernel[None].repeat_interleave(x.size(dim=1), dim=0)[None].repeat_interleave(x.size(dim=1), dim=0)
        x = torch.nn.functional.conv3d(x,
                                       weight=self.kernel,
                                       stride=self.pool_kernel_size)
        return x
       
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
    """
    Create a list of modules with together constitute a single conv layer with non-linearity
    and optional batchnorm/groupnorm.

    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        order (string): order of things, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
        padding (int): add zero-padding to the input

    Return:
        list of tuple (name, module)
    """
    assert 'c' in order, "Conv layer MUST be present"
    assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'

    modules = []
    for i, char in enumerate(order):
        if char == 'r':
            modules.append(('ReLU', nn.ReLU(inplace=True)))
        elif char == 'l':
            modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
        elif char == 'e':
            modules.append(('ELU', nn.ELU(inplace=True)))
        elif char == 'c':
            # add learnable bias only in the absence of gatchnorm/groupnorm
            bias = not ('g' in order or 'b' in order)
            modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
        elif char == 'g':
            is_before_conv = i < order.index('c')
            assert not is_before_conv, 'GroupNorm MUST go after the Conv3d'
            # number of groups must be less or equal the number of channels
            if out_channels < num_groups:
                num_groups = out_channels
            modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)))
        elif char == 'b':
            is_before_conv = i < order.index('c')
            if is_before_conv:
                modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
            else:
                modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
        else:
            raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")

    return modules

## Encoder and Decoder Networks

class Encoder(nn.Module):
    """
    A single module from the encoder path consisting of the optional max
    pooling layer (one may specify the MaxPool kernel_size to be different
    than the standard (2,2,2), e.g. if the volumetric data is anisotropic
    (make sure to use complementary scale_factor in the decoder path) followed by
    a DoubleConv module.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        conv_kernel_size (int): size of the convolving kernel
        apply_pooling (bool): if True use MaxPool3d before DoubleConv
        pool_kernel_size (tuple): the size of the window to take a max over
        pool_type (str): pooling layer: 'max' 'maxblur' or 'avg'
        basic_module(nn.Module): either ResNetBlock or DoubleConv
        conv_layer_order (string): determines the order of layers
            in `DoubleConv` module. See `DoubleConv` for more info.
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
                 pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='cr',
                 num_groups=8):
        super(Encoder, self).__init__()
        assert pool_type in ['max', 'maxblur', 'avg']
        if apply_pooling:
            if pool_type == 'max':
                self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
            elif pool_type == 'maxblur':
                self.pooling = MaxBlurPool3d(pool_kernel_size=pool_kernel_size)
            else:
                self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
        else:
            self.pooling = None

        self.basic_module = basic_module(in_channels, out_channels,
                                         encoder=True,
                                         kernel_size=conv_kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups)

    def forward(self, x):
        if self.pooling is not None:
            x = self.pooling(x)
        x = self.basic_module(x)
        return x


class Decoder(nn.Module):
    """
    A single module for decoder path consisting of the upsample layer
    (either learned ConvTranspose3d or interpolation) followed by a DoubleConv
    module.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int): size of the convolving kernel
        scale_factor (tuple): used as the multiplier for the image H/W/D in
            case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
            from the corresponding encoder
        basic_module(nn.Module): either ResNetBlock or DoubleConv
        conv_layer_order (string): determines the order of layers
            in `DoubleConv` module. See `DoubleConv` for more info.
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='cr', num_groups=8):
        super(Decoder, self).__init__()

        if basic_module == DoubleConv:
            # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling
            self.upsample = None
        else:
            # otherwise use ConvTranspose3d (bear in mind your GPU memory)
            # make sure that the output size reverses the MaxPool3d from the corresponding encoder
            # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0])
            # also scale the number of channels from in_channels to out_channels so that summation joining
            # works correctly
            self.upsample = nn.ConvTranspose3d(in_channels,
                                               out_channels,
                                               kernel_size=kernel_size,
                                               stride=scale_factor,
                                               padding=1,
                                               output_padding=1)
            # adapt the number of in_channels for the ExtResNetBlock
            in_channels = out_channels

        self.basic_module = basic_module(in_channels, out_channels,
                                         encoder=False,
                                         kernel_size=kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups)

    def forward(self, encoder_features, x):
        if self.upsample is None:
            # If it is the N2V2 Setup, we don't have encoder_features in the last layer of the U-net, encoder_features is then the output-shape
            if type(encoder_features) is torch.Size:
                output_size = encoder_features
                print("Is a tuple", encoder_features)
                # use nearest neighbor interpolation
                x = F.interpolate(x, size=output_size, mode='nearest')
            else:
                # use nearest neighbor interpolation and concatenation joining
                output_size = encoder_features.size()[2:]
                x = F.interpolate(x, size=output_size, mode='nearest')
                # concatenate encoder_features (encoder path) with the upsampled input across channel dimension
                x = torch.cat((encoder_features, x), dim=1)
        else:
            # use ConvTranspose3d and summation joining
            x = self.upsample(x)
            if not self.is_last_layer:
                x += encoder_features

        x = self.basic_module(x)
        return x


## Initialization function
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>

In [26]:
class Noise2NoiseUNet3D(nn.Module):
    """
    Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf and https://www.nature.com/articles/s41592-021-01225-0.
    Uses ExtResNetBlock instead of DoubleConv as a basic building block as well as summation joining instead
    of concatenation joining. Since the model effectively becomes a residual net, in theory it allows for deeper UNet.

    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output segmentation masks;
            Note that that the of out_channels might correspond to either
            different semantic classes or to different binary segmentation mask.
            It's up to the user of the class to interpret the out_channels and
            use the proper loss criterion during training (i.e. NLLLoss (multi-class)
            or BCELoss (two-class) respectively)
        f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
            of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4,5
        init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
        num_groups (int): number of groups for the GroupNorm
        is_N2V2_setup (bool): Flag if using the N2V2 setup model (source: https://openreview.net/forum?id=IZfQYb4lHVq), which has no last skip connection and uses max-blurring 
    """

    def __init__(self, in_channels, out_channels, f_maps=16, num_groups=8, is_N2V2_setup=False, **kwargs):
        super(Noise2NoiseUNet3D, self).__init__()

        # Set flag if N2V2 or not
        self.is_N2V2_setup = is_N2V2_setup
        # Use LeakyReLU activation everywhere except the last layer
        conv_layer_order = 'clg'

        if isinstance(f_maps, int):
            # use 5 levels in the encoder path as suggested in the paper
            f_maps = self.__create_feature_maps(f_maps, number_of_fmaps=5)

        # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
        # uses DoubleConv as a basic_module for the Encoder
        encoders = []
        for i, out_feature_num in enumerate(f_maps):
            if i == 0:
                encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv,
                                  conv_layer_order=conv_layer_order, num_groups=num_groups)
            else:
                if self.is_N2V2_setup:
                    encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
                                    conv_layer_order=conv_layer_order, num_groups=num_groups, pool_type='maxblur')
                else:
                    encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
                                    conv_layer_order=conv_layer_order, num_groups=num_groups)
            encoders.append(encoder)

        self.encoders = nn.ModuleList(encoders)

        # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
        # uses DoubleConv as a basic_module for the Decoder
        decoders = []
        reversed_f_maps = list(reversed(f_maps))
        for i in range(len(reversed_f_maps) - 1):
            if self.is_N2V2_setup and i==len(reversed_f_maps) - 2:
                in_feature_num = reversed_f_maps[i]
                print("Last layer")
                print("Reuglar output: ", reversed_f_maps[i] + reversed_f_maps[i + 1])
            else:
                in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
            out_feature_num = reversed_f_maps[i + 1]
            decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv,
                              conv_layer_order=conv_layer_order, num_groups=num_groups)
            print("Input channel: ", in_feature_num)
            print("Output channel: ", out_feature_num)
            decoders.append(decoder)

        self.decoders = nn.ModuleList(decoders)

        # 1x1x1 conv + simple ReLU in the final convolution
        self.final_conv = SingleConv(f_maps[0], out_channels, kernel_size=1, order='cr', padding=0)

    def forward(self, x):
        # encoder part
        encoders_features = []
        for ind, encoder in enumerate(self.encoders):
            x = encoder(x)
            # reverse the encoder outputs to be aligned with the decoder
            if self.is_N2V2_setup and ind==0:
                encoders_features.insert(0, x.size()[2:])
            else:
                encoders_features.insert(0, x)

        # remove the last encoder's output from the list
        # !!remember: it's the 1st in the list
        encoders_features = encoders_features[1:]

        running_var = 0
        # decoder part
        for decoder, encoder_features in zip(self.decoders, encoders_features):
            # pass the output from the corresponding encoder and the output
            # of the previous decoder
            running_var = running_var + 1
            print("Running variabl: ", running_var)
            x = decoder(encoder_features, x)

        x = self.final_conv(x)

        return x

    def __create_feature_maps(self, init_channel_number, number_of_fmaps):
        return [init_channel_number * 2 ** k for k in range(number_of_fmaps)]

In [3]:
model = Noise2NoiseUNet3D(1,1, is_N2V2_setup=True)

In [6]:
test_batch = torch.rand((8,1,16,64,64))
test_output = model(test_batch)

In [7]:
print(test_batch.size())
print(test_output.size())

torch.Size([8, 1, 16, 64, 64])
torch.Size([8, 1, 16, 64, 64])
