# Load Libraries and Model

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras 
import nibabel as nib
import os
import glob
from tqdm import tqdm
from keras.models import load_model
my_model = load_model('C:/Users/Fungj/Documents/EECE_571F/unet_model_20220401.h5', 
                      compile=False)

In [2]:
def crop_3D(img, new_size):
    img_shape = img.shape
    x_mid = int(img_shape[0]/2)
    y_mid = int(img_shape[1]/2)
    z_mid = int(img_shape[2]/2)

    x_diff = int(abs(new_size[0]-x_mid))
    y_diff = int(abs(new_size[1]-y_mid))
    z_diff = int(abs(new_size[2]-z_mid))

    x_start = x_mid-x_diff
    y_start = y_mid-y_diff
    z_start = z_mid-z_diff

    tmp_img = img[x_start:x_start+new_size[0],y_start:y_start+new_size[1],z_start:z_start+new_size[2]]
    return tmp_img

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()

def test_generate_brats_batch(file_pattern, 
                         contrasts, 
                         batch_size=32, 
                         tumour='*', 
                         patient_ids='*',
                         crop_size = (None,None,None), 
                         augment_size=None,
                         infinite = True):
    """
    Generate arrays for each batch, for x (data) and y (labels), where the contrast is treated like a colour channel.
    
    Example:
    x_batch shape: (32, 240, 240, 155, 4)
    y_batch shape: (32, 240, 240, 155)
    
    augment_size must be less than or equal to the batch_size, if None will not augment.
    
    """
    while True:
        n_classes = 4

        # get list of filenames for every contrast available
        keys = dict(prefix=prefix, tumour=tumour)
        filenames_by_contrast = {}
        for contrast in contrasts:
            filenames_by_contrast[contrast] = glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_ids, **keys)) if patient_ids == '*' else []
            if patient_ids != '*':
                contrast_files = []
                for patient_id in patient_ids:
                    contrast_files.extend(glob.glob(file_pattern.format(contrast=contrast, patient_id=patient_id, **keys)))
                filenames_by_contrast[contrast] = contrast_files

        # get the shape of one 3D volume and initialize the batch lists
        arbitrary_contrast = contrasts[0]
        if crop_size == (None,None,None):
            shape = nib.load(filenames_by_contrast[arbitrary_contrast][0]).get_fdata().shape
        else:
            shape = crop_size

        # initialize empty array of batches
        x_batch = np.empty((batch_size, ) + shape + (len(contrasts), )) #, dtype=np.int32)
        y_batch = np.empty((batch_size, ) + shape + (n_classes,)) #, dtype=np.int32)
        num_images = len(filenames_by_contrast[arbitrary_contrast])
#         np.random.shuffle(filenames_by_contrast[arbitrary_contrast])
        for bindex in tqdm(range(0, num_images, batch_size), total=num_images):
            filenames = filenames_by_contrast[arbitrary_contrast][bindex:bindex + batch_size]
            for findex, filename in enumerate(filenames):
                for cindex, contrast in enumerate(contrasts):
                    # load raw image batches and normalize the pixels
                    tmp_img = nib.load(filename.replace(arbitrary_contrast, contrast)).get_fdata()
                    try:
                        tmp_img = scaler.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
                    except:
                        print(filename)
                        print(contrast)
                    x_batch[findex, ..., cindex] = crop_3D(tmp_img, shape)
                    # load mask batches and change to categorical
                    tmp_mask = nib.load(filename.replace(arbitrary_contrast, 'seg')).get_fdata()
                    tmp_mask[tmp_mask==4] = 3
                    tmp_mask = crop_3D(tmp_mask, crop_size)
                    tmp_mask = to_categorical(tmp_mask, num_classes = 4)
                    y_batch[findex] = tmp_mask

            if bindex + batch_size > num_images:
                x_batch, y_batch = x_batch[:num_images - bindex], y_batch[:num_images - bindex]
            if augment_size is not None:
                # x_aug, y_aug = augment(x_batch, y_batch, augment_size)
                x_aug = None
                y_aug = None
                yield np.append(x_batch, x_aug), np.append(y_batch, y_aug)
            else:
                yield x_batch, y_batch
        if not infinite:
            break


In [5]:
tumours = ['LGG','HGG']

# prefix = '/Users/jasonfung/Documents/EECE571' # Jason's Macbook
prefix = 'C:/Users/Fungj/Documents/EECE_571F' # Jason's Desktop
brats_dir = '/MICCAI_BraTS_2018_Data_Training/'
# prefix = '/home/atom/Documents/datasets/brats' # Adam's Station
file_pattern = '{prefix}/MICCAI_BraTS_2018_Data_Training/{tumour}/{patient_id}/{patient_id}_{contrast}.nii.gz'
# patient_id = 'Brats18_TCIA09_620_1'
contrasts = ['t1ce', 'flair', 't2']
tumours = ['LGG', 'HGG']

data_list_LGG = os.listdir(os.path.join(prefix+brats_dir,tumours[0]))
data_list_HGG = os.listdir(os.path.join(prefix+brats_dir,tumours[1]))
dataset_file_list = data_list_HGG + data_list_LGG

# shuffle and split the dataset file list
import random
random.seed(42)
file_list_shuffled = dataset_file_list.copy()
random.shuffle(file_list_shuffled)
test_ratio = 0.2

train_file, test_file = file_list_shuffled[0:int(len(file_list_shuffled)*(1-test_ratio))], file_list_shuffled[int(len(file_list_shuffled)*(1-test_ratio)):]

while '.DS_Store' in train_file:
    train_file.remove('.DS_Store')
while '.DS_Store' in test_file:
    test_file.remove('.DS_Store')

In [6]:
from keras.metrics import MeanIoU
from tensorflow.keras.utils import to_categorical

batch_size = 1
test_datagen = test_generate_brats_batch(file_pattern, contrasts, batch_size = batch_size, patient_ids = test_file, crop_size= (128,128,128)) # first iteration

# predict on generator
n_classes = 4
IOU_keras = MeanIoU(num_classes=n_classes)



i = 0

while i < len(test_file)//batch_size:
    i += 1
    test_image_batch, test_mask_batch = test_datagen.__next__()
    
    test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)
    test_pred_batch = my_model.predict(test_image_batch)
    test_pred_batch_argmax = np.argmax(test_pred_batch, axis=4)
    IOU_keras.reset_state()
    IOU_keras.update_state(test_pred_batch_argmax, test_mask_batch_argmax)
    print("Mean IoU =", IOU_keras.result().numpy())


  2%|█▍                                                                                 | 1/57 [00:06<06:18,  6.75s/it]

Mean IoU = 0.4278142


  4%|██▉                                                                                | 2/57 [00:08<03:20,  3.64s/it]

Mean IoU = 0.70331657


  5%|████▎                                                                              | 3/57 [00:09<02:24,  2.67s/it]

Mean IoU = 0.7418028


  7%|█████▊                                                                             | 4/57 [00:11<02:01,  2.30s/it]

Mean IoU = 0.45996386


  9%|███████▎                                                                           | 5/57 [00:12<01:44,  2.00s/it]

Mean IoU = 0.45747167


 11%|████████▋                                                                          | 6/57 [00:14<01:33,  1.83s/it]

Mean IoU = 0.24921918


 12%|██████████▏                                                                        | 7/57 [00:16<01:30,  1.81s/it]

Mean IoU = 0.6771083


 14%|███████████▋                                                                       | 8/57 [00:17<01:24,  1.73s/it]

Mean IoU = 0.46701872


 16%|█████████████                                                                      | 9/57 [00:19<01:18,  1.64s/it]

Mean IoU = 0.4962692


 18%|██████████████▍                                                                   | 10/57 [00:20<01:14,  1.59s/it]

Mean IoU = 0.7244271


 19%|███████████████▊                                                                  | 11/57 [00:22<01:10,  1.54s/it]

Mean IoU = 0.65726376


 21%|█████████████████▎                                                                | 12/57 [00:23<01:08,  1.52s/it]

Mean IoU = 0.76312906


 23%|██████████████████▋                                                               | 13/57 [00:25<01:06,  1.51s/it]

Mean IoU = 0.249789


 25%|████████████████████▏                                                             | 14/57 [00:26<01:04,  1.49s/it]

Mean IoU = 0.54841626


 26%|█████████████████████▌                                                            | 15/57 [00:28<01:02,  1.49s/it]

Mean IoU = 0.36724234


 28%|███████████████████████                                                           | 16/57 [00:29<01:00,  1.49s/it]

Mean IoU = 0.76649714


 30%|████████████████████████▍                                                         | 17/57 [00:30<00:59,  1.49s/it]

Mean IoU = 0.47119817


 32%|█████████████████████████▉                                                        | 18/57 [00:32<00:59,  1.52s/it]

Mean IoU = 0.46410137


 33%|███████████████████████████▎                                                      | 19/57 [00:34<00:57,  1.50s/it]

Mean IoU = 0.5848288


 35%|████████████████████████████▊                                                     | 20/57 [00:35<00:54,  1.49s/it]

Mean IoU = 0.44942355


 37%|██████████████████████████████▏                                                   | 21/57 [00:36<00:53,  1.48s/it]

Mean IoU = 0.58150536


 39%|███████████████████████████████▋                                                  | 22/57 [00:38<00:51,  1.47s/it]

Mean IoU = 0.6951019


 40%|█████████████████████████████████                                                 | 23/57 [00:39<00:49,  1.47s/it]

Mean IoU = 0.47167206


 42%|██████████████████████████████████▌                                               | 24/57 [00:41<00:48,  1.46s/it]

Mean IoU = 0.46550623


 44%|███████████████████████████████████▉                                              | 25/57 [00:42<00:47,  1.48s/it]

Mean IoU = 0.5201626


 46%|█████████████████████████████████████▍                                            | 26/57 [00:44<00:46,  1.49s/it]

Mean IoU = 0.49957895


 47%|██████████████████████████████████████▊                                           | 27/57 [00:45<00:45,  1.52s/it]

Mean IoU = 0.56028414


 49%|████████████████████████████████████████▎                                         | 28/57 [00:47<00:43,  1.51s/it]

Mean IoU = 0.561466


 51%|█████████████████████████████████████████▋                                        | 29/57 [00:48<00:42,  1.51s/it]

Mean IoU = 0.73238176


 53%|███████████████████████████████████████████▏                                      | 30/57 [00:50<00:42,  1.58s/it]

Mean IoU = 0.5074925


 54%|████████████████████████████████████████████▌                                     | 31/57 [00:52<00:40,  1.57s/it]

Mean IoU = 0.63949823


 56%|██████████████████████████████████████████████                                    | 32/57 [00:53<00:40,  1.60s/it]

Mean IoU = 0.46822137


 58%|███████████████████████████████████████████████▍                                  | 33/57 [00:55<00:39,  1.63s/it]

Mean IoU = 0.5393058


 60%|████████████████████████████████████████████████▉                                 | 34/57 [00:57<00:37,  1.63s/it]

Mean IoU = 0.36534587


 61%|██████████████████████████████████████████████████▎                               | 35/57 [00:58<00:35,  1.63s/it]

Mean IoU = 0.7354644


 63%|███████████████████████████████████████████████████▊                              | 36/57 [01:00<00:35,  1.70s/it]

Mean IoU = 0.5735866


 65%|█████████████████████████████████████████████████████▏                            | 37/57 [01:02<00:32,  1.63s/it]

Mean IoU = 0.5052252


 67%|██████████████████████████████████████████████████████▋                           | 38/57 [01:03<00:30,  1.58s/it]

Mean IoU = 0.37056613


 68%|████████████████████████████████████████████████████████                          | 39/57 [01:05<00:27,  1.54s/it]

Mean IoU = 0.3317931


 70%|█████████████████████████████████████████████████████████▌                        | 40/57 [01:06<00:26,  1.55s/it]

Mean IoU = 0.4912234


 72%|██████████████████████████████████████████████████████████▉                       | 41/57 [01:08<00:24,  1.53s/it]

Mean IoU = 0.53128195


 74%|████████████████████████████████████████████████████████████▍                     | 42/57 [01:09<00:23,  1.56s/it]

Mean IoU = 1.0


 75%|█████████████████████████████████████████████████████████████▊                    | 43/57 [01:11<00:21,  1.51s/it]

Mean IoU = 0.2562969


 77%|███████████████████████████████████████████████████████████████▎                  | 44/57 [01:12<00:19,  1.51s/it]

Mean IoU = 0.3333149


 79%|████████████████████████████████████████████████████████████████▋                 | 45/57 [01:14<00:17,  1.49s/it]

Mean IoU = 0.7142801


 81%|██████████████████████████████████████████████████████████████████▏               | 46/57 [01:15<00:16,  1.49s/it]

Mean IoU = 0.4781084


 82%|███████████████████████████████████████████████████████████████████▌              | 47/57 [01:17<00:15,  1.53s/it]

Mean IoU = 0.5708096


 84%|█████████████████████████████████████████████████████████████████████             | 48/57 [01:18<00:14,  1.56s/it]

Mean IoU = 0.26351073


 86%|██████████████████████████████████████████████████████████████████████▍           | 49/57 [01:20<00:12,  1.56s/it]

Mean IoU = 0.46874857


 88%|███████████████████████████████████████████████████████████████████████▉          | 50/57 [01:21<00:10,  1.55s/it]

Mean IoU = 0.7969532


 89%|█████████████████████████████████████████████████████████████████████████▎        | 51/57 [01:23<00:09,  1.53s/it]

Mean IoU = 0.49974775


 91%|██████████████████████████████████████████████████████████████████████████▊       | 52/57 [01:24<00:07,  1.52s/it]

Mean IoU = 0.33890224


 93%|████████████████████████████████████████████████████████████████████████████▏     | 53/57 [01:26<00:06,  1.52s/it]

Mean IoU = 0.66474336


 95%|█████████████████████████████████████████████████████████████████████████████▋    | 54/57 [01:27<00:04,  1.51s/it]

Mean IoU = 0.3220101


 96%|███████████████████████████████████████████████████████████████████████████████   | 55/57 [01:29<00:03,  1.50s/it]

Mean IoU = 0.49507034


 98%|████████████████████████████████████████████████████████████████████████████████▌ | 56/57 [01:31<00:01,  1.55s/it]

Mean IoU = 0.48433536
Mean IoU = 1.0


In [6]:
print(test_file)

['Brats18_CBICA_AQG_1', 'Brats18_TCIA02_321_1', 'Brats18_CBICA_AXM_1', 'Brats18_TCIA09_141_1', 'Brats18_TCIA09_428_1', 'Brats18_TCIA02_473_1', 'Brats18_TCIA02_605_1', 'Brats18_2013_16_1', 'Brats18_CBICA_ATB_1', 'Brats18_TCIA02_274_1', 'Brats18_CBICA_AAB_1', 'Brats18_CBICA_AXW_1', 'Brats18_TCIA10_202_1', 'Brats18_CBICA_ALU_1', 'Brats18_TCIA02_198_1', 'Brats18_TCIA01_231_1', 'Brats18_TCIA05_444_1', 'Brats18_2013_23_1', 'Brats18_TCIA08_319_1', 'Brats18_CBICA_ASH_1', 'Brats18_TCIA02_394_1', 'Brats18_CBICA_AWH_1', 'Brats18_TCIA10_632_1', 'Brats18_CBICA_AXL_1', 'Brats18_CBICA_ABB_1', 'Brats18_CBICA_AYA_1', 'Brats18_CBICA_AAP_1', 'Brats18_TCIA13_633_1', 'Brats18_TCIA03_375_1', 'Brats18_TCIA01_150_1', 'Brats18_CBICA_ATV_1', 'Brats18_TCIA10_640_1', 'Brats18_TCIA03_474_1', 'Brats18_TCIA13_624_1', 'Brats18_CBICA_AUN_1', 'Brats18_2013_13_1', 'Brats18_TCIA02_283_1', 'Brats18_TCIA09_493_1', 'Brats18_TCIA01_186_1', 'Brats18_2013_24_1', 'Brats18_CBICA_AZD_1', 'Brats18_2013_26_1', 'Brats18_TCIA10_625_1

In [11]:
# prefix = '/Users/jasonfung/Documents/EECE571' # Jason's Macbook
# prefix = '/home/atom/Documents/datasets/brats' # Adam's Station
prefix = 'C:/Users/Fungj/Documents/EECE_571F' # Jason's Desktop
brats_dir = '/MICCAI_BraTS_2018_Data_Training/'

contrasts = ['t1ce', 'flair', 't2']

# Get random image
import random

random.seed(42)
img_num = random.randint(0,len(test_file))
name = test_file[img_num]

raw_img = np.empty((batch_size, ) + (128,128,128) + (len(contrasts), ))

if name in os.listdir(prefix+brats_dir+'HGG'):
    # load t2 and preprocess the data
    for cindex, contrast in enumerate(contrasts):
        tmp_img = nib.load(os.path.join(prefix+brats_dir+'HGG'+f'/{name}/{name}_{contrast}.nii.gz')).get_fdata()
        tmp_img = scaler.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
        tmp_img = crop_3D(tmp_img,(128,128,128))
        raw_img[0,...,cindex] = tmp_img
    
    ground_truth_mask = np.int_(nib.load(os.path.join(prefix+brats_dir+'HGG'+f'/{name}/{name}_seg.nii.gz')).get_fdata())
    ground_truth_mask[ground_truth_mask==4] = 3
    ground_truth_mask = crop_3D(ground_truth_mask,(128,128,128))
    
else:
    # load t2 and preprocess the data
    for cindex, contrast in enumerate(contrasts):
        tmp_img = nib.load(os.path.join(prefix+brats_dir+'LGG'+f'/{name}/{name}_{contrast}.nii.gz')).get_fdata()
        tmp_img = scaler.fit_transform(tmp_img.reshape(-1, tmp_img.shape[-1])).reshape(tmp_img.shape)
        tmp_img = crop_3D(tmp_img,(128,128,128))
        raw_img[0,...,cindex] = tmp_img
        
    ground_truth_mask = np.int_(nib.load(os.path.join(prefix+brats_dir+'HGG'+f'/{name}/{name}_seg.nii.gz')).get_fdata())
    ground_truth_mask[ground_truth_mask==4] = 3
    ground_truth_mask = crop_3D(ground_truth_mask,(128,128,128))
    

In [12]:
predict_mask = my_model.predict(raw_img) # perform inference
predict_mask = np.argmax(predict_mask, axis=4)[0,:,:,:]

In [13]:
ground_truth_mask.astype(int)
ground_truth_mask.max()

3

# Visualize MRI Volume

In [14]:
import napari
viewer = napari.view_image(raw_img[0,...,0], name = name + '' + contrasts[0])
viewer.add_image(raw_img[0,...,1], name = name + '' + contrasts[1])
viewer.add_image(raw_img[0,...,2], name = name + '' + contrasts[2])

gt_layer = viewer.add_labels(ground_truth_mask, name = "ground truth")
pred_layer = viewer.add_labels(predict_mask, name = "predicted")