In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import platform
import random
import uuid
import os
import os.path
import skimage
import skimage.segmentation
import torch
import utils
import utils.wavelet
import utils.data
import utils.data.augmentation
import numpy as np
import scipy as sp
import scipy.signal
import pandas as pd
import networkx
import networkx.algorithms.approximation
import wfdb
import json
import tqdm
import matplotlib.pyplot as plt
from scipy.stats import norm
from utils.signal import StandardHeader
from scipy.stats import lognorm, norm, halfnorm


# Define network

In [5]:
class MobileUNet(torch.nn.Module):
    '''Model'''

    def __init__(self, config):
        '''Initialization'''
        super(MobileUNet, self).__init__()

        # Parameters
        self.m_name         = config.m_name
        self.m_repetitions  = config.m_repetitions
        self.out_ch         = config.out_ch
        self.start_ch       = config.start_ch
        self.kernel_size    = config.kernel_size
        self.depth          = config.depth
        self.inc_rate       = config.inc_rate
        self.maxpool        = config.maxpool
        self.kernel_init    = config.kernel_init

        # Storing the architecture
        self.encoder_levels = torch.nn.ModuleList([torch.nn.ModuleList() for i in range(self.depth)])
        self.decoder_levels = torch.nn.ModuleList([torch.nn.ModuleList() for i in range(self.depth-1)])  # The encoder "contains" the bottleneck level
        self.encoder_transitions = torch.nn.ModuleList([torch.nn.ModuleList() for i in range(self.depth)])
        self.decoder_transitions = torch.nn.ModuleList([torch.nn.ModuleList() for i in range(self.depth)])

        # Easing the notation
        StemBlock = StemModule(self.m_name)
        ConvBlock = LevelModule(self.m_name)
        AtrousBlock = AtrousMiddleModule(self.m_name)
        UpsamplingBlock = torch.nn.modules.Upsample
        MaxAvgPool1d = torch.nn.MaxPool1d if self.maxpool else torch.nn.AvgPool1d
        OutputBlock = OutputModule(self.m_name, self.regression)

        ################### ARCHITECTURE DEFINITION ###################
        # Up and downsampling blocks
        [[self.encoder_transitions[i].append(MaxAvgPool1d(int(self.inc_rate**(self.depth-j)))) for j in range(self.depth-1, i, -1)[::-1]] for i in range(0, self.depth)]
        [[self.decoder_transitions[i].append(UpsamplingBlock(scale_factor=int(self.inc_rate**(j+1)))) for j in range(0, i)[::-1]] for i in range(0, self.depth)]

        # Channel contribution of different additions
        skipped_con_ch = [int((self.inc_rate**i)*self.start_ch) for i in range(self.depth-1)]
        dense_accum_ch = [0 for i in range(self.depth)]
        mulscale_up_ch = [0 for i in range(self.depth)]
        decoder_1st_ch = [0 for i in range(self.depth-1)]

        #### STEM ####
        # Apply stem
        self.encoder_levels[0].append(StemBlock(1, int(self.start_ch//self.inc_rate), self.kernel_size, self.kernel_init))

        #### ENCODER ####
        # Encoder levels
        for i in range(self.depth):
            # Output Channels for this level
            out_channels = int((self.inc_rate**i)*self.start_ch)

            for j in range(self.m_repetitions):
                # Input channels for this block
                in_channels = int((self.inc_rate**(i - (j == 0)))*self.start_ch)
                in_channels += int(self.hyperdense*dense_accum_ch[i])

                # Encoder - levels
                if (i != self.depth-1):  # Encoder levels
                    self.encoder_levels[i].append(ConvBlock(in_channels, out_channels, self.kernel_size, self.kernel_init))
                else:  # Embedding level (deepmost)
                    if (j == self.m_repetitions - 1) and self.atrous_conv:
                        self.encoder_levels[i].append(AtrousBlock(in_channels, out_channels, self.kernel_size, self.kernel_init))
                    else:
                        self.encoder_levels[i].append(ConvBlock(in_channels, out_channels, self.kernel_size, self.kernel_init))

                # Update dense connection accumulator
                dense_accum_ch[i] = in_channels

            # Update dense connection accumulator
            dense_accum_ch[i] += int(out_channels * (1 + 4*(i == self.depth-1)*self.atrous_conv))

        #### DECODER ####
        # Store the last block channels
        decoder_1st_ch[-1] = int(out_channels * (1 + 4*(i == self.depth-1)*self.atrous_conv))

        # Multiscale upsampling
        for j in range(self.depth-3, -1, -1):
            mulscale_up_ch[j] += int(out_channels * (1 + 4*(i == self.depth-1)*self.atrous_conv))

        # Decoder levels
        for i in range(self.depth-1)[::-1]:
            # Output Channels for this level
            out_channels = int((self.inc_rate**i)*self.start_ch)

            for j in range(self.m_repetitions):
                # Input channels for this block
                in_channels  = int((j == 0)*decoder_1st_ch[i] + (j != 0)*(self.inc_rate**i)*self.start_ch)
                in_channels += int((j == 0)*skipped_con_ch[i])
                in_channels += int((j == 0)*self.ms_upsampling*mulscale_up_ch[i])
                in_channels += int(self.hyperdense*(dense_accum_ch[i] - (j == 0)*skipped_con_ch[i]))

                # Decoder - levels
                if i != self.depth-1: self.decoder_levels[i].append(ConvBlock(in_channels, out_channels, self.kernel_size, self.kernel_init))

                # Update dense connection accumulator
                dense_accum_ch[i] = in_channels

            if i != 0: decoder_1st_ch[i-1] = out_channels

            # Update dense connection accumulator
            dense_accum_ch[i] += out_channels

            # Multiscale upsampling
            for j in range(i-2, -1, -1):
                mulscale_up_ch[j] += out_channels

        in_channels = int((1-self.hyperdense)*self.start_ch + self.hyperdense*dense_accum_ch[0])
        self.decoder_levels[0].append(OutputBlock(in_channels, 3, kernel_size=3, regression=self.regression, kernel_init='xavier_uniform'))

    def forward(self, x):
        # Store state of the input
        encoder_path = [[] for i in range(self.depth)]
        decoder_path = [[] for i in range(self.depth-1)]
        upsampl_path = [[] for i in range(self.depth-1)]

        # Divide in hyperdense/not to optimize GPU memory
        if self.hyperdense:
            encoder_path = [[] for i in range(self.depth)]
            decoder_path = [[] for i in range(self.depth-1)]
            if self.ms_upsampling: upsampl_path = [[] for i in range(self.depth-1)]

            #### ENCODER ####
            for i in range(len(self.encoder_levels)):
                for j in range(len(self.encoder_levels[i])):
                    if (i == 0) and (j == 0):  # Apply the stem
                        encoder_path[i].append(self.encoder_levels[i][j](x))
                    else:  # Do the usual
                        encoder_path[i].append(self.encoder_levels[i][j](torch.cat(encoder_path[i], 1)))

                if (i != len(self.encoder_levels) - 1):
                    encoder_path[i+1].append(self.encoder_transitions[i][-1](encoder_path[i][-1]))
                else:
                    decoder_path[i-1].append(self.decoder_transitions[i][-1](encoder_path[i][-1]))
                    if self.ms_upsampling: [upsampl_path[j].append(self.decoder_transitions[i][j](encoder_path[i][-1])) for j in range(i-2, -1, -1)]

            #### DECODER ####
            for i in range(len(self.decoder_levels))[::-1]:
                for j in range(len(self.decoder_levels[i])):
                    if self.ms_upsampling: decoder_path[i].append(self.decoder_levels[i][j](torch.cat(encoder_path[i] + decoder_path[i] + upsampl_path[i], 1)))
                    else:                 decoder_path[i].append(self.decoder_levels[i][j](torch.cat(encoder_path[i] + decoder_path[i], 1)))

                if (i != 0):
                    decoder_path[i-1].append(self.decoder_transitions[i][-1](decoder_path[i][-1]))

                # Multiscale Upsampling
                if self.ms_upsampling: [upsampl_path[j].append(self.decoder_transitions[i][j](decoder_path[i][-1])) for j in range(i-2, -1, -1)]

            return decoder_path[0][-1]
        else: # IF IT IS NOT HYPERDENSE
            skipped_path = []
            if self.ms_upsampling: upsampl_path = [[] for i in range(self.depth-1)]

            #### ENCODER ####
            for i in range(len(self.encoder_levels)):
                for j in range(len(self.encoder_levels[i])):
                    x = self.encoder_levels[i][j](x)

                skipped_path.append(x)

                if (i != len(self.encoder_levels) - 1):
                    x = self.encoder_transitions[i][-1](x)
                else:
                    if self.ms_upsampling: [upsampl_path[j].append(self.decoder_transitions[i][j](x)) for j in range(i-2, -1, -1)]
                    x = self.decoder_transitions[i][-1](x)

            #### DECODER ####
            for i in range(len(self.decoder_levels))[::-1]:
                for j in range(len(self.decoder_levels[i])):
                    if self.ms_upsampling and (j == 0): x = self.decoder_levels[i][j](torch.cat([skipped_path[i], x] + upsampl_path[i], 1))
                    elif (j == 0):                     x = self.decoder_levels[i][j](torch.cat([skipped_path[i], x], 1))
                    else:                              x = self.decoder_levels[i][j](x)

                # Multiscale Upsampling
                if self.ms_upsampling: [upsampl_path[j].append(self.decoder_transitions[i][j](x)) for j in range(i-2, -1, -1)]
                if (i != 0):          x = self.decoder_transitions[i][-1](x)

            return x

    def Export(self, path):
        with open(path,'wb') as f:
            printDict                                               = dict()
            printDict['Architecture_IsHyperDense:               ']  = self.hyperdense
            printDict['Architecture_HasAtrousLayer:             ']  = self.atrous_conv
            printDict['Architecture_HasMultiscaleUpsampling:    ']  = self.ms_upsampling
            printDict['Architecture_Depth:                      ']  = self.depth
            printDict['Architecture_OutputChannels:             ']  = self.out_ch
            printDict['Architecture_StartingChannels:           ']  = self.start_ch
            printDict['Architecture_ChannelIncrementRate:       ']  = self.inc_rate
            printDict['Architecture_HasMaxPool:                 ']  = self.maxpool
            printDict['Architecture_KernelInitializer:          ']  = self.kernel_init
            printDict['Module_Name:                             ']  = self.m_name
            printDict['Module_Repetitions:                      ']  = self.m_repetitions
            printDict['Module_KernelSize:                       ']  = self.kernel_size

            w = csv.writer(f)
            w.writerows(printDict.items())





In [None]:
MobileNetV2()