In [1]:
import os
from google.colab import drive
drive.mount('/content/drive')

!pip install SimpleITK tqdm

BASE_DIR = '/content/drive/MyDrive/research_intern_med/data_root/brain_tumor/Task01_BrainTumor'
IMAGE_TR_DIR = os.path.join(BASE_DIR, 'imagesTr')
LABEL_TR_DIR = os.path.join(BASE_DIR, 'labelsTr')
IMAGE_TS_DIR = os.path.join(BASE_DIR, 'imagesTs')

Mounted at /content/drive
Collecting SimpleITK
  Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (52.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 MB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.5.2


In [2]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting tensorboard~=2.19.0 (from tensorflow)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1 (from tensorflow)
  Downloading tensorflow_io_gcs_filesystem-0.37.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Collecting wheel<1.0,>=0.23.0 (from astunparse>=1.6.0->tensorflow

In [3]:
import SimpleITK as sitk
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import os

IMAGE_TR_DIR = os.path.join(BASE_DIR, 'imagesTr')
LABEL_TR_DIR = os.path.join(BASE_DIR, 'labelsTr')

preprocessed_data_path = os.path.join(BASE_DIR, 'preprocessed_brain_tumor_data.npz')


def preprocess(image, label):
    image = tf.cast(image, tf.float32)
    label = tf.cast(label, tf.float32)
    mean = tf.reduce_mean(image)
    std = tf.math.reduce_std(image)
    image = (image - mean) / std
    label = tf.expand_dims(label, axis=-1)
    return image, label

def load_dataset(image_dir, label_dir):
    image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.nii.gz') and not f.startswith('._')])
    label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.nii.gz') and not f.startswith('._')])
    images = []
    labels = []
    print("Loading data...")
    for img_path, lbl_path in tqdm(zip(image_files, label_files), total=len(image_files)):
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        lbl = sitk.GetArrayFromImage(sitk.ReadImage(lbl_path))
        for i in range(lbl.shape[0]):
            lbl_slice = lbl[i, :, :]
            if np.sum(lbl_slice) > 0:
                img_slice_4_channels = img[:, i, :, :]
                img_slice_transposed = np.transpose(img_slice_4_channels, (1, 2, 0))
                img_slice_resized = tf.image.resize(img_slice_transposed, [128, 128])
                lbl_slice_resized = tf.image.resize(np.expand_dims(lbl_slice, axis=-1), [128, 128], method='nearest')
                images.append(img_slice_resized.numpy())
                labels.append(lbl_slice_resized.numpy().squeeze())
    return np.array(images), np.array(labels)


if os.path.exists(preprocessed_data_path):
    print(f"File with preprocessed data exists...")
    with np.load(preprocessed_data_path) as data:
        train_images = data['images']
        train_labels = data['labels']
    print("Completed loading data")

else:
    print(f"Starting data loading and prprocessing...")
    train_images, train_labels = load_dataset(IMAGE_TR_DIR, LABEL_TR_DIR)

    print(f"Saving preprocessed data...")
    np.savez_compressed(preprocessed_data_path, images=train_images, labels=train_labels)
    print(f"Completed saving preprocessed data")


print("\nGenerating Tensorflow dataset...")
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.map(preprocess)
train_dataset = train_dataset.shuffle(buffer_size=1000).batch(16).prefetch(tf.data.AUTOTUNE)

print(f"\nNumber of slices: {len(train_images)}")
for image_batch, label_batch in train_dataset.take(1):
    print(f"Image batch shape: {image_batch.shape}")
    print(f"Label batch shape: {label_batch.shape}")

File with preprocessed data exists...
Completed loading data

Generating Tensorflow dataset...

Number of slices: 33755
Image batch shape: (16, 128, 128, 4)
Label batch shape: (16, 128, 128, 1)


In [10]:
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, Input, Activation, add, concatenate, multiply, BatchNormalization, SpatialDropout2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import tensorflow.keras.backend as K

def dice_coef(y_true, y_pred, smooth=1e-7):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def attention_gate(X, g, inter_channel):
    theta_x = Conv2D(inter_channel, (1, 1), strides=(1, 1), padding='same')(X)
    phi_g = Conv2D(inter_channel, (1, 1), strides=(1, 1), padding='same')(g)

    f = Activation('relu')(add([theta_x, phi_g]))
    psi_f = Conv2D(1, (1, 1), strides=(1, 1), padding='same')(f)
    rate = Activation('sigmoid')(psi_f)

    att_x = multiply([X, rate])
    return att_x

def unet_conv_block(inputs, num_filters, dropout_rate=0.2):
    x = Conv2D(num_filters, (3, 3), padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(num_filters, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    if dropout_rate > 0:
        x = SpatialDropout2D(dropout_rate)(x)
    return x

def build_attention_unet_optimized(input_shape=(128, 128, 4), num_classes=1):
    inputs = Input(input_shape, name='main_input')

    # Encoder
    c1 = unet_conv_block(inputs, 16)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = unet_conv_block(p1, 32)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = unet_conv_block(p2, 64)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = unet_conv_block(p3, 128)
    p4 = MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = unet_conv_block(p4, 256, dropout_rate=0.3)

    # Decoder
    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    att6 = attention_gate(c4, u6, 128)
    u6 = concatenate([u6, att6])
    c6 = unet_conv_block(u6, 128)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    att7 = attention_gate(c3, u7, 64)
    u7 = concatenate([u7, att7])
    c7 = unet_conv_block(u7, 64)

    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
    att8 = attention_gate(c2, u8, 32)
    u8 = concatenate([u8, att8])
    c8 = unet_conv_block(u8, 32)

    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
    att9 = attention_gate(c1, u9, 16)
    u9 = concatenate([u9, att9])
    c9 = unet_conv_block(u9, 16)

    outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(c9)

    model = Model(inputs=[inputs], outputs=[outputs])

    model.compile(optimizer=Adam(learning_rate=1e-4), loss=dice_coef_loss, metrics=[dice_coef])

    return model

if __name__ == '__main__':
    model = build_attention_unet_optimized()
    model.summary()

    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)

In [14]:
history = model.fit(
    train_dataset,
    epochs=10,
    verbose=1
)

model.save(os.path.join(BASE_DIR, 'attention_unet_brain_tumor_3_hr.keras'))

Epoch 1/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m934s[0m 443ms/step - dice_coef: 1.1229 - loss: -0.1229
Epoch 2/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m933s[0m 442ms/step - dice_coef: 1.1255 - loss: -0.1255
Epoch 3/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m933s[0m 442ms/step - dice_coef: 1.1291 - loss: -0.1291
Epoch 4/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m938s[0m 444ms/step - dice_coef: 1.1299 - loss: -0.1299
Epoch 5/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m934s[0m 443ms/step - dice_coef: 1.1314 - loss: -0.1314
Epoch 6/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m933s[0m 442ms/step - dice_coef: 1.1333 - loss: -0.1333
Epoch 7/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m936s[0m 444ms/step - dice_coef: 1.1355 - loss: -0.1355
Epoch 8/10
[1m2110/2110[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m934s[0m 442ms/step - dice_coef:

In [15]:
!pip install nibabel

Collecting nibabel
  Downloading nibabel-5.3.2-py3-none-any.whl.metadata (9.1 kB)
Downloading nibabel-5.3.2-py3-none-any.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nibabel
Successfully installed nibabel-5.3.2


In [19]:
import os
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import tensorflow as tf
from tqdm import tqdm

SAVED_MODEL_PATH= os.path.join(BASE_DIR, 'attention_unet_brain_tumor_3_hr.keras')

try:
    model = tf.keras.models.load_model(
        SAVED_MODEL_PATH,
        custom_objects={
            'dice_coef_loss': dice_coef_loss,
            'dice_coef': dice_coef,
        }
    )
    print(f"Successfully loaded model from '{SAVED_MODEL_PATH}'.")
except Exception as e:
    print(f"Error- cannot load model: {e}")
    exit()

def load_test_data(test_dir, num_samples=5):
    test_files = sorted([os.path.join(test_dir, f) for f in os.listdir(test_dir)
                         if f.endswith('.nii.gz') and not f.startswith('._')])

    print(f" NIfTI file num at '{test_dir}': {len(test_files)}")
    if not test_files:
        print(f"Error: Cannot find .nii.gz file inside '{test_dir}'.")
        return []

    test_images = []
    print("Loading test data...")

    for f_path in tqdm(test_files[:num_samples]):
        img_nifti = nib.load(f_path)
        img = img_nifti.get_fdata()

        if img.ndim == 4 and img.shape[-1] == 4:
            mid_slice_idx = img.shape[0] // 2
            img_slice_4_channels = img[mid_slice_idx, :, :, :]
            original_slice_for_display = img[mid_slice_idx, :, :, 0]

        elif img.ndim == 3:
            mid_slice_idx = img.shape[0] // 2
            single_channel_slice = img[mid_slice_idx, :, :]
            img_slice_4_channels = np.stack([single_channel_slice, single_channel_slice, single_channel_slice, single_channel_slice], axis=-1)
            original_slice_for_display = single_channel_slice

        else:
            continue

        img_slice_4_channels = img_slice_4_channels.astype(np.float32)

        mean = np.mean(img_slice_4_channels)
        std = np.std(img_slice_4_channels)
        if std == 0:
            img_slice_norm = img_slice_4_channels - mean
        else:
            img_slice_norm = (img_slice_4_channels - mean) / std

        img_tf_input = tf.convert_to_tensor(np.expand_dims(img_slice_norm, axis=0))
        img_resized = tf.image.resize(img_tf_input, [128, 128])

        test_images.append((original_slice_for_display, img_resized[0].numpy()))

    return test_images

def predict_and_visualize(model, test_images):
    plt.figure(figsize=(15, len(test_images) * 5))
    if not test_images:
        print("No image")
        return

    for i, (original_slice, processed_4_channel_slice) in enumerate(test_images):
        pred_mask = model.predict(np.expand_dims(processed_4_channel_slice, axis=0))[0]

        pred_mask_resized = tf.image.resize(np.expand_dims(pred_mask, axis=0),
                                            [original_slice.shape[0], original_slice.shape[1]]).numpy()[0].squeeze()

        plt.subplot(len(test_images), 3, i*3 + 1)
        plt.imshow(original_slice, cmap='bone')
        plt.title(f"Original Image {i+1}")
        plt.axis('off')

        plt.subplot(len(test_images), 3, i*3 + 2)
        plt.imshow(pred_mask_resized, cmap='jet', alpha=0.5)
        plt.title(f"Predicted Mask {i+1}")
        plt.axis('off')

        plt.subplot(len(test_images), 3, i*3 + 3)
        plt.imshow(original_slice, cmap='bone')
        plt.imshow(pred_mask_resized, cmap='jet', alpha=0.5)
        plt.title(f"Overlay {i+1}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

num_samples = 266
test_samples = load_test_data(IMAGE_TS_DIR, num_samples=num_samples)
predict_and_visualize(model, test_samples)

Output hidden; open in https://colab.research.google.com to view.