In [1]:
import nibabel as nib
import numpy as np
import os
import cv2
import types
import tensorflow as tf
import tensorflow.keras.backend as K
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
from skimage.io import imread, imshow
from skimage.transform import resize
from keras.models import Model
from keras.layers import Input, BatchNormalization, Activation
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.layers import concatenate
from sklearn.model_selection import train_test_split

# -----DEBUG-----
def imports():
    for name, val in globals().items():
        if isinstance(val, types.ModuleType):
            yield val.__name__
print(list(imports()))

2024-08-03 07:16:28.409848: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-03 07:16:28.450245: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


['builtins', 'builtins', 'nibabel', 'numpy', 'os', 'cv2', 'types', 'tensorflow', 'tensorflow.keras.backend', 'matplotlib.pyplot']


In [2]:
# Paths
TRAIN_DATASET_PATH = '/home/user/Tf_script/dataset/ISLES_2022/rawdata/'
TRAINMask_DATASET_PATH = '/home/user/Tf_script/dataset/ISLES_2022/derivatives/'

# Get dataset details
trainfolders = os.listdir(TRAIN_DATASET_PATH)
train_directories = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]

# train_ids = [train_directories[i][48:66] for i in range(len(train_directories))]
train_directory_startindex = train_directories[0].find("sub")
train_ids = sorted([train_directories[i][train_directory_startindex:] for i in range(len(train_directories))])

maskfolders = os.listdir(TRAINMask_DATASET_PATH)
mask_directories = [f.path for f in os.scandir(TRAINMask_DATASET_PATH) if f.is_dir()]

# mask_ids = [mask_directories[i][48:66] for i in range(len(mask_directories))]
mask_id_startindex = mask_directories[0].find("sub")
mask_ids = sorted([mask_directories[i][mask_id_startindex:] for i in range(len(mask_directories))])

# -----DEBUG-----
print("Train IDs: ", len(train_ids))
# print(train_ids[0], "to", train_ids[-1])
print(sorted(train_ids)[:5])
print("Mask IDs: ", len(mask_ids))
# print(mask_ids[0], "to", mask_ids[-1])
print(sorted(mask_ids)[:5])

Train IDs:  250
['sub-strokecase0001', 'sub-strokecase0002', 'sub-strokecase0003', 'sub-strokecase0004', 'sub-strokecase0005']
Mask IDs:  250
['sub-strokecase0001', 'sub-strokecase0002', 'sub-strokecase0003', 'sub-strokecase0004', 'sub-strokecase0005']


In [3]:
# train_test_ids, val_ids, train_test_mask, val_mask = train_test_split(train_ids, mask_ids, test_size=0.15, random_state=42)
# train_ids, test_ids, train_mask, test_mask = train_test_split(train_test_ids, train_test_mask, test_size=0.15, random_state=42)
train_test_ids, val_ids,train_test_mask, val_mask = train_test_split(train_ids,mask_ids,test_size=0.2,random_state = 32) 
train_ids,  test_ids, train_mask , test_mask = train_test_split(train_test_ids,train_test_mask,test_size=0.2,random_state = 32)

scaler = MinMaxScaler()

IMG_SIZE = 112

# -----DEBUG-----
print("Dimensions: ", ("{} X {}".format(IMG_SIZE, IMG_SIZE)))

Dimensions:  112 X 112


In [4]:
# Performance metrics
def dice_coeff(y_true, y_pred):
    y_true_new = K.flatten(y_true)
    y_pred_new = K.flatten(y_pred)
    denominator = K.sum(y_true_new) + K.sum(y_pred_new)
    numerator = K.sum(y_true_new * y_pred_new)
    return (2 * numerator + 1) / (denominator + 1)

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 1 - (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def iou(y_true, y_pred):
    intersec = K.sum(y_true * y_pred)
    union = K.sum(y_true + y_pred)
    return (intersec + 0.1) / (union - intersec + 0.1)

In [5]:
# Data generator class
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, list_IDs, dim=(IMG_SIZE, IMG_SIZE), batch_size=1, n_channels=1, shuffle=True):
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        Batch_ids = [self.list_IDs[k] for k in indexes]
        X, y = self.__data_generation(Batch_ids)
        return X, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        X = []
        y = []
        for i in Batch_ids:
            case_path = os.path.join(TRAIN_DATASET_PATH, i, 'ses-0001/dwi')
            nii_files = [f for f in os.listdir(case_path) if f.endswith('.nii.gz')]

            if not nii_files:
                print(f"No .nii.gz files found in {case_path}")
                continue

            file_path = os.path.join(case_path, nii_files[0])
            dwi = nib.load(file_path).get_fdata()
            dwi = scaler.fit_transform(dwi.reshape(-1, dwi.shape[-1])).reshape(dwi.shape)
            slices = dwi.shape[2]
            X_case = np.zeros((slices, IMG_SIZE, IMG_SIZE, 1))

            case_path2 = os.path.join(TRAINMask_DATASET_PATH, i)
            data_path_2 = os.path.join(case_path2 + '/ses-0001', f'{i}_ses-0001_msk.nii.gz')

            if not os.path.exists(data_path_2):
                print(f"Mask file not found: {data_path_2}")
                continue

            msk = nib.load(data_path_2).get_fdata()
            msk_slices = msk.shape[2]
            y_case = np.zeros((msk_slices, IMG_SIZE, IMG_SIZE))

            for j in range(slices):
                X_case[j, :, :, 0] = cv2.resize(dwi[:, :, j], (IMG_SIZE, IMG_SIZE))
                y_case[j, :, :] = cv2.resize(msk[:, :, j], (IMG_SIZE, IMG_SIZE))

            X.append(X_case)
            y.append(y_case)

        X = np.concatenate(X, axis=0).astype(np.float32)
        y = np.concatenate(y, axis=0).astype(np.float32)
        mask = tf.one_hot(y, depth=1)
        return X, mask

training_generator = DataGenerator(train_ids, batch_size=1)  # Reduce batch size
val_generator = DataGenerator(val_ids, batch_size=1)
test_generator = DataGenerator(test_ids, batch_size=1)

# -----DEBUG-----
print("Training: ", len(training_generator))
print("Validation: ", len(val_generator))
print("Testing: ", len(test_generator))

Training:  160
Validation:  50
Testing:  40


In [14]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

from math import log2
import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model

def mlp(x, cf):
    x = L.Dense(cf["mlp_dim"], activation="gelu")(x)
    x = L.Dropout(cf["dropout_rate"])(x)
    x = L.Dense(cf["hidden_dim"])(x)
    x = L.Dropout(cf["dropout_rate"])(x)
    return x

def transformer_encoder(x, cf):
    skip_1 = x
    x = L.LayerNormalization()(x)
    x = L.MultiHeadAttention(
        num_heads=cf["num_heads"], key_dim=cf["hidden_dim"]
    )(x, x)
    x = L.Add()([x, skip_1])

    skip_2 = x
    x = L.LayerNormalization()(x)
    x = mlp(x, cf)
    x = L.Add()([x, skip_2])

    return x

def conv_block(x, num_filters, kernel_size=3):
    x = L.Conv2D(num_filters, kernel_size=kernel_size, padding="same")(x)
    x = L.BatchNormalization()(x)
    x = L.ReLU()(x)
    return x

def deconv_block(x, num_filters, strides=2):
    x = L.Conv2DTranspose(num_filters, kernel_size=2, padding="same", strides=strides)(x)
    return x

def build_unetr_2d(cf):
    """ Inputs """
    input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"])
    inputs = L.Input(input_shape) ## (None, 256, 3072)

    """ Patch + Position Embeddings """
    patch_embed = L.Dense(cf["hidden_dim"])(inputs) ## (None, 256, 768)

    positions = tf.range(start=0, limit=cf["num_patches"], delta=1) ## (256,)
    pos_embed = L.Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (256, 768)
    x = patch_embed + pos_embed ## (None, 256, 768)

    """ Transformer Encoder """
    skip_connection_index = [3, 6, 9, 12]
    skip_connections = []

    for i in range(1, cf["num_layers"]+1, 1):
        x = transformer_encoder(x, cf)

        if i in skip_connection_index:
            skip_connections.append(x)

    """ CNN Decoder """
    z3, z6, z9, z12 = skip_connections

    ## Reshaping
    z0 = L.Reshape((cf["image_size"], cf["image_size"], cf["num_channels"]))(inputs)

    shape = (
        cf["image_size"]//cf["patch_size"],
        cf["image_size"]//cf["patch_size"],
        cf["hidden_dim"]
    )
    z3 = L.Reshape(shape)(z3)
    z6 = L.Reshape(shape)(z6)
    z9 = L.Reshape(shape)(z9)
    z12 = L.Reshape(shape)(z12)

    ## Additional layers for managing different patch sizes
    total_upscale_factor = int(log2(cf["patch_size"]))
    upscale = total_upscale_factor - 4

    if upscale >= 2: ## Patch size 16 or greater
        z3 = deconv_block(z3, z3.shape[-1], strides=2**upscale)
        z6 = deconv_block(z6, z6.shape[-1], strides=2**upscale)
        z9 = deconv_block(z9, z9.shape[-1], strides=2**upscale)
        z12 = deconv_block(z12, z12.shape[-1], strides=2**upscale)
        # print(z3.shape, z6.shape, z9.shape, z12.shape)

    if upscale < 0: ## Patch size less than 16
        p = 2**abs(upscale)
        z3 = L.MaxPool2D((p, p))(z3)
        z6 = L.MaxPool2D((p, p))(z6)
        z9 = L.MaxPool2D((p, p))(z9)
        z12 = L.MaxPool2D((p, p))(z12)

    ## Decoder 1
    x = deconv_block(z12, 128)

    s = deconv_block(z9, 128)
    s = conv_block(s, 128)

    x = L.Concatenate()([x, s])

    x = conv_block(x, 128)
    x = conv_block(x, 128)

    ## Decoder 2
    x = deconv_block(x, 64)

    s = deconv_block(z6, 64)
    s = conv_block(s, 64)
    s = deconv_block(s, 64)
    s = conv_block(s, 64)

    x = L.Concatenate()([x, s])
    x = conv_block(x, 64)
    x = conv_block(x, 64)

    ## Decoder 3
    x = deconv_block(x, 32)

    s = deconv_block(z3, 32)
    s = conv_block(s, 32)
    s = deconv_block(s, 32)
    s = conv_block(s, 32)
    s = deconv_block(s, 32)
    s = conv_block(s, 32)

    x = L.Concatenate()([x, s])
    x = conv_block(x, 32)
    x = conv_block(x, 32)

    ## Decoder 4
    x = deconv_block(x, 16)

    s = conv_block(z0, 16)
    s = conv_block(s, 16)

    x = L.Concatenate()([x, s])
    x = conv_block(x, 16)
    x = conv_block(x, 16)

    """ Output """
    outputs = L.Conv2D(1, kernel_size=1, padding="same", activation="sigmoid")(x)

    return Model(inputs, outputs, name="UNETR_2D")

In [15]:
input_shape = (IMG_SIZE, IMG_SIZE, 1)
num_classes = 1  # Binary classification
model = ViT(input_shape, num_classes)
model.summary()

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 112, 112, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_7 (Conv2D)              (None, 28, 28, 64)   1088        ['input_4[0][0]']                
                                                                                                  
 reshape_5 (Reshape)            (None, 784, 64)      0           ['conv2d_7[0][0]']               
                                                                                                  
 tf.__operators__.add_3 (TFOpLa  (None, 784, 64)     0           ['reshape_5[0][0]']        

In [16]:
# Compile the model
model.compile(optimizer='adam', loss=dice_loss, metrics=[dice_coeff, iou])

# Checkpoints and learning rate adjustments
checkpoint = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=True, verbose=1)

# Train the model
history = model.fit(training_generator,
                    validation_data=val_generator,
                    epochs=250,
                    steps_per_epoch=len(train_ids),
                    callbacks=[checkpoint, reduce_lr, early_stopping])

# Save the model
model.save("vit_brain_lesion_segmentation.h5")

Epoch 1/250


2024-08-03 07:24:54.912715: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]
2024-08-03 07:25:04.158160: W tensorflow/core/framework/op_kernel.cc:1818] INVALID_ARGUMENT: required broadcastable shapes
2024-08-03 07:25:04.158228: I tensorflow/core/common_runtime/executor.cc:1197] [/job:localhost/replica:0/task:0/device:GPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: required broadcastable shapes
	 [[{{node add_3}}]]
2024-08-03 07:25:04.158255: W tensorflow/core/framework/op_kernel.cc:1818] INVALID_ARGUMENT: required broadcastable shapes
2024-08-03 07:25:04.158285: W tensorflow/core/framework/op_kernel.cc:1818] INVALID_ARGUMENT: required broadcastable shapes


InvalidArgumentError: Graph execution error:

Detected at node 'add_3' defined at (most recent call last):
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/traitlets/config/application.py", line 992, in launch_instance
      app.start()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
      self.io_loop.start()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
      self._run_once()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
      handle._run()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
      await self.process_one()
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 505, in process_one
      await dispatch(*args)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
      await result
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
      reply_content = await reply_content
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
      res = shell.run_cell(
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmp/ipykernel_1848780/2965651932.py", line 10, in <module>
      history = model.fit(training_generator,
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/training.py", line 1685, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/training.py", line 1284, in train_function
      return step_function(self, iterator)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/training.py", line 1268, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/training.py", line 1249, in run_step
      outputs = model.train_step(data)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/training.py", line 1055, in train_step
      return self.compute_metrics(x, y, y_pred, sample_weight)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/training.py", line 1149, in compute_metrics
      self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 605, in update_state
      metric_obj.update_state(y_t, y_p, sample_weight=mask)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/utils/metrics_utils.py", line 77, in decorated
      update_op = update_state_fn(*args, **kwargs)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/metrics/base_metric.py", line 140, in update_state_fn
      return ag_update_state(*args, **kwargs)
    File "/home/user/anaconda3/envs/adi_ani_tf/lib/python3.10/site-packages/keras/metrics/base_metric.py", line 691, in update_state
      matches = ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/tmp/ipykernel_1848780/457084516.py", line 17, in iou
      union = K.sum(y_true + y_pred)
Node: 'add_3'
required broadcastable shapes
	 [[{{node add_3}}]] [Op:__inference_train_function_61904]