In [None]:
from keras.engine import Model
from keras.layers import Lambda
from keras.layers import Dropout, LeakyReLU, Input, Activation, BatchNormalization, Concatenate, multiply, Flatten
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.optimizers import RMSprop
from keras.regularizers import l1, l2
from keras.losses import mae
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, tqdm_notebook
from keras.utils.vis_utils import plot_model
 
from medpy.io import save as savemha
from skimage import color, img_as_float
from skimage.exposure import adjust_gamma
from sklearn.metrics import classification_report
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Conv2D, MaxPool2D, AveragePooling2D, add
from keras.optimizers import SGD
from keras.layers.advanced_activations import LeakyReLU
from keras.callbacks import EarlyStopping
from keras.initializers import glorot_normal
from keras.models import model_from_json
import json
import matplotlib.image as mpimg
    
    
from __future__ import print_function
from glob import glob
from skimage import io
from errno import EEXIST
from os.path import isdir
from os import makedirs
# import subprocess
import progressbar

import os
from sklearn.feature_extraction.image import extract_patches_2d
from skimage.filters.rank import entropy
from skimage.morphology import disk
from skimage.io import imsave, imread
from skimage.transform import rotate
from skimage.color import rgb2gray
from os.path import basename
import random


In [None]:
progress = progressbar.ProgressBar(widgets=[progressbar.Bar('*', '[', ']'), progressbar.Percentage(), ' '])
np.random.seed(5)  # for reproducibility

def mkdir_p(path):
    """
    mkdir -p function, makes folder recursively if required
    :param path:
    :return:
    """
    try:
        makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == EEXIST and isdir(path):
            pass
        else:
            raise


def normalize(slice_el):
    """
    :param slice_el: image to normalize removing 1% from top and bottom
     of histogram (intensity removal)
    :return: normalized slice
    """

    b = np.percentile(slice_el, 1)
    t = np.percentile(slice_el, 99)
    slice_el = np.clip(slice_el, b, t)
    if np.std(slice_el) == 0:
        return slice_el
    else:
        return (slice_el - np.mean(slice_el)) / np.std(slice_el)


class BrainPipeline(object):
    """
    A class for processing brain scans for one patient
    """

    def __init__(self, path, n4itk=False, n4itk_apply=False):
        """
        :param path: path to directory of one patient. Contains following mha files:
        flair, t1, t1c, t2, ground truth (gt)
        :param n4itk:  True to use n4itk normed t1 scans (defaults to True)
        :param n4itk_apply: True to apply and save n4itk filter to t1 and t1c scans for given patient.
        """
        self.path = path
        self.n4itk = n4itk
        self.n4itk_apply = n4itk_apply
        self.modes = ['flair', 't1', 't1c', 't2', 'gt']
        # slices=[[flair x 155], [t1], [t1c], [t2], [gt]], 155 per modality
        self.slices_by_mode, n = self.read_scans()
        # [ [slice1 x 5], [slice2 x 5], ..., [slice155 x 5]]
        self.slices_by_slice = n
        self.normed_slices = self.norm_slices()

    def read_scans(self):
        """
        goes into each modality in patient directory and loads individual scans.
        transforms scans of same slice into strip of 5 images
        """
        print('Loading scans...')
        slices_by_mode = np.zeros((5, 176, 216, 160))
        slices_by_slice = np.zeros((176, 5, 216, 160))
        flair = glob(self.path + '/*Flair*/*.mha')
        t2 = glob(self.path + '/*_T2*/*.mha')
        gt = glob(self.path + '/*more*/*.mha')
        t1s = glob(self.path + '/**/*T1*.mha')
        t1_n4 = glob(self.path + '/*T1*/*_n.mha')
        t1 = [scan for scan in t1s if scan not in t1_n4]
        scans = [flair[0], t1[0], t1[1], t2[0], gt[0]]  # directories to each image (5 total)
        if self.n4itk_apply:
            print('-> Applyling bias correction...')
            for t1_path in t1:
                self.n4itk_norm(t1_path)  # normalize files
            scans = [flair[0], t1_n4[0], t1_n4[1], t2[0], gt[0]]
        elif self.n4itk:
            scans = [flair[0], t1_n4[0], t1_n4[1], t2[0], gt[0]]
        for scan_idx in range(5):
            # read each image directory, save to self.slices
            print(io.imread(scans[scan_idx], plugin='simpleitk').astype(float).shape)
            print(scans[scan_idx])
            print('*' * 100)
            try:
                slices_by_mode[scan_idx] = io.imread(scans[scan_idx], plugin='simpleitk').astype(float)
            except:
                continue
        for mode_ix in range(slices_by_mode.shape[0]):  # modes 1 thru 5
            for slice_ix in range(slices_by_mode.shape[1]):  # slices 1 thru 155
                slices_by_slice[slice_ix][mode_ix] = slices_by_mode[mode_ix][slice_ix]  # reshape by slice
        return slices_by_mode, slices_by_slice

    def norm_slices(self):
        """
        normalizes each slice in self.slices_by_slice, excluding gt
        subtracts mean and div by std dev for each slice
        clips top and bottom one percent of pixel intensities
        if n4itk == True, will apply n4itk bias correction to T1 and T1c images
        """
        print('Normalizing slices...')
        normed_slices = np.zeros((176, 5, 216, 160))
        for slice_ix in range(176):
            normed_slices[slice_ix][-1] = self.slices_by_slice[slice_ix][-1]
            for mode_ix in range(4):
                normed_slices[slice_ix][mode_ix] = normalize(self.slices_by_slice[slice_ix][mode_ix])
        print ('Done.')
        return normed_slices

    def save_patient(self, reg_norm_n4, patient_num):
        """
        saves png in Norm_PNG directory for normed, Training_PNG for reg
        :param reg_norm_n4:  'reg' for original images, 'norm' normalized images,
         'n4' for n4 normalized images
        :param patient_num: unique identifier for each patient
        :return:
        """
        print('Saving scans for patient {}...'.format(patient_num))
        progress.currval = 0
        if reg_norm_n4 == 'norm':  # saved normed slices
            for slice_ix in progress(range(176)):  # reshape to strip
                strip = self.normed_slices[slice_ix].reshape(1080, 160)
                if np.max(strip) != 0:  # set values < 1
                    strip /= np.max(strip)
                if np.min(strip) <= -1:  # set values > -1
                    strip /= abs(np.min(strip))
                # save as patient_slice.png
                try:
                    io.imsave('Norm_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
                except:
                    mkdir_p('Norm_PNG/')
                    io.imsave('Norm_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
        elif reg_norm_n4 == 'reg':
            # for slice_ix in progress(range(155)):
            for slice_ix in progress(range(176)):
                strip = self.slices_by_slice[slice_ix].reshape(1080, 160)
                if np.max(strip) != 0:
                    strip /= np.max(strip)
                try:
                    io.imsave('Training_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
                except:
                    mkdir_p('Training_PNG/')
                    io.imsave('Training_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
        else:
            for slice_ix in progress(range(176)):  # reshape to strip
                strip = self.normed_slices[slice_ix].reshape(1080, 160)
                if np.max(strip) != 0:  # set values < 1
                    strip /= np.max(strip)
                if np.min(strip) <= -1:  # set values > -1
                    strip /= abs(np.min(strip))
                # save as patient_slice.png
                try:
                    io.imsave('n4_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
                except:
                    mkdir_p('n4_PNG/')
                    io.imsave('n4_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)

    def n4itk_norm(self, path, n_dims=3, n_iters='[20,20,10,5]'):
        """
        writes n4itk normalized image to parent_dir under orig_filename_n.mha
        :param path: path to mha T1 or T1c file
        :param n_dims:  param for n4itk filter
        :param n_iters: param for n4itk filter
        :return:
        """
        output_fn = path[:-4] + '_n.mha'
        # run n4_bias_correction.py path n_dim n_iters output_fn
        subprocess.call('python n4_bias_correction.py ' + path + ' ' + str(n_dims) + ' ' + n_iters + ' ' + output_fn,
                        shell=True)

In [None]:
def save_patient_slices(patients, type):
    '''
    INPUT   (1) list 'patients': paths to any directories of patients to save. for example- glob("Training/HGG/**")
            (2) string 'type': options = reg (non-normalized), norm (normalized, but no bias correction), n4 (bias corrected and normalized)
    saves strips of patient slices to approriate directory (Training_PNG/, Norm_PNG/ or n4_PNG/) as patient-num_slice-num
    '''
    for patient_num, path in enumerate(patients):
        a = BrainPipeline(path)
        a.save_patient(type, patient_num)


def save_labels(fns):
    '''
    INPUT list 'fns': filepaths to all labels
    '''
    progress.currval = 0
    for label_idx in progress(range(len(labels))):
        slices = io.imread(labels[label_idx], plugin = 'simpleitk')
        for slice_idx in range(len(slices)):
            try:
                io.imsave('Labels/{}_{}L.png'.format(label_idx, slice_idx), slices[slice_idx])
            except:
                mkdir_p('Labels/')
                io.imsave('Labels/{}_{}L.png'.format(label_idx, slice_idx), slices[slice_idx])

In [None]:
labels = glob(r'BRATS-2\Image_Data/HG/**/*more*/**.mha')
save_labels(labels)

In [None]:
patients = glob(r'BRATS-2\Image_Data/HG/**')
save_patient_slices(patients, 'reg')
save_patient_slices(patients, 'norm')
save_patient_slices(patients, 'n4')

In [None]:
np.random.seed(5)
def rotate_patch(patch, angle):
    """
    :param patch: patch of size (4, 33, 33)
    :param angle: says how much rotation must be applied
    :return: rotate_patch
    """

    return np.array([rotate(patch[0], angle, resize=False),
                     rotate(patch[1], angle, resize=False),
                     rotate(patch[2], angle, resize=False),
                     rotate(patch[3], angle, resize=False)])


def get_right_order(filename):
    """
    gives a key_value function for a sorted extraction
    :param filename:  path to image
    :return:
    """
    last_part = filename.split('/')[len(filename.split('/')) - 1]
    number_value = last_part[:-4]
    return int(number_value)


class PatchLibrary(object):
    """
    class for creating patches and subpatches from training data to use as input for segmentation models.
    """

    def __init__(self, patch_size=(33, 33), train_data='empty', num_samples=1000, augmentation_angle=0):
        """
        :param patch_size: tuple, size (in voxels) of patches to extract. Use (33,33) for sequential model
        :param train_data: list of filepaths to all training data saved as pngs. images should have shape (5, 216, 160)
        :param num_samples: the number of patches to collect from training data.
        :param augmentation_angle: the angle used for flipping patches(producing more datas)
        """
        if 'empty' in train_data:
            print(" insert a path for path extraction")
            exit(1)
        self.patch_size = patch_size
        if augmentation_angle % 360 != 0:
            self.augmentation_multiplier = int(float(360) / float(augmentation_angle))
        else:
            self.augmentation_multiplier = 1

        self.num_samples = num_samples
        self.augmentation_angle = augmentation_angle % 360

        self.train_data = train_data
        self.h = self.patch_size[0]
        self.w = self.patch_size[1]

    def find_patches(self, class_num, num_patches):
        """
        Helper function for sampling slices with evenly distributed classes
        :param class_num: class to sample from choice of {0, 1, 2, 3, 4}.
        :param num_patches: number of patches to extract
        :return: num_samples patches from class 'class_num' randomly selected.
        """
        h, w = self.patch_size[0], self.patch_size[1]
        patches, labels = [], np.full(num_patches * self.augmentation_multiplier, class_num, 'float')
        print('Finding patches of class {}...'.format(class_num))

        full = False
        start_value_extraction = 0
        if isdir('patches/') and isdir('patches/class_{}/'.format(class_num)):

            # load all patches
            # check if quantity is enough to work
            path_to_patches = sorted(glob('./patches/class_{}/**.png'.format(class_num)),
                                     key=get_right_order)

            for path_index in range(len(path_to_patches)):
                if path_index < num_patches:
                    patch_to_add = rgb2gray(imread(path_to_patches[path_index],
                                                   dtype=float)).reshape(4,
                                                                         self.patch_size[0],
                                                                         self.patch_size[1])

                    for el in range(len(patch_to_add)):
                        if np.max(patch_to_add[el]) > 1:
                            patch_to_add[el] = patch_to_add[el] / np.max(patch_to_add[el])

                    patches.append(patch_to_add)
                    print('*---> patch {} loaded and added '.format(path_index))
                else:
                    full = True
                    break

            if len(path_to_patches) < num_patches:
                # change start_value_extraction
                start_value_extraction = len(path_to_patches)
            else:
                full = True
        else:
            mkdir_p('patches/class_{}'.format(class_num))
        if not full:
            ct = start_value_extraction
            while ct < num_patches:
                print('searching for patch {}...'.format(ct))
                im_path = random.choice(self.train_data)
                fn = basename(im_path)
                try:
                    label = np.array(
                        imread('Labels/' + fn[:-4] + 'L.png'))
                except:
                    continue
                # resample if class_num not in selected slice
                unique, counts = np.unique(label, return_counts=True)
                labels_unique = dict(zip(unique, counts))
                try:
                    if labels_unique[class_num] < 10:
                        continue
                except:
                    continue
                # select centerpix (p) and patch (p_ix)
                img = imread(im_path).reshape(5, 216, 160)[:-1].astype('float')
                p = random.choice(np.argwhere(label == class_num))
                p_ix = (p[0] - (h // 2), p[0] + ((h + 1) // 2), p[1] - (w // 2), p[1] + ((w + 1) // 2))
                patch = np.array([i[p_ix[0]:p_ix[1], p_ix[2]:p_ix[3]] for i in img])

                # resample if patch is empty or too close to edge
                if patch.shape != (4, h, w) or len(np.argwhere(patch == 0)) > (3 * h * w):
                    if class_num == 0 and patch.shape == (4, h, w):
                        pass
                    else:
                        continue

                for slice_el in range(len(patch)):
                    if np.max(patch[slice_el]) != 0:
                        patch[slice_el] /= np.max(patch[slice_el])
                imsave('./patches/class_{}/{}.png'.format(class_num,
                                                          ct),
                       (np.array(patch.reshape((4 * self.patch_size[0], self.patch_size[1])))))
                patches.append(patch)
                print('*---> patch {} saved and added'.format(ct))
                ct += 1


        if self.augmentation_angle != 0:
            print('_*_*_*_ proceed with data augmentation  for class {} _*_*_*_'.format(class_num))
            print()

            if isdir('./patches/class_{}/rotations'.format(
                    class_num)):
                print("rotations folder present ")
            else:
                mkdir_p('./patches/class_{}/rotations'.format(
                    class_num))
                print("rotations folder created")
            for el_index in range(len(patches)):
                for j in range(1, self.augmentation_multiplier):
                    try:
                        patch_rotated = np.array(rgb2gray(imread('./patches/class_{}/'
                                                                 'rotations/{}_{}.png'.format(class_num,
                                                                                              el_index,
                                                                                              self.augmentation_angle * j)),
                                                          dtype=float)).reshape(4,
                                                                                self.patch_size[0],
                                                                                self.patch_size[1])

                        for slice_el in range(len(patch_rotated)):
                            if np.max(patch_rotated[slice_el]) > 1:
                                patch_rotated[slice_el] /= np.max(patch_rotated[slice_el])

                        patches.append(patch_rotated)
                        print('*---> patch {} loaded and added '
                              'with rotation of {} degrees'.format(el_index,
                                                                   self.augmentation_angle * j))
                    except:

                        final_rotated_patch = rotate_patch(np.array(patches[el_index]), self.augmentation_angle * j)
                        patches.append(final_rotated_patch)
                        imsave('./patches/class_{}/'
                               'rotations/{}_{}.png'.format(class_num,
                                                            el_index,
                                                            self.augmentation_angle * j),
                               np.array(final_rotated_patch).reshape(4 * self.patch_size[0], self.patch_size[1]))
                        print(('*---> patch {} saved and added '
                               'with rotation of {} degrees '.format(el_index,
                                                                     self.augmentation_angle * j)))
            print('augmentation done \n')


        return np.array(patches), labels


    def make_training_patches(self, classes=None):
        """
        Creates datas(X) and labels(y) for training CNN
        :param entropy: if True, half of the patches are chosen based on highest entropy area.
        defaults to False.
        :param classes: list of classes to sample from.
         Only change default if entropy is False and balanced_classes is True
        :return:
        datas : patches (num_samples, 4_chan, h, w)
        labels (num_samples,)
        """
        if classes is None:
            classes = [0, 1, 2, 3, 4]
            per_class = self.num_samples // len(classes)
            patches, labels = [], []
            progress.currval = 0
            for i in progress(range(len(classes))):
                p, l = self.find_patches(classes[i], per_class)
                patches.append(p)
                labels.append(l)
            return np.array(patches).reshape(self.num_samples * self.augmentation_multiplier, 4, self.h,
                                             self.w), np.array(labels).reshape(
                self.num_samples * self.augmentation_multiplier)


In [None]:
train_data = glob(r'n4_PNG/**')
patch_extractor = PatchLibrary(train_data=train_data, num_samples=10000,patch_size=(33,33),augmentation_angle = 180)
patches, labels = patch_extractor.make_training_patches()

In [None]:
print(patches.shape,labels.shape)
io.imshow(patches[666,0,:,:])
plt.show()

In [None]:
class Brain_tumor_segmentation_model(object):

    """
    A class for compiling/loading, fitting and saving various models,
     viewing segmented images and analyzing results
    """

    def __init__(self, n_chan=4, loaded_model=False, model_name=None):
        """
        :param model_name: if loaded_model is True load the model name specified
        :param n_chan:number of channels being assessed. defaults to 4
        :param loaded_model: True if loading a pre-existing model. defaults to False
        """
        self.n_chan = n_chan
        self.loaded_model = loaded_model
        self.model = None

        if not self.loaded_model:
            self.model_name = None
            self._make_model()
            self._compile_model()
            print('model for {} ready and compiled, waiting for training'.format(self.model_name))
        else:
            if model_name is None:
                model_to_load = str(raw_input('Which model should I load? '))
            else:
                model_to_load = model_name
            self.model = self.load_model_weights(model_to_load)

    def _make_model(self):
        dropout_rate = 0.2
        
#         model_to_make = Sequential()
        inp = Input(shape=(4, 33, 33))
        model_to_make = Conv2D(64, (3, 3),
                                 kernel_initializer=glorot_normal(),
                                 bias_initializer='zeros',
                                 padding='same',
                                 data_format='channels_first',
                                 input_shape=(4, 33, 33)
                                 )(inp)
#         model_to_make = BatchNormalization()(model_to_make)
        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)
        m1_res = model_to_make              # residual path
                         
        model_to_make = Conv2D(filters=64,
                                 kernel_size=(3, 3),
                                 padding='same',
                                 data_format='channels_first',
                                 input_shape=(64, 33, 33))(model_to_make)
#         model_to_make = BatchNormalization()(model_to_make)
        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)
        model_to_make = Conv2D(filters=64,
                                 kernel_size=(3, 3),
                                 padding='same',
                                 data_format='channels_first',
                                 input_shape=(64, 33, 33))(model_to_make)
#             model_to_make = BatchNormalization()(model_to_make)
                
        model_to_make = add([model_to_make,m1_res])
        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)

        model_to_make = MaxPool2D(pool_size=(3, 3),
                                    strides=(2, 2),
                                    data_format='channels_first',
                                    input_shape=(64, 33, 33))(model_to_make)

        model_to_make = Conv2D(filters=128,
                                 kernel_size=(3, 3),
                                 padding='same',
                                 data_format='channels_first',
                                 input_shape=(64, 16, 16))(model_to_make)
#         model_to_make = BatchNormalization()(model_to_make)

        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)
        m1_res = model_to_make             # residual path

        model_to_make = Conv2D(filters=128,
                                 kernel_size=(3, 3),
                                 padding='same',
                                 data_format='channels_first',
                                 input_shape=(128, 16, 16))(model_to_make)
#         model_to_make = BatchNormalization()(model_to_make)
#         model_to_make = LeakyReLU(alpha=0.333)(model_to_make)

        model_to_make = Conv2D(filters=128,
                                 kernel_size=(3, 3),
                                 padding='same',
                                 data_format='channels_first',
                                 input_shape=(128, 16, 16))(model_to_make)
#             model_to_make = BatchNormalization()(model_to_make)
            
        model_to_make = add([model_to_make,m1_res])
        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)
        
        model_to_make = MaxPool2D(pool_size=(3, 3),
                                    strides=(2, 2),
                                    data_format='channels_first',
                                    input_shape=(128, 16, 16))(model_to_make)
        
        model_to_make = Flatten()(model_to_make)
        model_to_make = Dense(units=256, input_dim=6272)(model_to_make)
        
        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)
        
        model_to_make = Dropout(dropout_rate)(model_to_make)
        
        model_to_make = Dense(units=256, input_dim=256)(model_to_make)
        
        model_to_make = LeakyReLU(alpha=0.333)(model_to_make)
        
        model_to_make = Dropout(dropout_rate)(model_to_make)
        
        model_to_make = Dense(units=5,
                                input_dim=256)(model_to_make)
        
        model_to_make = Activation('softmax')(model_to_make)
        
        self.model = Model(inputs=inp,outputs=model_to_make)
        self.model.summary()

    def _compile_model(self):
        sgd = SGD(lr=3e-3,
                  decay=0,
                  momentum=0.9,
                  nesterov=True)
        print(sgd)
        self.model.compile(optimizer=sgd,
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])
        plot_model(self.model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)
    
        
    @staticmethod
    def load_model_weights(model_name):
        """
        :param model_name: filepath to model and weights, not including extension
        :return: Model with loaded weights. can fit on model using loaded_model=True in fit_model method
        """
        print('Loading model {}'.format(model_name))
        model_to_load = '{}.json'.format(model_name)
        weights = '{}.hdf5'.format(model_name)
        with open(model_to_load,"r+") as f:
            m = f.readline()
        model_comp = model_from_json(json.loads(m))
        model_comp.load_weights(weights)
        print('Done.')
        return model_comp

    def fit_model(self, X_train, y_train):
        """
        :param X_train: list of patches to train on in form (n_sample, n_channel, h, w)
        :param y_train: list of labels corresponding to X_train patches in form (n_sample,)
        :return: Fits specified model
        """

        print(X_train.shape)
        print('*' * 100)
        print(y_train.shape)
        print('*' * 100)
        Y_train = np_utils.to_categorical(y_train, 5)

        shuffle = list(zip(X_train, Y_train))
        np.random.shuffle(shuffle)

        X_train = np.array([shuffle[i][0] for i in range(len(shuffle))])
        Y_train = np.array([shuffle[i][1] for i in range(len(shuffle))])
        EarlyStopping(monitor='val_loss', patience=2, mode='auto')

        n_epochs = 20

        self.model.fit(X_train, Y_train, epochs=n_epochs, batch_size=128, verbose=1)

    def save_model(self, model_name):
        """
        Saves current model as json and weigts as h5df file
        :param model_name: name to save model and weigths under, including filepath but not extension
        :return:
        """
        model_to_save = '{}.json'.format(model_name)
        weights = '{}.hdf5'.format(model_name)
        json_string = self.model.to_json()
        try:
            self.model.save_weights(weights)
        except:
            mkdir_p(model_name)
            self.model.save_weights(weights)

        with open(model_to_save, 'w') as f:
            json.dump(json_string, f)


    def predict_image(self, test_img):
        """
        predicts classes of input image
        :param test_img: filepath to image to predict on
        :return: segmented result
        """
        imgs = mpimg.imread(test_img).astype('float')
        imgs = rgb2gray(imgs).reshape(5, 216, 160)

        plist = []

        # create patches_to_predict from an entire slice
        for img in imgs[:-1]:
            if np.max(img) != 0:
                img /= np.max(img)
            p = extract_patches_2d(img, (33, 33))
            plist.append(p)
        patches_to_predict = np.array(
            list(zip(np.array(plist[0]), np.array(plist[1]), np.array(plist[2]), np.array(plist[3]))))
        # predict classes of each pixel based on model
        full_pred = self.model.predict(patches_to_predict)
        full_pred = np.argmax(full_pred,axis=1)
#         print(full_pred.shape)
        try:
                mkdir_p("./labels/")
                io.imsave("./labels/" + test_img[-5:], imgs[-1])
        except:
                io.imsave("./labels/" + test_img[-5:], imgs[-1])
        fp1 = full_pred.reshape(184, 128)
        return fp1

    def save_segmented_image(self, index, test_img, save=False):
        """
        Creates an image of original brain with segmentation overlay
        :param index: index of image to save
        :param test_img: filepath to test image for segmentation, including file extension
        :param save: If true, shows output image. (defaults to False)
        :return: if show is True, shows image of segmentation results
                 if show is false, returns segmented image.
        """

        segmentation = self.predict_image(test_img)

        img_mask = np.pad(segmentation, (16, 16), mode='edge')
        ones = np.argwhere(img_mask == 1)
        twos = np.argwhere(img_mask == 2)
        threes = np.argwhere(img_mask == 3)
        fours = np.argwhere(img_mask == 4)

        test_im = mpimg.imread(test_img).astype('float')
        test_back = rgb2gray(test_im).reshape(5, 216, 160)[-2]
        gray_img = img_as_float(test_back)

        # adjust gamma of image
        image = adjust_gamma(color.gray2rgb(gray_img), 0.65)
        sliced_image = image.copy()
        red_multiplier = [1, 0.2, 0.2]
        yellow_multiplier = [1, 1, 0.25]
        green_multiplier = [0.35, 0.75, 0.25]
        blue_multiplier = [0, 0.25, 0.9]

        # change colors of segmented classes
        for i in range(len(ones)):
            sliced_image[ones[i][0]][ones[i][1]] = red_multiplier
        for i in range(len(twos)):
            sliced_image[twos[i][0]][twos[i][1]] = green_multiplier
        for i in range(len(threes)):
            sliced_image[threes[i][0]][threes[i][1]] = blue_multiplier
        for i in range(len(fours)):
            sliced_image[fours[i][0]][fours[i][1]] = yellow_multiplier
        
        if save:

            try:
                mkdir_p('./results/')
                io.imsave('./results/result' + '_' + str(index) + '.png', sliced_image)
                savemha(img_mask,'./results/result' + '_' + str(index) + '.mha')
            except:
                io.imsave('./results/result' + '_' + str(index) + '.png', sliced_image)
                savemha(img_mask,'./results/result' + '_' + str(index) + '.mha')
        else:
            return sliced_image
    
    def get_dice_coef(self, test_img, label):
        '''
        Calculate dice coefficient for total slice, tumor-associated slice, advancing tumor and core tumor
        INPUT   (1) str 'test_img': filepath to slice to predict on
                (2) str 'label': filepath to ground truth label for test_img
        '''
        label = "./labels/" + label[-5:]
        segmentation = self.predict_image(test_img)
        seg_full = np.pad(segmentation, (16,16), mode='edge')
        gt = io.imread(label).astype('int')
        gt = np.divide(gt,np.max(gt))
        gt = np.round(gt * 4).astype('int')
        # dice coef of total image
        total = (len(np.argwhere(seg_full == gt)) * 2.) / (2 * 216 * 160)

        def unique_rows(a):
            '''
            helper function to get unique rows from 2D numpy array
            '''
            a = np.ascontiguousarray(a)
            unique_a = np.unique(a.view([('', a.dtype)]*a.shape[1]))
            return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1]))

        # dice coef advancing tumor
        adv_gt = np.argwhere(gt == 4)
#         print(np.unique(gt))
        gt_a, seg_a = [], [] # classification of
        for i in adv_gt:
            gt_a.append(gt[i[0]][i[1]])
            seg_a.append(seg_full[i[0]][i[1]])
        gta = np.array(gt_a)
        sega = np.array(seg_a)
        adv = float(len(np.argwhere(gta == sega))) / len(adv_gt)

        # dice coef core tumor
        noadv_gt = np.argwhere(gt == 3)
        necrosis_gt = np.argwhere(gt == 1)
        live_tumor_gt = np.append(adv_gt, noadv_gt, axis = 0)
        core_gt = np.append(live_tumor_gt, necrosis_gt, axis = 0)
        gt_core, seg_core = [],[]
        for i in core_gt:
            gt_core.append(gt[i[0]][i[1]])
            seg_core.append(seg_full[i[0]][i[1]])
        gtcore, segcore = np.array(gt_core), np.array(seg_core)
        core = len(np.argwhere(gtcore == segcore)) / float(len(core_gt))

        print ('Region_______________________| Dice Coefficient')
        print ('Total Slice__________________| {0:.2f}'.format(total))
        print ('Advancing Tumor______________| {0:.2f}'.format(adv))
        print ('Core Tumor___________________| {0:.2f}'.format(core))
        print (' ')


In [None]:
model = Brain_tumor_segmentation_model()
model.fit_model(patches, labels)

In [None]:
#saving the trained model and weights
model.save_model('residual_model')

In [None]:
#segments the given test images in folder and saves the result
tests = glob(r'experiment/**')
segmented_images = []
for index, slice_img in enumerate(tests):
    segmented_images.append(model.save_segmented_image(index, test_img=slice_img, save=True))

In [None]:
#Cell to find the DICE similarity coefficient between predicted and Ground Truth image
for slice_img in tests:
    print('-----Dice Coefficient for slice {} -----'.format(slice_img[84:]))
    model.get_dice_coef(test_img = slice_img, label = slice_img)

In [None]:
#Cell to segment all 155 slices of a particular patient
tests = sorted(glob(r'experiment\all slices/**'))
segmented_images = []
for index, slice_img in enumerate(tests):
    segmented_images.append(model.save_segmented_image(index, test_img=slice_img, save=True))

In [None]:
# Cell to colour the Ground Truth in training and save it!
tests1 = glob(r'experiment/**')
for index, slice_img in enumerate(tests1):
    imgs = mpimg.imread(slice_img).astype('float')
    imgs = rgb2gray(imgs).reshape(5, 216, 160)
    img_mask = imgs[-1]
    img_mask = np.divide(img_mask,np.max(img_mask))
    img_mask = np.round(img_mask * 4).astype('int')
    ones = np.argwhere(img_mask == 1)
    twos = np.argwhere(img_mask == 2)
    threes = np.argwhere(img_mask == 3)
    fours = np.argwhere(img_mask == 4)

    test_back = imgs[-2]
    # overlay = mark_boundaries(test_back, img_mask)
    gray_img = img_as_float(test_back)

    # adjust gamma of image
    image = adjust_gamma(color.gray2rgb(gray_img), 0.65)
    sliced_image = image.copy()
    red_multiplier = [1, 0.2, 0.2]
    yellow_multiplier = [1, 1, 0.25]
    green_multiplier = [0.35, 0.75, 0.25]
    blue_multiplier = [0, 0.25, 0.9]

    # change colors of segmented classes
    for i in range(len(ones)):
        sliced_image[ones[i][0]][ones[i][1]] = red_multiplier
    for i in range(len(twos)):
        sliced_image[twos[i][0]][twos[i][1]] = green_multiplier
    for i in range(len(threes)):
        sliced_image[threes[i][0]][threes[i][1]] = blue_multiplier
    for i in range(len(fours)):
        sliced_image[fours[i][0]][fours[i][1]] = yellow_multiplier

    try:
        mkdir_p('./results/')
        io.imsave('./results/resultGT' + '_' + str(index) + '.png', sliced_image)
    except:
        io.imsave('./results/resultGT' + '_' + str(index) + '.png', sliced_image)

  .format(dtypeobj_in, dtypeobj_out))


In [None]:
# run this cell to load the saved model and segment the images present in experiment folder
model1 = Brain_tumor_segmentation_model(loaded_model= True,model_name='residual_model')
location = glob(r'.\experiment/**')
for index, slice_img in enumerate(location):
    model1.save_segmented_image(index, test_img=slice_img, save=True)
    model1.get_dice_coef(test_img = slice_img, label = slice_img)