import tensorflow as tf


def dice_loss(y_true, y_pred):
    numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1,2,3))
    denominator = tf.reduce_sum(y_true + y_pred, axis=(1,2,3))
    return 1 - numerator / denominator

# Define the model architecture
def instance_segmentation_model(S, C, HI, WI):
    inputs = tf.keras.layers.Input(shape=(None, None, 3))
    # Divide input image into SxS uniform grids
    x = tf.keras.layers.Lambda(lambda x: tf.image.extract_patches(x, [1, S, S, 1], [1, S, S, 1], [1, 1, 1, 1], 'SAME'))(inputs)
    # Semantic category prediction for each grid cell
    semantic = tf.keras.layers.Conv2D(C, 1)(x)
    # Position-sensitive encoding of pixel coordinates
    coords = tf.range(S, dtype=tf.float32) / (S - 1)
    x_coords, y_coords = tf.meshgrid(coords, coords)
    x_coords = tf.tile(tf.reshape(x_coords, (1, S, S, 1)), [tf.shape(inputs)[0], 1, 1, 1])
    y_coords = tf.tile(tf.reshape(y_coords, (1, S, S, 1)), [tf.shape(inputs)[0], 1, 1, 1])
    x = tf.concat([x, x_coords, y_coords], axis=-1)
    # Instance mask generation for each grid cell
    x = tf.keras.layers.Conv2D(HI*WI*S*S, 1)(x)
    x = tf.keras.layers.Reshape((S*S, HI*WI))(x)
    outputs = tf.keras.layers.Concatenate(axis=-1)([semantic, x])
    return tf.keras.Model(inputs=inputs, outputs=outputs)

# Define input image dimensions and grid size
input_height = 512
input_width = 512
grid_size = 5
mask_size = 32

# Define the input tensor
input_tensor = tf.keras.layers.Input(shape=(input_height, input_width, 3))

# Add a normalization layer to the input tensor
normalized_input = tf.keras.layers.Lambda(lambda x: x / 255.0 - 0.5)(input_tensor)

# Add a CoordConv layer to the input tensor
coordconv_layer = tf.keras.layers.Lambda(lambda x: tf.stack(tf.meshgrid(tf.linspace(-1.0, 1.0, x.shape[1]), 
                                                            tf.linspace(-1.0, 1.0, x.shape[2])), 
                                             axis=-1))(input_tensor)
input_features = tf.keras.layers.Concatenate(axis=-1)([normalized_input, coordconv_layer])

# Add the backbone network to the input features
backbone = tf.keras.applications.ResNet50(input_tensor=input_features, include_top=False)

# Add the feature pyramid network (FPN) to the backbone output
fpn = tf.keras.models.Model(inputs=backbone.input, outputs=backbone.get_layer('conv5_block3_out').output)
fpn_output = fpn.output

# Add the instance segmentation heads to each FPN feature level
mask_outputs = []
for i in range(3, 6):
    # Instance category prediction branch
    category_branch = tf.keras.layers.Conv2D(80, (3, 3), padding='same', activation='relu', name=f'category_branch_{i}')(fpn.get_layer(f'conv2_block{i}_out').output)
    category_output = tf.keras.layers.Conv2D(grid_size ** 2, (1, 1), padding='valid', activation='sigmoid', name=f'category_output_{i}')(category_branch)

    # Instance mask prediction branch
    mask_branch = tf.keras.layers.Conv2D(80, (3, 3), padding='same', activation='relu', name=f'mask_branch_{i}')(fpn.get_layer(f'conv2_block{i}_out').output)
    mask_output = tf.keras.layers.Conv2D(grid_size ** 2 * mask_size, (1, 1), padding='valid', activation='sigmoid', name=f'mask_output_{i}')(mask_branch)

    mask_outputs.append(mask_output)

In [5]:
import tensorflow as tf


# Define the number of grid cells, the number of classes, and the image size
S = 5
C = 10
H = 224
W = 224

# Define the input layer for the image
input_layer = tf.keras.layers.Input(shape=(H, W, 3))

# Define the pixel coordinate layer
coord_layer = tf.keras.layers.Lambda(lambda x: tf.meshgrid(tf.linspace(-1.0, 1.0, H), tf.linspace(-1.0, 1.0, W)))(input_layer)
coord_layer = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=-1))(coord_layer)
coord_layer = tf.keras.layers.Reshape((H, W, 2))(coord_layer)
coord_layer = tf.keras.layers.Lambda(lambda x: x * 10.0)(coord_layer)  # Scale the coordinates to match the feature map

# Define the backbone network (not shown in the diagram)
backbone = ...

# Define the feature pyramid network (not shown in the diagram)
fpn = ...

# Apply the backbone network to the input image
x = backbone(input_layer)

# Apply the feature pyramid network to the backbone output
features = fpn(x)

# Concatenate the pixel coordinate layer with the feature map
features = tf.keras.layers.Concatenate(axis=-1)([features, coord_layer])

# Create the category prediction head
category_prediction = tf.keras.layers.Conv2D(C, kernel_size=3, padding='same', activation='softmax')(features)

# Create the instance mask generation head
instance_mask = tf.keras.layers.Conv2D(S*S, kernel_size=3, padding='same', activation='sigmoid')(features)

# Reshape the output of the instance mask generation head to HI×WI×S^2
instance_mask = tf.keras.layers.Reshape((H, W, S*S))(instance_mask)

# Define a lambda layer to extract the instance mask at each grid cell
mask_extraction_layer = tf.keras.layers.Lambda(lambda x: tf.stack(
    [tf.gather(tf.transpose(x, perm=[2, 0, 1]), i) for i in range(S*S)], axis=-1))

# Apply the mask extraction layer to the instance mask
instance_mask = mask_extraction_layer(instance_mask)

# Multiply the category prediction by the instance mask to get the final output
output = tf.keras.layers.Multiply()([category_prediction, instance_mask])

# Define the model with the input and output layers
model = tf.keras.models.Model(inputs=input_layer, outputs=output)



def CoordConv(x):
    """Implementation of CoordConv operation."""
    batch_size, height, width, channels = tf.shape(x)
    x_coords = tf.range(-1, 1, 2/height, dtype=tf.float32)
    y_coords = tf.range(-1, 1, 2/width, dtype=tf.float32)
    x_coords = tf.tile(tf.reshape(x_coords, [1, height, 1, 1]), [batch_size, 1, width, 1])
    y_coords = tf.tile(tf.reshape(y_coords, [1, 1, width, 1]), [batch_size, height, 1, 1])
    return tf.concat([x, x_coords, y_coords], axis=3)

def build_SOLO_head(input_tensor, num_classes):
    """Builds the SOLO head network architecture."""
    # Feature pyramid network (FPN)
    C3, C4, C5 = tf.keras.applications.ResNet50(include_top=False, weights=None).outputs
    P3 = C3
    P4 = tf.keras.layers.Add()([tf.keras.layers.UpSampling2D()(P3), C4])
    P5 = tf.keras.layers.Add()([tf.keras.layers.UpSampling2D()(P4), C5])
    P3 = tf.keras.layers.Conv2D(256, kernel_size=1, padding='same')(P3)
    P4 = tf.keras.layers.Conv2D(256, kernel_size=1, padding='same')(P4)
    P5 = tf.keras.layers.Conv2D(256, kernel_size=1, padding='same')(P5)
    P6 = tf.keras.layers.MaxPooling2D(pool_size=1, strides=2, padding='same')(P5)
    features = [P3, P4, P5, P6]

    # Category prediction sub-network
    cls_preds = []
    for feature in features:
        cls_pred = tf.keras.layers.Conv2D(num_classes, kernel_size=3, padding='same')(feature)
        cls_pred = tf.keras.layers.Activation('sigmoid')(cls_pred)
        cls_preds.append(cls_pred)

    # Instance mask segmentation sub-network
    mask_preds = []
    for feature in features:
        mask_feat = CoordConv(feature)
        mask_feat = tf.keras.layers.Conv2D(256, kernel_size=3, padding='same')(mask_feat)
        mask_feat = tf.keras.layers.ReLU()(mask_feat)
        mask_feat = tf.keras.layers.Conv2D(256, kernel_size=3, padding='same')(mask_feat)
        mask_feat = tf.keras.layers.ReLU()(mask_feat)
        mask_pred = tf.keras.layers.Conv2D(1, kernel_size=1, activation='sigmoid')(mask_feat)
        mask_preds.append(mask_pred)

    # Upsampling and concatenation
    for i in range(len(mask_preds)):
        if i == 0:
            mask_preds[i] = tf.keras.layers.UpSampling2D()(mask_preds[i])
        else:
            mask_preds[i] = tf.keras.layers.UpSampling2D()(mask_preds[i])
            mask_preds[i] = tf.keras.layers.Add()([mask_preds[i], mask_preds[i-1]])
    mask_preds[-1] = tf.keras.layers.Conv2DTranspose(num_classes, kernel_size=2, strides=2, padding='same')(mask_preds[-1])
    mask_preds[-1] = tf.keras.layers.Activation('sigmoid')(mask_preds[-1])
    mask_preds.reverse()

    # SOLO head output tensor
    solo_output = tf.keras.layers.Concatenate(axis=-1)([category_output, mask_output])
    
    
    # Reshape the SOLO output tensor to (batch_size, num_grids, output_height, output_width, num_classes+4)
    solo_output = tf.reshape(solo_output, [batch_size, num_grids, output_height, output_width, num_classes+4])

    # Gather the raw instance segmentation results
    raw_results = tf.gather_nd(solo_output, indices)

    # Apply sigmoid to the category scores
    category_scores = tf.sigmoid(raw_results[..., :num_classes])

    # Apply softmax to the category scores
    category_scores = tf.nn.softmax(category_scores, axis=-1)

    # Apply sigmoid to the mask predictions
    mask_predictions = tf.sigmoid(raw_results[..., num_classes:])

    # Apply NMS to obtain the final instance segmentation results
    instance_masks = []
    for i in range(batch_size):
        # Obtain the category scores and mask predictions for the i-th image
        category_scores_i = category_scores[i, ...]
        mask_predictions_i = mask_predictions[i, ...]

        # Apply NMS to the predictions for the i-th image
        nms_indices = tf.image.non_max_suppression(
            boxes=grid_boxes,
            scores=tf.reshape(category_scores_i, [-1]),
            max_output_size=max_instances_per_grid,
            iou_threshold=nms_threshold
        )

        # Gather the mask predictions for the selected indices
        nms_masks = tf.gather(mask_predictions_i, nms_indices, axis=-1)

        # Resize the masks to the original image size
        nms_masks = tf.image.resize(
            nms_masks,
            size=(image_height, image_width),
            method=tf.image.ResizeMethod.BILINEAR
        )

        instance_masks.append(nms_masks)

    # Stack the instance masks for all images in the batch
    instance_masks = tf.stack(instance_masks, axis=0)

    return instance_masks


def dice_loss(y_true, y_pred):
    """Computes the Dice Loss between the ground truth masks (y_true) and the predicted masks (y_pred)"""
    numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=(1, 2))
    denominator = tf.reduce_sum(y_true + y_pred, axis=(1, 2))
    return 1 - (numerator + 1) / (denominator + 1)

def solo_loss(y_true, y_pred):
    """Computes the SOLO training loss, including both the categorical loss and the mask loss"""
    Lcate = tf.keras.losses.CategoricalCrossentropy()(y_true[:, :, :, :-2], y_pred[:, :, :, :-2])
    Npos = tf.reduce_sum(tf.cast(tf.reduce_max(y_true[:, :, :, :-2], axis=-1) > 0, tf.float32))
    mask_targets = tf.reshape(y_true[:, :, :, -2:], [-1, tf.shape(y_true)[1] * tf.shape(y_true)[2], 2])
    mask_preds = tf.reshape(y_pred[:, :, :, -2:], [-1, tf.shape(y_pred)[1] * tf.shape(y_pred)[2], 2])
    Lmask = tf.reduce_sum(dice_loss(mask_targets[:, :, 1], mask_preds[:, :, 1]) * mask_targets[:, :, 0]) / (Npos + 1e-6)
    return Lcate + 3 * Lmask

In [19]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Concatenate, MaxPooling2D, Conv2DTranspose, Dropout, BatchNormalization, Activation, Add
from tensorflow.keras.models import Model



def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    filters1, filters2, filters3 = filters

    x = Conv2D(filters1, (1, 1), strides=strides,
               name=f'res{stage}{block}_branch2a')(input_tensor)
    x = BatchNormalization(axis=3, name=f'bn{stage}{block}_branch2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size, padding='same',
               name=f'res{stage}{block}_branch2b')(x)
    x = BatchNormalization(axis=3, name=f'bn{stage}{block}_branch2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=f'res{stage}{block}_branch2c')(x)
    x = BatchNormalization(axis=3, name=f'bn{stage}{block}_branch2c')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides,
                      name=f'res{stage}{block}_branch1')(input_tensor)
    shortcut = BatchNormalization(
        axis=3, name=f'bn{stage}{block}_branch1')(shortcut)

    x = Add()([x, shortcut])
    x = Activation('relu')(x)
    return x


def identity_block(input_tensor, kernel_size, filters, stage, block):
    """
    identity block 구현

    Arguments:
    input_tensor -- 입력 tensor
    kernel_size -- middle layer conv filter 크기
    filters -- tuple 형태로 각 conv layer filter 수
    stage -- layer 그룹
    block -- 블록 이름

    Returns:
    output tensor
    """

    # conv layer들의 이름 지정
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    # filter 수
    filters1, filters2, filters3 = filters

    # Shortcut path
    x_shortcut = input_tensor

    # 1st conv layer
    x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
    x = BatchNormalization(axis=3, name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    # 2nd conv layer
    x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
    x = BatchNormalization(axis=3, name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    # 3rd conv layer
    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    x = BatchNormalization(axis=3, name=bn_name_base + '2c')(x)

    # shortcut과 합치기
    x = Add()([x, x_shortcut])
    x = Activation('relu')(x)
    return x



def FPN50(inputs):
    # Backbone - ResNet50
    x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', name='conv1')(inputs)
    x = BatchNormalization(axis=3, name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

    # Residual blocks
    c2 = conv_block(x, stage=2, block='a')
    c3 = identity_block(x, stage=3, block='a')
    c4 = identity_block(x, stage=4, block='a')
    c5 = identity_block(x, stage=5, block='a')

    # Feature Pyramid Network
    p5 = Conv2D(256, (1, 1), name='fpn_c5p5')(c5)
    p4 = Concatenate()([Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', name='fpn_p5upsampled')(p5), Conv2D(256, (1, 1), name='fpn_c4p4')(c4)])
    p3 = Concatenate()([Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', name='fpn_p4upsampled')(p4), Conv2D(256, (1, 1), name='fpn_c3p3')(c3)])
    p2 = Concatenate()([Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', name='fpn_p3upsampled')(p3), Conv2D(256, (1, 1), name='fpn_c2p2')(c2)])

    # Top-down pathway
    p5 = Conv2D(256, (3, 3), padding='same', name='fpn_p5')(p5)
    p4 = Conv2D(256, (3, 3), padding='same', name='fpn_p4')(p4)
    p3 = Conv2D(256, (3, 3), padding='same', name='fpn_p3')(p3)
    p2 = Conv2D(256, (3, 3), padding='same', name='fpn_p2')(p2)
    
    # Mask branch
    x = p5
    for i in range(4):
        x = Conv2D(256, (3, 3), padding='same', name=f'mask_conv{i}')(x)
        x = BatchNormalization(name=f'mask_bn{i}')(x)
        x = Activation('relu')(x)
    mask_output = Conv2D(1, (1, 1), activation='sigmoid', name='mask_output')(x)

    # Class branch
    x = p5
    for i in range(4):
        x = Conv2D(256, (3, 3), padding='same', name=f'class_conv{i}')(x)
        x = BatchNormalization(name=f'class_bn{i}')(x)
        x = Activation('relu')(x)
    class_output = Conv2D(NUM_CLASSES, (1, 1), activation='softmax', name='class_output')(x)

    # Build model
    model = Model(inputs=inputs, outputs=[mask_output, class_output], name='FPN50')
    return model


def dice_loss_fn(y_true, y_pred):
    numerator = 2 * tf.reduce_sum(y_true * y_pred, axis=-1)
    denominator = tf.reduce_sum(y_true + y_pred, axis=-1)
    dice_loss = 1 - (numerator + 1) / (denominator + 1)
    return dice_loss


def train(model, train_dataset, optimizer, epoch, max_epoch):
    # Define loss objects
    focal_loss = tfa.losses.SigmoidFocalCrossEntropy(reduction=tf.keras.losses.Reduction.NONE)
    dice_loss = dice_loss_fn

    # Define lambda function for combining loss
    def combined_loss(y_true, y_pred):
        categorical_loss = focal_loss(y_true[:, :, :, 0], y_pred[:, :, :, 0])
        mask_loss = dice_loss(y_true[:, :, :, 1:], y_pred[:, :, :, 1:])
        return categorical_loss + 3 * mask_loss

    # Define metrics for monitoring training
    categorical_loss_metric = tf.keras.metrics.Mean(name='categorical_loss')
    mask_loss_metric = tf.keras.metrics.Mean(name='mask_loss')
    total_loss_metric = tf.keras.metrics.Mean(name='total_loss')

    # Iterate over dataset
    for batch_idx, (images, targets) in enumerate(train_dataset):
        # Forward pass
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = combined_loss(targets, predictions)

        # Backward pass
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update metrics
        categorical_loss_metric.update_state(focal_loss(targets[:, :, :, 0], predictions[:, :, :, 0]))
        mask_loss_metric.update_state(dice_loss(targets[:, :, :, 1:], predictions[:, :, :, 1:]))
        total_loss_metric.update_state(loss)

        # Print progress
        if batch_idx % 10 == 0:
            print('Epoch {}/{} Batch {}/{} Categorical loss {:.4f} Mask loss {:.4f} Total loss {:.4f}'
                  .format(epoch + 1, max_epoch, batch_idx, len(train_dataset),
                          categorical_loss_metric.result(), mask_loss_metric.result(), total_loss_metric.result()))

    # Reset metrics
    categorical_loss_metric.reset_states()
    mask_loss_metric.reset_states()
    total_loss_metric.reset_states()

In [24]:
inputs = tf.keras.Input(shape=(512,512, 3))
#def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)
               
               
# Backbone - ResNet50
x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', name='conv1')(inputs)
x = BatchNormalization(axis=3, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)

# Residual blocks
c2 = conv_block(x, kernel_size=(7,7), filters=(16,32,64), stage=2, block='a')
c3 = identity_block(x, kernel_size=(7,7), filters=(16,32,64), stage=3, block='a')
c4 = identity_block(x, kernel_size=(7,7), filters=(16,32,64), stage=4, block='a')
c5 = identity_block(x, kernel_size=(7,7), filters=(16,32,64), stage=5, block='a')

# Feature Pyramid Network
p5 = Conv2D(256, (1, 1), name='fpn_c5p5')(c5)
p4 = Concatenate()([Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', name='fpn_p5upsampled')(p5), Conv2D(256, (1, 1), name='fpn_c4p4')(c4)])
p3 = Concatenate()([Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', name='fpn_p4upsampled')(p4), Conv2D(256, (1, 1), name='fpn_c3p3')(c3)])
p2 = Concatenate()([Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', name='fpn_p3upsampled')(p3), Conv2D(256, (1, 1), name='fpn_c2p2')(c2)])

# Top-down pathway
p5 = Conv2D(256, (3, 3), padding='same', name='fpn_p5')(p5)
p4 = Conv2D(256, (3, 3), padding='same', name='fpn_p4')(p4)
p3 = Conv2D(256, (3, 3), padding='same', name='fpn_p3')(p3)
p2 = Conv2D(256, (3, 3), padding='same', name='fpn_p2')(p2)

# Mask branch
x = p5
for i in range(4):
    x = Conv2D(256, (3, 3), padding='same', name=f'mask_conv{i}')(x)
    x = BatchNormalization(name=f'mask_bn{i}')(x)
    x = Activation('relu')(x)
mask_output = Conv2D(1, (1, 1), activation='sigmoid', name='mask_output')(x)

# Class branch
x = p5
for i in range(4):
    x = Conv2D(256, (3, 3), padding='same', name=f'class_conv{i}')(x)
    x = BatchNormalization(name=f'class_bn{i}')(x)
    x = Activation('relu')(x)
class_output = Conv2D(NUM_CLASSES, (1, 1), activation='softmax', name='class_output')(x)

# Build model
model = Model(inputs=inputs, outputs=[mask_output, class_output], name='FPN50')

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 256, 256, 256), (None, 128, 128, 256)]