In [2]:
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# **mio**

In [2]:
def getConvBlockOne(layerPrec, filters, kernel_size = (1, 1), activation='relu', kernel_initializer='he_normal'):
  l = BatchNormalization()(layerPrec)
  l = Activation('relu')(l)
  l = Conv2D(filters, kernel_size, kernel_initializer=kernel_initializer, padding='same')(l)
  return l

def getConvBlockTwo(layerPrec, filters, kernel_size = (3, 3), activation='relu', kernel_initializer='he_normal'):
  l = BatchNormalization()(layerPrec)
  l = Activation('relu')(l)
  l = Conv2D(filters, kernel_size, kernel_initializer=kernel_initializer, padding='same')(l)
  return l

def getUNet(input, shape, filters):
    # CONTRACTING PHASE
  # RED block
  l1 = getConvBlockOne(input,filters)
  l1 = getConvBlockTwo(l1,filters)
  # ---
  concat1 = Concatenate()([l1,input])
  # ORANGE block
  middle1 = getConvBlockOne(concat1,filters)

  # GREEN block
  l2 = getConvBlockOne(concat1,filters)
  maxPooling1 = MaxPool2D(padding='same')(l2)
  # ---

    # Bottle neck
  l3 = getConvBlockOne(maxPooling1,filters)
  l4 = getConvBlockTwo(l3,filters)
  concat2 = Concatenate()([maxPooling1,l4])

    # EXPANSIVE PATH
  # BLUE block
  l5 = getConvBlockOne(concat2,filters)
  layer = UpSampling2D()(l5)
  layer = Conv2D(filters, kernel_size=(2,2), activation='relu', kernel_initializer='he_normal', padding='same')(layer)
  layer = BatchNormalization()(layer)
  # ---
  concat3 = Concatenate()([middle1,layer])
  
  l6 = getConvBlockOne(concat3,filters)
  l7 = getConvBlockTwo(l6,filters)
  concat4 = Concatenate()([concat3,l7])
  output = getConvBlockOne(concat4,filters)
  return Model(inputs=input, outputs=output)

def getModel(shape=(128, 128, 3), filters=64):
  # m: main feature flow
  # n: generated feature number in a block of U-Net

  input = Input(shape=shape)
  output = getUNet(input, shape, filters=64)
  return output

In [None]:
net = getModel(shape=(128,128,3), filters=32)
net.summary()

In [17]:
def getUnit1(layerPrec, filters, kernel_size = (1, 1), activation='relu', kernel_initializer='he_normal'):
  l = BatchNormalization()(layerPrec)
  l = Activation('relu')(l)
  l = Conv2D(filters, kernel_size, kernel_initializer=kernel_initializer, padding='same')(l)
  return l

def getUnit2(layerPrec, filters, kernel_size = (3, 3), activation='relu', kernel_initializer='he_normal'):
  l = BatchNormalization()(layerPrec)
  l = Activation('relu')(l)
  l = Conv2D(filters, kernel_size, kernel_initializer=kernel_initializer, padding='same')(l)
  return l

def getDownBlock(layerPrec,m,n,indexBlock):
  l = getUnit1(layerPrec,4*n)
  l = getUnit2(l,n)
  nets["layers"][f"down{indexBlock}"].append(l)

  concat = Concatenate()([layerPrec,l])
  l = getUnit1(concat,m)
  maxPooling = MaxPool2D(padding='same')(l)
  return maxPooling, getUnit1(concat,m)

def getUpBlock(layerPrec,skipConn,m,n,indexBlock,upLayers=[]):
  l = getUnit1(layerPrec,m)
  l = UpSampling2D()(layerPrec)
  concat = Concatenate()([skipConn,l]+upLayers)
  l = getUnit1(concat,4*n)
  l = getUnit2(l,n)
  nets["layers"][f"up{indexBlock}"].append(l)
  concat = Concatenate()([concat,l])
  return concat

def getUNet(input,m,n,indexUNet,nUNet, nBlocks):
  layerPrec = input
  listSkipConn = []

  if indexUNet != 0:
    layerPrec = nets[f"unet{indexUNet-1}"]
    #layerPrec = Concatenate()([input,layerPrec]) #l'abbiamo fatto giù con l'if dopo l'up
    layerPrec = getUnit1(layerPrec,m)

  # down
  for i in range(nBlocks):
    layerPrec = Concatenate()([layerPrec]+nets["layers"][f"down{i}"])
    layerPrec, skipConn = getDownBlock(layerPrec,m,n,i)
    listSkipConn.append(skipConn)

  # bottle neck
  layerPrec = Concatenate()([layerPrec]+nets["layers"][f"bn"])

  l = getUnit1(layerPrec,4*n)
  l = getUnit2(l,n)
  nets["layers"]["bn"].append(l)
  concat = Concatenate()([layerPrec,l])

  # up
  layerPrec = concat
  for i in range(nBlocks):
    layerPrec = getUpBlock(layerPrec,listSkipConn[-(i+1)],m,n,i,upLayers=nets["layers"][f"up{i}"])


  if indexUNet != nUNet - 1:
    l = Concatenate()([input,layerPrec])
  else:
    l = getUnit1(layerPrec,16,activation=None)

  return l

def trasformationInput(x, filters):
  x = BatchNormalization()(x)
  x = Activation('relu')(x)
  x = Conv2D(filters, kernel_size=(7,7), strides=(2,2), kernel_initializer='he_normal', padding='same')(x)
  maxPooling = MaxPool2D(padding='same')(x)
  return maxPooling

def getCUNet(shape,m,n,nUNet,nBlocks):
  for i in range(nUNet):
    nets[f"unet{i}"] = None

  for j in range(nBlocks):
    nets["layers"][f"down{j}"] = []
    nets["layers"][f"up{j}"] = []
    
  input = Input(shape=shape)
  
  input = trasformationInput(input,m) # per le heatmap da 64x64
  print(input.shape)
  for i in range(nUNet):
    nets[f"unet{i}"] = getUNet(input,m,n,i,nUNet,nBlocks)

  output = nets[f"unet{nUNet-1}"]
  return Model(inputs=input, outputs=output)

nets = {}
nets["layers"] = {}
nets["layers"]["bn"] = []
shape = (256,256,3)
m = 64
n = 16
nUNet = 4
nBlocks = 4
net = getCUNet(shape,m,n,nUNet,nBlocks)
net.summary()

(None, 64, 64, 64)
Model: "model_7"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_19 (InputLayer)          [(None, 64, 64, 64)  0           []                               
                                ]                                                                 
                                                                                                  
 concatenate_514 (Concatenate)  (None, 64, 64, 64)   0           ['input_19[0][0]']               
                                                                                                  
 batch_normalization_845 (Batch  (None, 64, 64, 64)  256         ['concatenate_514[1][0]']        
 Normalization)                                                                                   
                                                                         