In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, Add, SeparableConv2D, UpSampling2D, concatenate
from tensorflow.keras.models import Model

def SpatialPath(input_tensor):
    # Layer 1
    x = Conv2D(64, (7, 7), strides=2, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    # Layer 2
    x = Conv2D(128, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    # Layer 3
    x = Conv2D(256, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    return x

def AttentionRefinementModule(input_tensor):
    # Global average pooling
    x = GlobalAveragePooling2D()(input_tensor)
    x = Dense(input_tensor.shape[-1], activation='sigmoid')(x)
    
    # Multiply attention map with the input features
    x = tf.keras.layers.Multiply()([input_tensor, x[:, tf.newaxis, tf.newaxis, :]])
    
    return x

def ContextPath(input_tensor):
    # Assume the Xception model is pre-trained and loaded here
    base_model = tf.keras.applications.Xception(include_top=False, weights='imagenet', input_tensor=input_tensor)

    # Get the feature maps from the last two stages of the Xception model
    feature_13 = base_model.get_layer('block13_sepconv2_bn').output
    feature_14 = base_model.get_layer('block14_sepconv2_act').output
    
    # Attention refinement modules
    feature_13_arm = AttentionRefinementModule(feature_13)
    feature_14_arm = AttentionRefinementModule(feature_14)

    # Global average pooling
    global_context = GlobalAveragePooling2D()(feature_14_arm)
    global_context = tf.keras.layers.Reshape((1, 1, -1))(global_context)
    global_context = tf.keras.layers.UpSampling2D(size=(tf.keras.backend.int_shape(feature_14_arm)[1], tf.keras.backend.int_shape(feature_14_arm)[2]), interpolation='nearest')(global_context)

    # Adding global context to the last stage feature
    feature_14_arm = tf.keras.layers.Add()([feature_14_arm, global_context])

    # Upsample the output of ARM
    feature_13_arm = UpSampling2D(size=(2, 2), interpolation='bilinear')(feature_13_arm)
    feature_14_arm = UpSampling2D(size=(4, 4), interpolation='bilinear')(feature_14_arm)

    return feature_13_arm, feature_14_arm

def FeatureFusionModule(spatial_out, context_out):
    # Concatenate spatial path and context path outputs
    concatenated = concatenate([spatial_out, context_out])

    # Apply convolution, batch normalization, and ReLU
    x = Conv2D(256, (3, 3), padding='same')(concatenated)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    # Global average pooling
    pooled = GlobalAveragePooling2D()(x)
    pooled = Dense(256, activation='relu')(pooled)
    pooled = Dense(256, activation='sigmoid')(pooled)
    
    # Reshape and multiply
    pooled = tf.keras.layers.Reshape((1, 1, 256))(pooled)
    x = tf.keras.layers.Multiply()([x, pooled])
    
    return x

# Input layer
input_tensor = Input(shape=(512, 512, 3))

# Spatial Path
spatial_out = SpatialPath(input_tensor)

# Context Path
context_out_13, context_out_14 = ContextPath(input_tensor)

# Feature Fusion Module
fused_out = FeatureFusionModule(spatial_out, context_out_14)

# Upsample the final output
fused_out = UpSampling2D(size=(8, 8), interpolation='bilinear')(fused_out)

# Final classifier
output_tensor = Conv2D(19, (1, 1), activation='softmax')(fused_out)

# Create the model
model = Model(input_tensor, output_tensor)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Model summary
model.summary()
