In [None]:
# setting the Thesis project folder as working directory
%cd "../.."

# Import packages

In [None]:
import random
import tensorflow as tf
import numpy as np

import tensorflow.keras.layers as tfkl
import tensorflow.keras as tfk 
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Input
from tensorflow.keras import Model
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.preprocessing.image import ImageDataGenerator

## UNeXt

In [None]:
class UNeXt():
    def __init__(self, input_shape, activ_encod = 'gelu', activ_decod='gelu', activ_out='sigmoid', kern_init='HeUniform'):
        # define activation functions
        self.activ_encod = activ_encod
        self.activ_decod = activ_decod
        self.activ_out = activ_out

        # define kernel initializer
        self.kern_init = kern_init

        # layer to be used in the tfk.Model
        self.input = tfkl.Input(input_shape)

        # layer to be used in the network creation
        self.encoder = self.input
        self.decoder = None
        
        # list where I save all conv layers that will be concatenated through
        # the skip connection. This will contain the 2 list of the pool layers
        # of the 2 encoders
        self.pool_layers_list = []

    def Down_Conv_block(self, inp, filters, encoder, activ):
        conv1 = tfkl.Conv2D(filters=filters, kernel_size=7, strides=1, padding='same', activation=None, kernel_initializer=self.kern_init)(inp)
        conv2 = tfkl.Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='same', activation=activ, kernel_initializer=self.kern_init)(conv1)
        conv3 = tfkl.Conv2D(filters=filters, kernel_size=1, strides=1, padding='same', activation=None, kernel_initializer=self.kern_init)(conv2)
        sum_4 = tfkl.Add()([conv1, conv3])
        pool = tfkl.MaxPool2D(pool_size=(2, 2), strides=2)(sum_4)
        self.pool_layers_list.append(sum_4)
        return pool

    def Up_Conv_block(self, inp, filters, respective_down_layer, activ):      
        conv1 = tfkl.Conv2D(filters=filters, kernel_size=7, strides=1, padding='same', activation=None, kernel_initializer=self.kern_init)(inp)
        conv2 = tfkl.Conv2D(filters=filters*4, kernel_size=1, strides=1, padding='same', activation=activ, kernel_initializer=self.kern_init)(conv1)
        conv3 = tfkl.Conv2D(filters=filters, kernel_size=1, strides=1, padding='same', activation=None, kernel_initializer=self.kern_init)(conv2)
        sum_4 = tfkl.Add()([conv1, conv3])
        up_conv = tfkl.Conv2DTranspose(filters=filters//2, kernel_size=2, strides=2, padding='same')(sum_4)
        concat = tfkl.Concatenate()([respective_down_layer, up_conv])
        
        return concat


    def build_model(self, filters_list):

        # Encoder 
        for i, filters in enumerate(filters_list[:-1]):
            self.encoder = self.Down_Conv_block(self.encoder, filters, encoder=0, activ=self.activ_encod)

        # reverse the list of layers to give to the encoder in the right order
        rev_list = self.pool_layers_list[::-1]

        # set the starting layer of the decoder
        self.decoder = self.encoder

        # Decoder
        for i, filters in enumerate(filters_list[:-len(filters_list):-1]):
            self.decoder = self.Up_Conv_block(self.decoder, filters, rev_list[i], activ=self.activ_decod)
        
        # first convolutions of filters_list
        layer = tfkl.Conv2D(filters=filters_list[0], kernel_size=3, strides=1, padding='same', activation=self.activ_decod, kernel_initializer=self.kern_init)(self.decoder)
        layer = tfkl.Conv2D(filters=filters_list[0], kernel_size=3, strides=1, padding='same', activation=self.activ_decod, kernel_initializer=self.kern_init)(layer)

        # output
        out = tfkl.Conv2D(filters=1, kernel_size=3, strides=1, padding='same', activation=self.activ_out, kernel_initializer=self.kern_init)(layer)

        model = tfk.Model(inputs=[self.input], outputs=out)

        return model

In [None]:
input_shape = (None, None, 1)
filters = [8, 16, 32, 64, 128]

In [None]:
model = UNeXt(input_shape).build_model(filters)

In [None]:
model.summary()