In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def MobileUNet(input_shape=(512, 512, 6)):
    inputs = tf.keras.Input(shape=input_shape)

    # Base model (MobileNetV2 as encoder)
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights=None  # No pretrained weights for 6 channels
    )

    # Use intermediate layers for skip connections
    layer_names = [
        'block_1_expand_relu',   # 128x128
        'block_3_expand_relu',   # 64x64
        'block_6_expand_relu',   # 32x32
        'block_13_expand_relu',  # 16x16
        'block_16_project',      # 8x8 (bottleneck)
    ]
    skips = [base_model.get_layer(name).output for name in layer_names]
    encoder = tf.keras.Model(inputs=base_model.input, outputs=skips)

    # If using 6-band input, replicate channels to match expected shape
    if input_shape[-1] != 3:
        x = layers.Conv2D(3, 1, padding='same')(inputs)
    else:
        x = inputs

    x = encoder(x)
    x, skips = x[-1], reversed(x[:-1])  # bottleneck and skip layers

    # Decoder
    up_filters = [512, 256, 128, 64]
    for up_filter, skip in zip(up_filters, skips):
        x = layers.Conv2DTranspose(up_filter, 3, strides=2, padding='same')(x)
        x = layers.Concatenate()([x, skip])
        x = layers.Conv2D(up_filter, 3, padding='same', activation='relu')(x)
        x = layers.Conv2D(up_filter, 3, padding='same', activation='relu')(x)

    # Output layer
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(x)

    return models.Model(inputs, outputs)


In [None]:
def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

bce = tf.keras.losses.BinaryCrossentropy()

def bce_dice_loss(y_true, y_pred):
    return 0.5 * bce(y_true, y_pred) + 0.5 * dice_loss(y_true, y_pred)


In [None]:
model = MobileUNet(input_shape=(256, 256, 6))
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=bce_dice_loss,
    metrics=[
        tf.keras.metrics.MeanIoU(num_classes=2),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision()
    ]
)
model.summary()
