In [2]:
import os
import glob

import tqdm
import random
import numpy as np
import pandas as pd
import nibabel as nib
import scipy.ndimage as sci
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
# from keras_preprocessing.image import ImageDataGenerator

In [3]:
def rotation_3d(img, angle, axis = (1,2), order = 3, isseg = False):
    if isseg == False:
        # interpolate image if the image is not a segmentation mask
        rotated = sci.rotate(img, angle, axis, order = order, reshape=False)
    else:
        # do not interpolate i.e., order = 0
        rotated = sci.rotate(img, angle, axis, order = 0, reshape = False)

    return rotated

# def get_augmentation(patch_size):
#     # CURRENT CONFIG:
#     # For RandomRotate90 -> given image (z,x,y), rotate along z-axis with axes = (2,1)
#     return Compose([
#         # Rotate((0, 0), (-2, 2), (-2, 2), interpolation = 0, always_apply= True, p =1)
#         RandomRotate90(axes = (2,1), always_apply = True, p = 1)
#     ], p=1.0)

Load In Batches
---

In [4]:
scalar = MinMaxScaler()
def generate_brats_batch(file_pattern, contrasts, batch_size=32, tumour='*', patient_ids='*', augment_size=None):
    """
    Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
    
    Example:
    x_batch shape: (32, 240, 240, 155, 4)
    y_batch shape: (32, 240, 240, 155)
    
    augment_size must be less than or equal to the batch_size, if None will not augment.
    
    """
    n_classes = 4

    # get list of filenames for every contrast available
    keys = dict(prefix=prefix, tumour=tumour)
    filenames_by_contrast = {}
    for contrast in contrasts:
        filenames_by_contrast[contrast] = glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_ids, **keys)) if patient_ids == '*' else []
        if patient_ids != '*':
            contrast_files = []
            for patient_id in patient_ids:
                contrast_files.extend(glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_id, **keys)))
            filenames_by_contrast[contrast] = contrast_files
    
    # get the shape of one 3D volume and initialize the batch lists
    arbitrary_contrast = contrasts[0]
    shape = nib.load(filenames_by_contrast[arbitrary_contrast][0]).get_fdata().shape
    x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
    y_batch = np.empty((batch_size, ) + shape + (n_classes,)) #, dtype=np.int32)
    num_images = len(filenames_by_contrast[arbitrary_contrast])
    np.random.shuffle(filenames_by_contrast[arbitrary_contrast])
    for bindex in tqdm.tqdm_notebook(range(0, num_images, batch_size), total=num_images):
        filenames = filenames_by_contrast[arbitrary_contrast][bindex:bindex + batch_size]
        for findex, filename in enumerate(filenames):
            for cindex, contrast in enumerate(contrasts):

                # load raw image batches and normalize the pixels
                tmp_img = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
                tmp_img = scalar.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
                x_batch[findex, ..., cindex] = tmp_img

                # load mask batches and change to categorical
                tmp_mask = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()
                tmp_mask[tmp_mask==4] = 3
                tmp_mask = to_categorical(tmp_mask, num_classes = 4)
                y_batch[findex] = tmp_mask
        
        if bindex + batch_size > num_images:
            x_batch, y_batch = x_batch[:num_images - bindex], y_batch[:num_images - bindex]
        if augment_size is not None:
            # x_aug, y_aug = augment(x_batch, y_batch, augment_size)
            x_aug = None
            y_aug = None
            yield np.append(x_batch, x_aug), np.append(y_batch, y_aug)
        else:
            yield x_batch, y_batch

Loading Image Parameters
---

In [5]:
tumours = ['LGG','HGG']
prefix = '/Users/jasonfung/Documents/EECE571'
brats_dir = '/MICCAI_BraTS_2018_Data_Training/'
prefix = '/home/atom/Documents/datasets/brats'
file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
# patient_id = 'Brats18_TCIA09_620_1'
contrasts = ['t1ce', 'flair', 't2']
tumours = ['LGG', 'HGG']

data_list_LGG = os.listdir(os.path.join(prefix+brats_dir,tumours[0]))
data_list_HGG = os.listdir(os.path.join(prefix+brats_dir,tumours[1]))
dataset_file_list = data_list_HGG + data_list_LGG

batch_size = 2

# shuffle and split the dataset file list
import random
random.seed(42)
file_list_shuffled = dataset_file_list.copy()
random.shuffle(file_list_shuffled)
test_ratio = 0.2

train_file, test_file = file_list_shuffled[0:int(len(file_list_shuffled)*(1-test_ratio))], file_list_shuffled[int(len(file_list_shuffled)*(1-test_ratio)):]

train_datagen = generate_brats_batch(file_pattern, contrasts, batch_size=batch_size, patient_ids=train_file) # first iteration
test_datagen = generate_brats_batch(file_pattern, contrasts, batch_size=batch_size, patient_ids=test_file) # first iteration

## Test Train the Model

In [7]:
import segmentation_models_3D as sm 
sm.set_framework('tf.keras')

# data parameters
x_size = None
y_size = None
z_size = None
contrast_channels = 3
input_shape = (x_size, y_size, z_size, contrast_channels)
n_classes = 4

# define Hyper Parameters
LR = 0.0001
activation = 'softmax'
encoder_weights = 'imagenet'
BACKBONE = 'resnet50'
optim = tf.keras.optimizers.Adam(LR)
class_weights = [0.25, 0.25, 0.25, 0.25]

# Define Loss Functions
dice_loss = sm.losses.DiceLoss(class_weights=class_weights)
focal_loss = sm.losses.CategoricalFocalLoss()
total_loss = dice_loss + (1*focal_loss)
metrics = [sm.metrics.IOUScore(threshold = 0.5), sm.metrics.FScore(threshold = 0.5)]

# Define the model being used. In this case, UNet
model = sm.Unet(backbone_name= BACKBONE,
                classes = n_classes,
                input_shape = input_shape,
                encoder_weights = encoder_weights,
                activation = activation,
                decoder_block_type = 'upsampling') #'transpose')

model.compile(optimizer = optim, loss = total_loss, metrics = metrics)

with tf.device('/device:GPU:0'):
    history = model.fit(train_datagen,
                        epochs = 50,
                        verbose = 1,
                        validation_data = test_datagen)


Epoch 1/50


InvalidArgumentError:  ConcatOp : Dimensions of inputs should match: shape[0] = [2,16,16,10,2048] vs. shape[1] = [2,15,15,10,1024]
	 [[node functional_7/decoder_stage0_concat/concat (defined at tmp/ipykernel_270514/1098615544.py:37) ]] [Op:__inference_train_function_50120]

Function call stack:
train_function


Previous Approach Based on Survival Metadata
---

In [2]:
def read_file(patient_id, contrast):
    filenames = [file_pattern for tumour in tumours]
    try:
        x = nib.load(filenames[0])
    except FileNotFoundError:
        x = nib.load(filenames[1])
    filenames.replace()
    nii_data = x.get_fdata()
    return nii_data

In [6]:
metadata = pd.read_csv(training_set)

In [7]:
def create_filename(row, contrast):
    possible_filenames = [file_pattern.format(prefix=prefix, tumour=tumour, patient_id=row['BraTS18ID'], contrast=contrast) for tumour in tumours]
    filename = [filename for filename in possible_filenames if os.path.exists(filename)][0]
    return filename

In [8]:
metadata['filename_flair'] = metadata.apply(create_filename, axis=1, contrast='flair')
metadata['filename_t1ce'] = metadata.apply(create_filename, axis=1, contrast='t1ce')
metadata['filename_seg'] = metadata.apply(create_filename, axis=1, contrast='seg')

In [9]:
metadata

Unnamed: 0,BraTS18ID,Age,Survival,ResectionStatus,filename_flair,filename_t1ce,filename_seg
0,Brats18_TCIA08_167_1,74.907,153,,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
1,Brats18_TCIA08_242_1,66.479,147,,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
2,Brats18_TCIA08_319_1,64.860,254,,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
3,Brats18_TCIA08_469_1,63.899,519,,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
4,Brats18_TCIA08_218_1,57.345,346,,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
...,...,...,...,...,...,...,...
158,Brats18_CBICA_ABB_1,68.493,465,GTR,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
159,Brats18_CBICA_AAP_1,39.068,788,GTR,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
160,Brats18_CBICA_AAL_1,54.301,464,GTR,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...
161,Brats18_CBICA_AAG_1,52.263,616,GTR,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...,/home/atom/Documents/datasets/brats/MICCAI_Bra...


# Jason's Implementation

## Load Data in Batches

In [12]:
def generate_brats_batch(file_pattern, contrasts, patient_list, batch_size=32):
    """
    Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
    
    Example:
    x_batch shape: (32, 240, 240, 155, 4)
    y_batch shape: (32, 240, 240, 155)
    """
    
    keys = dict(prefix=prefix, tumour="*", patient_id="*")
    filenames = {contrast: glob.glob(file_pattern.format(contrast=contrast, **keys)) 
                 for contrast in contrasts}
    
    arbitrary_contrast = contrasts[0]
    shape = nib.load(filenames[arbitrary_contrast][0]).get_fdata().shape

    # define empty arrays for batches
    x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
    y_batch = np.empty((batch_size, ) + shape) #, dtype=np.int32)
    num_images = len(filenames[arbitrary_contrast])

    # shuffle 
    np.random.shuffle(filenames[arbitrary_contrast])

    for bindex in tqdm.tqdm_notebook(range(0, num_images, batch_size), total=num_images):
        filenames = filenames[arbitrary_contrast][bindex:bindex + batch_size]
        for findex, filename in enumerate(filenames):
            for cindex, contrast in enumerate(contrasts):
                x_batch[findex, ..., cindex] = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
                y_batch[findex] = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()                
        yield (x_batch, y_batch)
    

In [22]:

# train and test file holds patient ID's

# go through all batch sizes until it reaches the end of the directory 
start_batch = 0
limit = len(train_file)
end_batch = batch_size

while end_batch < limit:
    img_list = []
    mask_list = []

    file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'


    
    
    



In [20]:
import segmentation_models_3D as sm



Segmentation Models: using `keras` framework.


In [24]:
keys = dict(prefix=prefix, tumour="*", patient_id="*")
filenames = {contrast: glob.glob(file_pattern.format(contrast=contrast, **keys)) for contrast in contrasts}

In [25]:
filenames

{'Brats18_CBICA_ARW_1_seg.nii.gz': [],
 'Brats18_CBICA_ARW_1_flair.nii.gz': [],
 'Brats18_CBICA_ARW_1_t1.nii.gz': [],
 'Brats18_CBICA_ARW_1_t2.nii.gz': [],
 'Brats18_CBICA_ARW_1_t1ce.nii.gz': []}