In [1]:
import tensorflow as tf

# Check available devices
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("GPU Device Name: ", tf.test.gpu_device_name())


Num GPUs Available:  0
GPU Device Name:  


In [None]:
import tensorflow as tf
from tensorflow.python.keras import backend as K

# Clear the session to free up memory
K.clear_session()


In [None]:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_strict_conv_algorithm_picker=false"


In [None]:
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)


In [None]:
# Enable GPU memory growth
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Using GPU: {gpu}")

# Debugging logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'

In [None]:
import os
import nibabel as nib
import numpy as np
from glob import glob
from skimage.transform import resize
from skimage.metrics import structural_similarity as ssim
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, Concatenate
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import datetime
import matplotlib.pyplot as plt

# Function to load NIfTI images
def load_nii(file_path):
    img = nib.load(file_path)
    data = img.get_fdata()
    return np.array(data, dtype=np.float32)

# Resize images to a consistent shape
def resize_image(image, target_shape=(240, 240, 240)):
    return resize(image, target_shape, mode='reflect', anti_aliasing=True)

# Normalize image to range [0, 1]
def normalize(image):
    return (image - np.min(image)) / (np.max(image) - np.min(image))

# Apply brain mask to image
def apply_mask(image, mask):
    return image * mask

# Load dataset and preprocess
def load_dataset_with_manual_masks(base_path, mode='Training'):
    data = []
    mode_path = os.path.join(base_path, mode)
    centers = [os.path.join(mode_path, center) for center in os.listdir(mode_path) if os.path.isdir(os.path.join(mode_path, center))]

    for center in centers:
        patients = [os.path.join(center, patient) for patient in os.listdir(center) if os.path.isdir(os.path.join(center, patient))]
        for patient in patients:
            preprocessed_path = os.path.join(patient, 'Preprocessed_Data')
            t1_path = os.path.join(preprocessed_path, 'T1_preprocessed.nii.gz')
            t1ce_path = os.path.join(preprocessed_path, 'GADO_preprocessed.nii.gz')
            mask_path = os.path.join(patient, 'Masks', 'Brain_Mask.nii.gz')
            manual_masks = glob(os.path.join(patient, 'Masks', 'ManualSegmentation_*.nii.gz'))

            if os.path.exists(t1_path) and os.path.exists(t1ce_path) and os.path.exists(mask_path) and len(manual_masks) > 0:
                data.append({
                    't1': t1_path,
                    't1ce': t1ce_path,
                    'mask': mask_path,
                    'manual_masks': manual_masks
                })
    return data

# Preprocessing function for a given patient data
def preprocess_with_manual_masks(t1_path, t1ce_path, mask_path, manual_mask_paths):
    t1 = normalize(load_nii(t1_path))
    t1ce = normalize(load_nii(t1ce_path))
    brain_mask = load_nii(mask_path)
    manual_masks = [load_nii(m) for m in manual_mask_paths]
    combined_mask = np.mean(manual_masks, axis=0)
    combined_mask = (combined_mask > 0.5).astype(np.float32)

    # Apply brain mask
    t1 = apply_mask(t1, brain_mask)
    t1ce = apply_mask(t1ce, brain_mask)
    combined_mask = apply_mask(combined_mask, brain_mask)

    # Resize images to consistent shape
    t1 = resize_image(t1)
    t1ce = resize_image(t1ce)
    combined_mask = resize_image(combined_mask)

    return t1, t1ce, combined_mask

# Retina U-Net model
def retina_unet_with_manual_masks(input_shape):
    inputs = Input(shape=input_shape)
    c1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling3D((2, 2, 2))(c1)

    c2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(p1)
    c2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling3D((2, 2, 2))(c2)

    c3 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(p2)

    u1 = UpSampling3D((2, 2, 2))(c3)
    u1 = Concatenate()([u1, c2])
    c4 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(u1)

    u2 = UpSampling3D((2, 2, 2))(c4)
    u2 = Concatenate()([u2, c1])
    c5 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(u2)

    outputs = Conv3D(1, (1, 1, 1), activation='sigmoid')(c5)
    return tf.keras.Model(inputs, outputs)

# Synthesis Module
def synthesis_module(input_shape):
    inputs = Input(shape=input_shape)
    c1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c1)

    c2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c1)
    c2 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c2)

    outputs = Conv3D(1, (1, 1, 1), activation='linear')(c2)
    return tf.keras.Model(inputs, outputs)

In [None]:
# Load dataset and preprocess
dataset_path = "dataset"
train_data = load_dataset_with_manual_masks(dataset_path, mode='Training')
test_data = load_dataset_with_manual_masks(dataset_path, mode='Testing')

train_images = [preprocess_with_manual_masks(d['t1'], d['t1ce'], d['mask'], d['manual_masks']) for d in train_data]
test_images = [preprocess_with_manual_masks(d['t1'], d['t1ce'], d['mask'], d['manual_masks']) for d in test_data]

x_train = np.expand_dims(np.array([img[0] for img in train_images]), axis=-1)
y_train = np.expand_dims(np.array([img[2] for img in train_images]), axis=-1)

# TensorBoard and Checkpoint Callbacks
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, profile_batch=2)

checkpoint_dir = "checkpoints/retina_unet/"
retina_checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(checkpoint_dir, "cp-{epoch:04d}.weights.h5"),
    save_weights_only=True,
    verbose=1
)


# Retina U-Net training
retina_model = retina_unet_with_manual_masks(input_shape=(240, 240, 240, 1))
retina_model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='binary_crossentropy', metrics=['accuracy'])

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    retina_model.load_weights(latest_checkpoint)
    print(f"Loaded Retina U-Net checkpoint from {latest_checkpoint}")

retina_model.fit(
    x_train, y_train, 
    epochs=50, 
    batch_size=2, 
    callbacks=[tensorboard_callback, retina_checkpoint_callback]
)

semantic_features = retina_model.predict(x_train)
combined_input = np.concatenate([x_train, semantic_features], axis=-1)

# Synthesis Module
synthesis_checkpoint_dir = "checkpoints/synthesis_module/"
synthesis_checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(synthesis_checkpoint_dir, "cp-{epoch:04d}.ckpt"),
    save_weights_only=True,
    verbose=1
)

synthesis_model = synthesis_module(input_shape=(240, 240, 240, 2))
synthesis_model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='mse')

latest_synthesis_checkpoint = tf.train.latest_checkpoint(synthesis_checkpoint_dir)
if latest_synthesis_checkpoint:
    synthesis_model.load_weights(latest_synthesis_checkpoint)
    print(f"Loaded Synthesis Module checkpoint from {latest_synthesis_checkpoint}")

synthesis_model.fit(
    combined_input, x_train, 
    epochs=50, 
    batch_size=2, 
    callbacks=[tensorboard_callback, synthesis_checkpoint_callback]
)

# Evaluation on test data
x_test = np.expand_dims(np.array([img[0] for img in test_images]), axis=-1)
semantic_features_test = retina_model.predict(x_test)
combined_input_test = np.concatenate([x_test, semantic_features_test], axis=-1)
synthetic_t1ce = synthesis_model.predict(combined_input_test)

# Visualize results
for i in range(3):  # Visualize 3 test samples
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title("Original T1")
    plt.imshow(x_test[i, :, :, 120, 0], cmap='gray')
    plt.subplot(1, 3, 2)
    plt.title("Synthetic T1CE")
    plt.imshow(synthetic_t1ce[i, :, :, 120, 0], cmap='gray')
    plt.subplot(1, 3, 3)
    plt.title("Ground Truth T1CE")
    plt.imshow(test_images[i][1][:, :, 120], cmap='gray')
    plt.show()