### Multi-task TRANSUnet for segmentation and classification

In [1]:
import cv2
import numpy as np
import pandas as pd
import os
from glob import glob
import random
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime
from tensorflow.keras import layers, Model, regularizers
from tensorflow.keras.metrics import Recall, Precision, Accuracy, IoU
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping, TensorBoard
import tensorflow.keras.backend as K
from sklearn.model_selection import KFold
import re
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
masterlist_path = '/content/drive/MyDrive/MasterlistAug30-2017.xlsx'
df = pd.read_excel(masterlist_path)

outcome_counts = df['Outcome'].value_counts()
print("Outcome distribution:\n", outcome_counts)
print("Outcome:\n0 not implanted\n1 implanted\n2 unknown")
df = df[df['Outcome'] != 2]
image_to_outcome = dict(zip(df['File Name'], df['Outcome']))
print(image_to_outcome)
outcome_counts = df['Outcome'].value_counts()

def get_outcome_from_path(path):
    fname = os.path.basename(path)
    base_fname = os.path.splitext(fname)[0]
    # Remove trailing aug index, if exists
    base_key = re.sub(r'_\d+$', '', base_fname)
    outcome = image_to_outcome.get(base_key, None)
    if outcome is None or pd.isna(outcome):
        pass
    return outcome

Outcome distribution:
 Outcome
1.0    96
2.0    78
0.0    75
Name: count, dtype: int64
Outcome:
0 not implanted
1 implanted
2 unknown
{nan: nan, 'Blast_PCRM_R12-0137': 1.0, 'Blast_PCRM_R12-0160': 1.0, 'Blast_PCRM_R12-0173a': 0.0, 'Blast_PCRM_R12-0173b': 0.0, 'Blast_PCRM_R12-0221a': 1.0, 'Blast_PCRM_R12-0221b': 1.0, 'Blast_PCRM_R12-0223a': 1.0, 'Blast_PCRM_R12-0223b': 1.0, 'Blast_PCRM_R12-0236': 1.0, 'Blast_PCRM_R12-0254b': 0.0, 'Blast_PCRM_R12-0259': 1.0, 'Blast_PCRM_R12-0266': 1.0, 'Blast_PCRM_R12-0268a': 1.0, 'Blast_PCRM_R12-0268b': 1.0, 'Blast_PCRM_R12-0293a': 0.0, 'Blast_PCRM_R12-0296': 1.0, 'Blast_PCRM_R12-0306a': 0.0, 'Blast_PCRM_R12-0306b': 0.0, 'Blast_PCRM_R12-0315': 0.0, 'Blast_PCRM_R12-0316a': 1.0, 'Blast_PCRM_R12-0316b': 1.0, 'Blast_PCRM_R12-0326b': 1.0, 'Blast_PCRM_R12-0335': 1.0, 'Blast_PCRM_R12-0338a': 0.0, 'Blast_PCRM_R12-0338b': 0.0, 'Blast_PCRM_R12-0347a': 1.0, 'Blast_PCRM_R12-0348a': 1.0, 'Blast_PCRM_R12-0354': 0.0, 'Blast_PCRM_R12-0358a': 0.0, 'Blast_PCRM_R12-0358b':

In [3]:
""" Seeding """
np.random.seed(42)
tf.random.set_seed(42)

In [14]:
SIZE=256
def read_image_tf(path):
    path = path.decode()
    x = cv2.imread(path, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (SIZE, SIZE))
    x = x / 255.0
    return x.astype(np.float32)

def read_mask_tf(path):
    path = path.decode()
    y = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    y = cv2.resize(y, (SIZE, SIZE))
    y[y != 0] = 255
    y = y / 255.0
    y = (y > 0.5).astype(np.float32)
    y = np.expand_dims(y, axis=-1)
    return y

def read_multiclass_mask_tf(te_path, zp_path, icm_path):
    te = read_mask_tf(te_path)
    zp = read_mask_tf(zp_path)
    icm = read_mask_tf(icm_path)
    mask = np.concatenate([te, zp, icm], axis=-1)  # (H, W, 3)
    return mask

def tf_parse_multi_with_cls(x_path, te_path, zp_path, icm_path, cls_label):
    def _parse(x_path, te_path, zp_path, icm_path, cls_label):
        x = read_image_tf(x_path)
        y = read_multiclass_mask_tf(te_path, zp_path, icm_path)
        return x, y, cls_label

    x, y, cls_label = tf.numpy_function(
        _parse,
        [x_path, te_path, zp_path, icm_path, cls_label],
        [tf.float32, tf.float32, tf.float32]
    )
    x.set_shape([SIZE, SIZE, 3])
    y.set_shape([SIZE, SIZE, 3])  # 3 channels: TE, ZP, ICM
    cls_label.set_shape([])
    return x, (y, cls_label)

def tf_dataset_multi_with_cls(x_paths, te_paths, zp_paths, icm_paths, cls_labels, batch_size=16, shuffle=True):
  dataset = tf.data.Dataset.from_tensor_slices((x_paths, te_paths, zp_paths, icm_paths, cls_labels))
  if shuffle:
      dataset = dataset.shuffle(buffer_size=1000)
  dataset = dataset.map(tf_parse_multi_with_cls, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
  return dataset

def filter_dataset_by_outcome(images, y1, y2, y3, image_to_outcome):
    filtered_x = []
    filtered_y1 = []
    filtered_y2 = []
    filtered_y3 = []
    filtered_labels = []

    for i, image_path in enumerate(images):
        outcome = get_outcome_from_path(image_path)
        if outcome in [0.0, 1.0]:
            filtered_x.append(image_path)
            filtered_y1.append(y1[i])
            filtered_y2.append(y2[i])
            filtered_y3.append(y3[i])
            filtered_labels.append(outcome)
    return filtered_x, filtered_y1, filtered_y2, filtered_y3, filtered_labels

*   Number of training images with outcome 0: 285
*   Number of training images with outcome 1: 390
*   Number of training images with outcome 2: 320

In [5]:
images = '/content/drive/MyDrive/PFA_Final/new_data/train/images'
temask = '/content/drive/MyDrive/PFA_Final/new_data/train/GT_TE'
zpmask = '/content/drive/MyDrive/PFA_Final/new_data/train/GT_ZP'
icmmask = '/content/drive/MyDrive/PFA_Final/new_data/train/GT_ICM'

images_test = '/content/drive/MyDrive/PFA_Final/new_data/test/images'
temask_test = '/content/drive/MyDrive/PFA_Final/new_data/test/GT_TE'
zpmask_test = '/content/drive/MyDrive/PFA_Final/new_data/test/GT_ZP'
icmmask_test = '/content/drive/MyDrive/PFA_Final/new_data/test/GT_ICM'

images_valid = '/content/drive/MyDrive/PFA_Final/new_data/valid/images'
temask_valid = '/content/drive/MyDrive/PFA_Final/new_data/valid/GT_TE'
zpmask_valid = '/content/drive/MyDrive/PFA_Final/new_data/valid/GT_ZP'
icmmask_valid = '/content/drive/MyDrive/PFA_Final/new_data/valid/GT_ICM'

train_x = sorted(glob(os.path.join(images, "*.bmp")))
train_y1 = sorted(glob(os.path.join(temask, "*.bmp")))
train_y2 = sorted(glob(os.path.join(zpmask, "*.bmp")))
train_y3 = sorted(glob(os.path.join(icmmask, "*.bmp")))

test_x = sorted(glob(os.path.join(images_test, "*.bmp")))
test_y1 = sorted(glob(os.path.join(temask_test, "*.bmp")))
test_y2 = sorted(glob(os.path.join(zpmask_test, "*.bmp")))
test_y3 = sorted(glob(os.path.join(icmmask_test, "*.bmp")))

valid_x = sorted(glob(os.path.join(images_valid, "*.bmp")))
valid_y1 = sorted(glob(os.path.join(temask_valid, "*.bmp")))
valid_y2 = sorted(glob(os.path.join(zpmask_valid, "*.bmp")))
valid_y3 = sorted(glob(os.path.join(icmmask_valid, "*.bmp")))

In [6]:
filtered_train_x,filtered_train_y1,filtered_train_y2,filtered_train_y3,filtered_train_labels = filter_dataset_by_outcome(
    train_x, train_y1, train_y2, train_y3, image_to_outcome)

print("Training data:", Counter(filtered_train_labels))

# Undersample majority class (1)
indices_by_class = defaultdict(list)
for idx, label in enumerate(filtered_train_labels):
    indices_by_class[label].append(idx)

minority_class_size = min(len(indices_by_class[0]), len(indices_by_class[1]))
undersampled_indices = indices_by_class[0] + random.sample(indices_by_class[1], minority_class_size)
random.shuffle(undersampled_indices)

undersampled_train_x = [filtered_train_x[i] for i in undersampled_indices]
undersampled_train_y1 = [filtered_train_y1[i] for i in undersampled_indices]
undersampled_train_y2 = [filtered_train_y2[i] for i in undersampled_indices]
undersampled_train_y3 = [filtered_train_y3[i] for i in undersampled_indices]
undersampled_train_labels = [filtered_train_labels[i] for i in undersampled_indices]
print("After undersampling:", Counter(undersampled_train_labels))

train_ds = tf_dataset_multi_with_cls(
    undersampled_train_x,
    undersampled_train_y1,
    undersampled_train_y2,
    undersampled_train_y3,
    undersampled_train_labels,
    batch_size=8,
)

Training data: Counter({1.0: 390, 0.0: 285})
After undersampling: Counter({0.0: 285, 1.0: 285})


In [7]:
# Validation data
valid_filtered_x, valid_filtered_y1, valid_filtered_y2, valid_filtered_y3, valid_filtered_labels = filter_dataset_by_outcome(
    valid_x, valid_y1, valid_y2, valid_y3, image_to_outcome)

valid_ds = tf_dataset_multi_with_cls(
    valid_filtered_x,
    valid_filtered_y1,
    valid_filtered_y2,
    valid_filtered_y3,
    valid_filtered_labels,
    batch_size=8,
    shuffle=False
)

# Test data
test_filtered_x, test_filtered_y1, test_filtered_y2, test_filtered_y3, test_filtered_labels = filter_dataset_by_outcome(
    test_x, test_y1, test_y2, test_y3, image_to_outcome)

test_ds = tf_dataset_multi_with_cls(
    test_filtered_x,
    test_filtered_y1,
    test_filtered_y2,
    test_filtered_y3,
    test_filtered_labels,
    batch_size=8,
    shuffle=False
)
print(test_ds.take(1))

<_TakeDataset element_spec=(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), (TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.float32, name=None)))>


In [8]:
smooth = 1e-15

cls_callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_classification_loss',
        factor=0.5,       # halve LR (less aggressive than 0.3, more stable)
        patience=3,       # classification is stalling quickly, so react faster
        min_lr=5e-6,
        mode='min',
        verbose=1
    )
]


seg_callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_segmentation_loss',
        factor=0.1,
        patience=10,
        min_lr=5e-6,
        mode='min',
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_segmentation_loss',
        patience=50,
        restore_best_weights=True,
        mode='min'
    )
]

def iou_multiclass(y_true, y_pred):
    y_true = tf.round(y_true)
    y_pred = tf.round(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred, axis=[0,1,2])  # sum per channel
    union = tf.reduce_sum(y_true, axis=[0,1,2]) + tf.reduce_sum(y_pred, axis=[0,1,2]) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return tf.reduce_mean(iou)

def dice_coef_multiclass(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred, axis=[0,1,2])
    denominator = tf.reduce_sum(y_true, axis=[0,1,2]) + tf.reduce_sum(y_pred, axis=[0,1,2])
    dice = (2. * intersection + smooth) / (denominator + smooth)
    return tf.reduce_mean(dice)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef_multiclass(y_true, y_pred)

In [9]:
weight_decay=1e-8

def conv_block(x, filters, dropout=False):
    x = layers.Conv2D(filters, (3, 3), padding="same", kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv2D(filters, (3, 3), padding="same", kernel_initializer='he_normal')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    if dropout:
        x = layers.Dropout(0.5)(x)

    return x

def encoder_block(x, filters):
    c = conv_block(x, filters)
    p = layers.MaxPooling2D((2, 2))(c)
    return c, p

def decoder_block(x, skip, filters):
    x = layers.Conv2DTranspose(filters, (2, 2), strides=2, padding="same")(x)
    x = layers.Concatenate()([x, skip])
    x = conv_block(x, filters)
    return x

In [10]:
def classification_head_v1(bottleneck_feat):
  gap = layers.GlobalAveragePooling1D()(bottleneck_feat)
  x = layers.Dense(128, activation='relu')(gap)
  x = layers.Dropout(0.4)(x)
  return layers.Dense(1, activation='sigmoid', name="classification")(x)

def classification_head_v2(bottleneck_feat, dropout_rate=0.3, l2_reg=1e-4):
    x = layers.GlobalAveragePooling2D()(bottleneck_feat)
    x = layers.Dense(128, activation='relu',
                     kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(64, activation='relu',
                     kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.Dropout(dropout_rate)(x)
    output = layers.Dense(1, activation='sigmoid', name="classification")(x)
    return output

def cls_head_midfeature(mid_feat):
    x = layers.GlobalAveragePooling2D()(mid_feat)
    x = layers.LayerNormalization()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    cls_output = layers.Dense(1, activation='sigmoid', name='classification')(x)
    return cls_output

def classification_head_combined(mid_feat, bottleneck_feat, dropout_rate=0.5, l2_reg=1e-4):
    mid_pool = layers.GlobalAveragePooling2D()(mid_feat)
    b1_reshaped = layers.Reshape((-1, bottleneck_feat.shape[-1]))(bottleneck_feat)  # (B, N_patches, D)
    bottleneck_pool = layers.GlobalAveragePooling1D()(b1_reshaped)  # shape -> (B, D)

    x = layers.Concatenate()([mid_pool, bottleneck_pool])  # shape -> (B, C_mid + D)

    x = layers.Dense(128, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Dense(64, kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout_rate)(x)

    return layers.Dense(1, activation='sigmoid', name="classification")(x)

In [11]:
# image sequentialization: shaping the input x into a sequence of flattened 2D patches
def patch_embedding(x, patch_size, embed_dim):
    x = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_size, padding='valid')(x)
    # (B, H/P, W/P, D=embed_dim), reshaped to (B, N, D), N=H*W/P²
    x = layers.Reshape((-1, embed_dim))(x)
    return x

def transformer_block(x, embed_dim, num_heads, mlp_dim, dropout=0.1, return_attention=False):
    norm1 = layers.LayerNormalization(epsilon=1e-6)(x)
    mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=dropout)

    if return_attention:
        attn_output, attn_scores = mha(norm1, norm1, return_attention_scores=True)
    else:
        attn_output = mha(norm1, norm1)

    x = layers.Add()([x, attn_output])

    norm2 = layers.LayerNormalization(epsilon=1e-6)(x)
    mlp_output = layers.Dense(mlp_dim, activation='gelu')(norm2)
    mlp_output = layers.Dropout(dropout)(mlp_output)
    mlp_output = layers.Dense(embed_dim)(mlp_output)
    mlp_output = layers.Dropout(dropout)(mlp_output)
    x = layers.Add()([x, mlp_output])

    if return_attention:
        return x, attn_scores
    return x

def transformer_bottleneck(x, patch_size=1, num_layers=4, embed_dim=512, num_heads=8, mlp_dim=1024, return_all_attention=False):
    h, w, c = x.shape[1:]

    # Image Sequentialization.
    x = patch_embedding(x, patch_size, embed_dim)  # (B, N, D)

    # Positional encoding
    pos_emb = tf.Variable(tf.random.normal([1, x.shape[1], embed_dim]), trainable=True)
    x = x + pos_emb # Patch Embedding.

    attn_maps = []  # attention scores
    for _ in range(num_layers):
        if return_all_attention:
            x, attn = transformer_block(x, embed_dim, num_heads, mlp_dim, return_attention=True)
            attn_maps.append(attn)
        else:
            x = transformer_block(x, embed_dim, num_heads, mlp_dim)

    # Reshape back to 2D
    new_h, new_w = h // patch_size, w // patch_size
    x = layers.Reshape((new_h, new_w, embed_dim))(x)

    if return_all_attention:
        return x, attn_maps
    return x

def build_transunet(input_shape, return_attention=False):
    inputs = tf.keras.Input(input_shape)

    # Encoder
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4 = conv_block(p3, 512, dropout=True)
    p4 = layers.MaxPooling2D((2, 2))(s4)

    # Transformer bottleneck
    if return_attention:
        b1, attention_maps = transformer_bottleneck(p4, patch_size=1, num_layers=4,
                                                    embed_dim=512, num_heads=8,
                                                    mlp_dim=1024, return_all_attention=True)
    else:
        b1 = transformer_bottleneck(p4, patch_size=1, num_layers=4,
                                    embed_dim=512, num_heads=8,
                                    mlp_dim=1024, return_all_attention=False)

    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)

    seg_output = layers.Conv2D(3, 1, padding="same", activation="sigmoid", name="segmentation")(d4)
    cls_output = classification_head_combined(p3, b1)

    if return_attention:
        return Model(inputs, [seg_output, cls_output, attention_maps])
    return Model(inputs, [seg_output, cls_output])

In [None]:
X = np.array(undersampled_train_x)
y1 = np.array(undersampled_train_y1)
y2 = np.array(undersampled_train_y2)
y3 = np.array(undersampled_train_y3)
y_cls = np.array(undersampled_train_labels).astype(np.float32) # Cast to float32 here

kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_no = 1
val_scores = []

for train_idx, val_idx in kf.split(X):
    print(f"\n--- Fold {fold_no} ---")

    train_ds_fold = tf_dataset_multi_with_cls(
        X[train_idx], y1[train_idx], y2[train_idx], y3[train_idx], y_cls[train_idx],
        batch_size=8, shuffle=True
    )

    val_ds_fold = tf_dataset_multi_with_cls(
        X[val_idx], y1[val_idx], y2[val_idx], y3[val_idx], y_cls[val_idx],
        batch_size=8, shuffle=False
    )

    model = build_transunet(input_shape=(SIZE, SIZE, 3), return_attention=False)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss={
            "segmentation": dice_loss,
            "classification": tf.keras.losses.BinaryCrossentropy(from_logits=False)
        },
        loss_weights={"segmentation": 1.0, "classification": 2.5},  # adjust if needed
        metrics={
            "segmentation": [dice_coef_multiclass, iou_multiclass],
            "classification": [
                tf.keras.metrics.AUC(name="auprc", curve="PR"),
                tf.keras.metrics.Precision(name="precision"),
                tf.keras.metrics.Recall(name="recall"),
                tf.keras.metrics.BinaryAccuracy(name="accuracy")
            ]
        }
    )

    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=8, min_lr=1e-7, verbose=1
    )
    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=25, restore_best_weights=True, verbose=1
    )

    history = model.fit(
        train_ds_fold,
        validation_data=val_ds_fold,
        epochs=200,
        callbacks=[reduce_lr, early_stop],
        verbose=1
    )

    scores = model.evaluate(val_ds_fold, verbose=0)
    print(f"Fold {fold_no} scores: {scores}")
    val_scores.append(scores)

    fold_no += 1

val_scores = np.array(val_scores)
print("\n=== 5-Fold Cross-Validation Results ===")
print("Mean:", val_scores.mean(axis=0))
print("Std:", val_scores.std(axis=0))


--- Fold 1 ---
Epoch 1/200
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m633s[0m 10s/step - classification_accuracy: 0.5329 - classification_auprc: 0.5640 - classification_loss: 0.8022 - classification_precision: 0.5237 - classification_recall: 0.6324 - loss: 2.7742 - segmentation_dice_coef_multiclass: 0.2618 - segmentation_iou_multiclass: 0.1832 - segmentation_loss: 0.7382 - val_classification_accuracy: 0.4474 - val_classification_auprc: 0.5322 - val_classification_loss: 0.8187 - val_classification_precision: 0.4474 - val_classification_recall: 1.0000 - val_loss: 2.8653 - val_segmentation_dice_coef_multiclass: 0.2607 - val_segmentation_iou_multiclass: 0.2232 - val_segmentation_loss: 0.7393 - learning_rate: 1.0000e-04
Epoch 2/200
[1m57/57[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 634ms/step - classification_accuracy: 0.5538 - classification_auprc: 0.5270 - classification_loss: 0.8346 - classification_precision: 0.5293 - classification_recall: 0.6411 - loss: 2