In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model

def build_encoder(input_layer, num_filters=64, num_blocks=4):
    skips = []
    x = input_layer
    for _ in range(num_blocks):
        x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
        x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
        skips.append(x)
        x = MaxPooling2D(2)(x)
        num_filters *= 2
    x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
    x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
    return x, skips

def build_decoder(bottleneck, encoder1_skips, encoder2_skips, num_filters=512, num_blocks=4):
    x = bottleneck
    for i in range(num_blocks):
        x = UpSampling2D(2)(x)
        skip1 = encoder1_skips[-(i+1)]
        skip2 = encoder2_skips[-(i+1)]
        merged_skip = Concatenate()([skip1, skip2])
        x = Concatenate()([x, merged_skip])
        x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
        x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
        num_filters //= 2
    return x

def y_net(input_shape=(256, 256, 3)):
    input1 = Input(shape=input_shape)
    input2 = Input(shape=input_shape)

    encoder1_output, encoder1_skips = build_encoder(input1, num_filters=64, num_blocks=4)
    encoder2_output, encoder2_skips = build_encoder(input2, num_filters=64, num_blocks=4)

    merged_bottleneck = Concatenate()([encoder1_output, encoder2_output])

    decoder_output = build_decoder(merged_bottleneck, encoder1_skips, encoder2_skips, num_filters=1024, num_blocks=4)
    outputs = Conv2D(1, 1, activation='sigmoid')(decoder_output)

    model = Model(inputs=[input1, input2], outputs=outputs)
    return model

In [2]:
model = y_net()

In [3]:
model.summary()