In [1]:
import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import backend as K

In [2]:
# Building the convolutional block
def ConvBlock(inputs, filters):
    # Taking first input and implementing the first conv block
    conv1 = layers.Conv2D(filters, kernel_size = (3,3), padding = "same")(inputs)
    batch_norm1 = layers.BatchNormalization()(conv1)
    act1 = layers.ReLU()(batch_norm1)
    
    # Taking first input and implementing the second conv block
    conv2 = layers.Conv2D(filters, kernel_size = (3,3), padding = "same")(act1)
    batch_norm2 = layers.BatchNormalization()(conv2)
    act2 = layers.ReLU()(batch_norm2)
    
    return act2

# Building the encoder
def encoder(inputs, filters=64):
    # Collect the start and end of each sub-block for normal pass and skip connections
    enc1 = ConvBlock(inputs, filters)
    MaxPool1 = layers.MaxPooling2D(strides = (2,2))(enc1)
    return enc1, MaxPool1

# Building the decoder
def decoder(inputs, skip, filters=64):
    # Upsampling and concatenating the essential features
    Upsample = layers.Conv2DTranspose(filters, (2, 2), strides=2, padding="same")(inputs)
    Connect_Skip = layers.Concatenate()([Upsample, skip])
    out = ConvBlock(Connect_Skip, filters)
    return out

In [3]:
inputs = layers.Input((416,416,2))
    
# Construct the encoder blocks and increasing the filters by a factor of 2
skip1, encoder_1 = encoder(inputs, 64)
skip2, encoder_2 = encoder(encoder_1, 64*2)
skip3, encoder_3 = encoder(encoder_2, 64*4)
skip4, encoder_4 = encoder(encoder_3, 64*8)

# Preparing the next block
conv_block = ConvBlock(encoder_4, 64*16)

2025-09-22 16:56:29.199893: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-09-22 16:56:30.955438: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30927 MB memory:  -> device: 0, name: Tesla V100-SXM2-32GB-LS, pci bus id: 0000:85:00.0, compute capability: 7.0
2025-09-22 16:56:30.959042: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 30927 MB memory:  -> device: 1, name: Tesla V100-SXM2-32GB-LS, pci bus id: 0000:86:00.0, compute capability: 7.0


In [4]:
# Construct the decoder blocks and decreasing the filters by a factor of 2
decoder_1 = decoder(conv_block, skip4, 64*8)
decoder_2 = decoder(decoder_1, skip3, 64*4)
decoder_3 = decoder(decoder_2, skip2, 64*2)
decoder_4 = decoder(decoder_3, skip1, 64)

outputs = layers.Conv2D(1, 1, padding="same", activation="sigmoid")(decoder_4)
outputs = layers.MaxPooling2D(strides = (16,16))(outputs)

model = models.Model(inputs, outputs)

In [5]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 416, 416, 2  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 416, 416, 64  1216        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 416, 416, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

In [6]:
# Learning_rate=0.001 can be tuned for better performance  
model.compile(optimizer=keras.optimizers.Adam(learning_rate = 0.001), loss=['binary_crossentropy'], metrics=keras.metrics.AUC())

early_stopping = EarlyStopping(
        monitor = 'val_loss',
        min_delta=0.000001, # minimium amount of change to count as an improvement
        patience=2, # how many epochs to wait before stopping
        restore_best_weights=True,
        mode='auto'
)

Run this line when you have the training/validation data prepared and you want to train the model

In [7]:
#history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs = 100, callbacks = [early_stopping], batch_size=4)