In [None]:
import keras
from keras.layers import Conv2D, MaxPool2D, Input, Conv2DTranspose, BatchNormalization, Rescaling

In [None]:
# Input layer
input_layer = Input(shape=(512,640,3), name='input_layer')
x = Rescaling(scale=1/255.0)(input_layer)

# Hidden layers - two branches
# Encoder
x_centroids_1 = Conv2D(filters=16, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x)
x_centroids_1 = BatchNormalization()(x_centroids_1)
x_centroids_1 = MaxPooling2D(pool_size=(2,2))(x_centroids_1)
x_heads_1 = Conv2D(filters=16, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x)
x_heads_1 = BatchNormalization()(x_heads_1)
x_heads_1_pool = MaxPooling2D(pool_size=(2,2))(x_heads_1)

x_centroids_2 = Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_1)
x_centroids_2 = BatchNormalization()(x_centroids_2)
x_centroids_2 = MaxPooling2D(pool_size=(2,2))(x_centroids_2)
x_heads_2 = Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_1_pool)
x_heads_2 = BatchNormalization()(x_heads_2)
x_heads_2_pool = MaxPooling2D(pool_size=(2,2))(x_heads_2)

x_centroids_3 = Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_2)
x_centroids_3 = BatchNormalization()(x_centroids_3)
x_centroids_3 = MaxPooling2D(pool_size=(2,2))(x_centroids_3)
x_heads_3 = Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_2_pool)
x_heads_3 = BatchNormalization()(x_heads_3)
x_heads_3_pool = MaxPooling2D(pool_size=(2,2))(x_heads_3)

x_centroids_4 = Conv2D(filters=128, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_3)
x_centroids_4 = BatchNormalization()(x_centroids_4)
x_centroids_4 = MaxPooling2D(pool_size=(2,2))(x_centroids_4)
x_centroids_4 = Conv2D(filters=128, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_4)
x_centroids_4 = BatchNormalization()(x_centroids_4)
x_heads_4 = Conv2D(filters=128,kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_3_pool)
x_heads_4 = BatchNormalization()(x_heads_4)
x_heads_4_pool = MaxPooling2D(pool_size=(2,2))(x_heads_4)
x_heads_4_pool = Conv2D(filters=128, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_4_pool)
x_heads_4_pool = BatchNormalization()(x_heads_4_pool)

# Decoder
x_centroids_5 = Conv2DTranspose(filters=64, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_4)
x_centroids_5 = BatchNormalization()(x_centroids_5)
x_heads_5 = Conv2DTranspose(filters=64, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_4_pool)
x_heads_5 = BatchNormalization()(x_heads_5)

x_centroids_6 = Conv2DTranspose(filters=32, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_5)
x_centroids_6 = BatchNormalization()(x_centroids_6)
x_heads_6 = Conv2DTranspose(filters=32, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_5)
x_heads_6 = BatchNormalization()(x_heads_6)

x_centroids_7 = Conv2DTranspose(filters=16, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_6)
x_centroids_7 = BatchNormalization()(x_centroids_7)
x_heads_7 = Conv2DTranspose(filters=16, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_6)
x_heads_7 = BatchNormalization()(x_heads_7)

x_centroids_8 = Conv2DTranspose(filters=8, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_centroids_7)
x_centroids_8 = BatchNormalization()(x_centroids_8)
x_heads_8 = Conv2DTranspose(filters=8, kernel_size=3, strides=(2,2), padding='same', activation='relu', kernel_initializer='he_uniform')(x_heads_7)
x_heads_8 = BatchNormalization()(x_heads_8)


# Output layer
heads_output = Conv2D(filters=1, kernel_size=1, strides=1, activation='relu', kernel_initializer='he_uniform', name='heads_layer')(x_heads_8)  #pointwise convolution
centroids_output = Concatenate(axis=-1)([x_centroids_8, x_heads_8])
centroids_output = Conv2D(filters=1, kernel_size=1, strides=1, activation='relu', kernel_initializer='he_uniform', name='centroids_layer')(centroids_output)

# Defining the model
model = tf.keras.Model(input_layer, [heads_output, centroids_output])

In [None]:
# Compiling
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), 
              loss={'heads_layer': 'mse',
                    'centroids_layer': 'mse'})

In [None]:
# Fitting
no_epochs=12
batch_size=3 
history=model.fit(x=train_data, 
                  y={'heads_layer': gt_data_heads,
                     'centroids_layer': gt_data_centroids}, 
                  validation_data=(validation_data, {'heads_layer': gt_validation_heads,
                                                     'centroids_layer': gt_validation_centroids}),
                  batch_size=batch_size, 
                  epochs=no_epochs, 
                  shuffle=True)