In [5]:
import tensorflow as tf
from tensorflow.keras import layers, models

def initial_block(input_tensor):
    x = layers.Conv2D(13, (3, 3), strides=2, padding='same')(input_tensor)
    x = layers.PReLU()(x)
    pool = layers.MaxPooling2D((2, 2))(input_tensor)
    concat = layers.concatenate([x, pool], axis=-1)
    return concat

def bottleneck_block(input_tensor, filters, kernel_size=3, downsample=False, dilated=False, asym=False, dilation_rate=(1, 1)):
    stride = 2 if downsample else 1
    
    # Initial 1x1 Convolution
    x = layers.Conv2D(filters // 4, (1, 1), strides=stride, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.PReLU()(x)
    
    # Main Convolution Path
    if dilated:
        x = layers.Conv2D(filters // 4, (kernel_size, kernel_size), padding='same', dilation_rate=dilation_rate)(x)
    elif asym:
        x = layers.Conv2D(filters // 4, (kernel_size, 1), padding='same')(x)
        x = layers.Conv2D(filters // 4, (1, kernel_size), padding='same')(x)
    else:
        x = layers.Conv2D(filters // 4, (kernel_size, kernel_size), padding='same')(x)
    
    x = layers.BatchNormalization()(x)
    x = layers.PReLU()(x)
    
    # Final 1x1 Convolution
    x = layers.Conv2D(filters, (1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.SpatialDropout2D(0.01 if filters < 128 else 0.1)(x)
    
    # Adjust input_tensor if downsampling
    if downsample:
        input_tensor = layers.Conv2D(filters, (1, 1), strides=stride, padding='same')(input_tensor)
    
    # Align dimensions for addition
    if x.shape[1:] != input_tensor.shape[1:]:
        input_tensor = layers.Conv2D(filters, (1, 1), padding='same')(input_tensor)
    
    # Skip connection
    x = layers.add([x, input_tensor])
    x = layers.PReLU()(x)
    return x

def build_enet(input_shape, num_classes):
    input_tensor = layers.Input(shape=input_shape)
    
    x = initial_block(input_tensor)
    
    x = bottleneck_block(x, 64, downsample=True)
    for _ in range(4):
        x = bottleneck_block(x, 64)
    
    x = bottleneck_block(x, 128, downsample=True)
    x = bottleneck_block(x, 128)
    x = bottleneck_block(x, 128, dilated=True)
    x = bottleneck_block(x, 128, asym=True)
    x = bottleneck_block(x, 128, dilated=True, dilation_rate=(4, 4))
    x = bottleneck_block(x, 128)
    x = bottleneck_block(x, 128, dilated=True, dilation_rate=(8, 8))
    x = bottleneck_block(x, 128, asym=True)
    x = bottleneck_block(x, 128, dilated=True, dilation_rate=(16, 16))
    
    x = bottleneck_block(x, 64, downsample=False)
    for _ in range(2):
        x = bottleneck_block(x, 64)
    
    x = bottleneck_block(x, 16, downsample=False)
    x = bottleneck_block(x, 16)
    
    x = layers.Conv2D(num_classes, (1, 1), padding='same')(x)
    x = layers.UpSampling2D((4, 4))(x)
    x = layers.Activation('softmax')(x)
    
    model = models.Model(inputs=input_tensor, outputs=x)
    return model

# Example usage
input_shape = (512, 512, 3)  # Example input shape
num_classes = 20  # Example number of classes

model = build_enet(input_shape, num_classes)
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            [(None, 512, 512, 3) 0                                            
__________________________________________________________________________________________________
conv2d_196 (Conv2D)             (None, 256, 256, 13) 364         input_5[0][0]                    
__________________________________________________________________________________________________
p_re_lu_176 (PReLU)             (None, 256, 256, 13) 851968      conv2d_196[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 256, 256, 3)  0           input_5[0][0]                    
____________________________________________________________________________________________