## Readme
We trid to test a variety of models to see their performance and applicability for medical services.

This is simple copy of F. Wang's repository (github.com/woodywff/brats_2019).

We updated the code to operate in **tensorflow 2.2 and windows10**.
Some print() codes were added to see the process.

InstanceNormalization is imported from tensorflow-addons instead of keras_contrib.

During implementation we ran into 'dead kernel'(i.e. Windows fatal exception) when entering the validation steps, and it took us quite a long time to fix. (this problem also arise in ellisdg's code.)

We found tensorflow using two differnt thread for training_data_generator and validation_data_generator  repectively. And the problem was raised by those threads accessing a single h5 data file.
As a walkaround, we forced tensorflow to use only one thread and it worked.

Having not much time for this work, we could only **train ~60 epochs** and could not apply trics such as crossvalidation or TTA (and we did not trained including validation image(20% of total train image), but the trained model was powerful enough to show that it is predicting quite precisely.

With this seg_model_1.3, (which has almost same hyperparameters(e.g. patch size, depth, n_seg_levels) as original code), we eared **mean Dice-scores of 0.69, 0.87 and 0.75 in ET, WT, and TC** respectively (when we submitted prediction with validation dataset in (https://ipp.cbica.upenn.edu/).

We expect the higher scores can be achieved when more dataset and augmentation is applied.
Also, optimizing the model to reduce outlier predictions will be also effective considering **high median dice-scores(0.82, 0.90 and 0.85 respectively)** compared to mean dice-scores.

FYI, We used AMD's Ryzen 3600X and Nvidia's RTX 2070 super (8GB VRAM) for training.

We thank BraTS comunity memebers for their contribution in data preparation, code-sharing and all other major and minor works.

In [1]:
# to prevent access violation error. (by train_generator and validation_generator in different threads)
import tensorflow as tf
tf.config.threading.set_inter_op_parallelism_threads(1)
# tf.config.threading.set_intra_op_parallelism_threads(2)

In [2]:
# from tensorflow.keras.mixed_precision import experimental as mixed_precision
# policy = mixed_precision.Policy('mixed_float16')
# mixed_precision.set_policy(policy)

# print('Compute dtype: %s' % policy.compute_dtype)
# print('Variable dtype: %s' % policy.variable_dtype)

# Model

In [3]:
##### unet3d/metrics.py #####

from functools import partial
from tensorflow.keras import backend as K

def dice_coefficient(y_true, y_pred, smooth=1.):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coefficient_loss(y_true, y_pred):
    return -dice_coefficient(y_true, y_pred)

def weighted_dice_coefficient(y_true, y_pred, axis=(-3, -2, -1), smooth=0.00001):
    """
    Weighted dice coefficient. Default axis assumes a "channels first" data structure
    :param smooth:
    :param y_true:
    :param y_pred:
    :param axis:
    :return:
    """
    return K.mean(2. * (K.sum(y_true * y_pred, axis=axis) + smooth/2)/
                       (K.sum(y_true, axis=axis) + K.sum(y_pred,axis=axis) + smooth))

def weighted_dice_coefficient_loss(y_true, y_pred):
    return -weighted_dice_coefficient(y_true, y_pred)

In [4]:
##### unet3d/model/unet.py #####

import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, BatchNormalization,ReLU, PReLU, Conv3DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.optimizers import Adam

# from unet3d.metrics import dice_coefficient_loss, get_label_dice_coefficient_function, dice_coefficient
# from unet3d.metrics import dice_coefficient_loss, dice_coefficient

K.set_image_data_format("channels_first")

def create_convolution_block(input_layer, n_filters, batch_normalization=False, kernel=(3, 3, 3), activation=ReLU,
                             padding='same', strides=(1, 1, 1), instance_normalization=False):
    """
    :param strides:
    :param input_layer:
    :param n_filters:
    :param batch_normalization:
    :param kernel:
    :param activation: Keras activation layer to use. (default is 'relu')
    :param padding:
    :return:
    """
    layer = Conv3D(n_filters, kernel, padding=padding, strides=strides)(input_layer)
    
    if batch_normalization:
        layer = BatchNormalization(axis=1)(layer)
    elif instance_normalization:
        try:
            from tensorflow_addons.layers import InstanceNormalization
        except ImportError:
            raise ImportError("Install tensorflow_addons in order to use instance normalization")
        layer = InstanceNormalization(axis=1)(layer)
    return activation()(layer)
    
##### unet3d/model/isensee2017.py #####

from functools import partial
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import LeakyReLU, Add, UpSampling3D, Activation, SpatialDropout3D, Conv3D
from tensorflow.keras.optimizers import Adam
# from .unet import create_convolution_block, concatenate
# from ..metrics import weighted_dice_coefficient_loss, dice_coefficient

create_convolution_block = partial(create_convolution_block, activation=LeakyReLU, instance_normalization=True)


def isensee2017_model(input_shape=(4, 128, 128, 128), n_base_filters=16, depth=5, dropout_rate=0.3,
                      n_segmentation_levels=3, n_labels=4, optimizer=Adam, initial_learning_rate=5e-4,
                      loss_function=weighted_dice_coefficient_loss, activation_name="sigmoid",metrics=dice_coefficient):
    """
    This function builds a model proposed by Isensee et al. for the BRATS 2017 competition:
    https://www.cbica.upenn.edu/sbia/Spyridon.Bakas/MICCAI_BraTS/MICCAI_BraTS_2017_proceedings_shortPapers.pdf

    This network is highly similar to the model proposed by Kayalibay et al. "CNN-based Segmentation of Medical
    Imaging Data", 2017: https://arxiv.org/pdf/1701.03056.pdf

    :param input_shape:
    :param n_base_filters:
    :param depth:
    :param dropout_rate:
    :param n_segmentation_levels:
    :param n_labels:
    :param optimizer:
    :param initial_learning_rate:
    :param loss_function:
    :param activation_name:
    :return:
    """
    inputs = Input(input_shape)

    current_layer = inputs
    level_output_layers = list()
    level_filters = list()
    
    for level_number in range(depth):
        n_level_filters = (2**level_number) * n_base_filters #number of filters in each level(depth)
        level_filters.append(n_level_filters) 

        if current_layer is inputs:
            in_conv = create_convolution_block(current_layer, n_level_filters)
        else:
            in_conv = create_convolution_block(current_layer, n_level_filters, strides=(2, 2, 2)) #

        context_output_layer = create_context_module(in_conv, n_level_filters, dropout_rate=dropout_rate)

        summation_layer = Add()([in_conv, context_output_layer])
        level_output_layers.append(summation_layer)
        current_layer = summation_layer
    
    #J.Lee: print(level_filters)
    
    segmentation_layers = list()
    for level_number in range(depth - 2, -1, -1):
        up_sampling = create_up_sampling_module(current_layer, level_filters[level_number])
        concatenation_layer = concatenate([level_output_layers[level_number], up_sampling], axis=1)
        localization_output = create_localization_module(concatenation_layer, level_filters[level_number])
        current_layer = localization_output
        if level_number < n_segmentation_levels:
            segmentation_layers.insert(0, Conv3D(n_labels, (1, 1, 1))(current_layer))

    output_layer = None
    for level_number in reversed(range(n_segmentation_levels)):
        segmentation_layer = segmentation_layers[level_number]
        if output_layer is None:
            output_layer = segmentation_layer
        else:
            output_layer = Add()([output_layer, segmentation_layer])

        if level_number > 0:
            output_layer = UpSampling3D(size=(2, 2, 2))(output_layer)

    activation_block = Activation(activation_name)(output_layer)

    model = Model(inputs=inputs, outputs=activation_block)

    if not isinstance(metrics, list):
        metrics = [metrics]
#     model.compile(optimizer=optimizer(lr=initial_learning_rate), loss=loss_function)        
    model.compile(optimizer=optimizer(epsilon=1e-7, lr=initial_learning_rate), loss=loss_function, metrics=metrics)
    return model


def create_localization_module(input_layer, n_filters):
    convolution1 = create_convolution_block(input_layer, n_filters)
    convolution2 = create_convolution_block(convolution1, n_filters, kernel=(1, 1, 1))
    return convolution2


def create_up_sampling_module(input_layer, n_filters, size=(2, 2, 2)):
    up_sample = UpSampling3D(size=size)(input_layer)
    convolution = create_convolution_block(up_sample, n_filters)
    return convolution


def create_context_module(input_layer, n_level_filters, dropout_rate=0.3, data_format="channels_first"):
    convolution1 = create_convolution_block(input_layer=input_layer, n_filters=n_level_filters)
    dropout = SpatialDropout3D(rate=dropout_rate, data_format=data_format)(convolution1)
    convolution2 = create_convolution_block(input_layer=dropout, n_filters=n_level_filters)
    return convolution2

In [5]:
#model = isensee2017_model()

In [6]:
#model.summary()

# Data

In [7]:
config = dict()
config["overwrite"] = False 
config["all_modalities"] = ["t1", "t1ce", "flair", "t2"]
config["training_modalities"] = config["all_modalities"]

config["image_shape"] = (240,240,155)  # This determines what shape the images will be cropped/resampled to.
config["patch_shape"] = (128, 128, 128)     # switch to None to train on the whole image
config["training_patch_start_offset"] = (4, 4, 4)  # randomly offset the first patch index by up to this offset
config["validation_patch_overlap"] = 32

config["data_file"] = 'C:/IAMEDIC/Jaeho_code/data/data_N4_norm.h5' #os.path.abspath("../data/data.h5")
config["image_shape"] = (240,240,155)
config['mean_std_file'] =  'C:/IAMEDIC/Jaeho_code/data/mean_std.pkl' #os.path.abspath('../data/mean_std.pkl')

config['val_data_file'] = 'C:/IAMEDIC/Jaeho_code/data/val_data.h5' #os.path.abspath("../data/val_data.h5")
config['val_predict_dir'] = 'C:/IAMEDIC/Jaeho_code/prediction/val_prediction' #os.path.abspath("val_prediction")
config['val_index_list'] = 'C:/IAMEDIC/Jaeho_code/data/val_index_list.pkl' #os.path.abspath('../data/val_index_list.pkl')

In [8]:
##### dev_tools/my_tools.py #####
def print_red(something):
    print("\033[1;31m{}\033[0m".format(something))
def pad_image(img_npy, target_image_shape):
    '''
    image: ndarray
    target_image_shape: tuple or list
    '''
    source_shape = np.asarray(img_npy.shape)
    target_image_shape = np.asarray(target_image_shape)
    edge = (target_image_shape - source_shape)/2
    pad_width = tuple((i,j) for i,j in zip(np.floor(edge).astype(int),np.ceil(edge).astype(int)))
    padded_img = np.pad(img_npy,pad_width,'constant',constant_values=0)
    return padded_img, pad_width

def sec2hms(seconds):    
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)
    d, h = divmod(h, 24)
    return str(int(d))+' days, '+str(int(h))+' hours, '+str(int(m))+' mins, '+str(round(s,3))+' secs.'
#     print("%d:%02d:%02d" % (h, m, s))

In [9]:
##### brats_19/demo_task1/preprocess.py #####

import glob
import os
import warnings
import shutil

import SimpleITK as sitk
import numpy as np
from nipype.interfaces.ants import N4BiasFieldCorrection

import pdb
#from train_model import config


from tqdm import tqdm
#from dev_tools.my_tools import print_red


def get_image(subject_folder, name):
    file_card = os.path.join(subject_folder, "*" + name + ".nii.gz")
    try:
        return glob.glob(file_card)[0]
    except IndexError:
        raise RuntimeError("Could not find file matching {}".format(file_card))
    return

def correct_bias(in_file, out_file, image_type=sitk.sitkFloat64):
    """
    Corrects the bias using ANTs N4BiasFieldCorrection. If this fails, will then attempt to correct bias using SimpleITK
    :param in_file: input file path
    :param out_file: output file path
    :return: file path to the bias corrected image
    """
    correct = N4BiasFieldCorrection()
    correct.inputs.input_image = in_file
    correct.inputs.output_image = out_file
    try:
        done = correct.run()
        return done.outputs.output_image
    except IOError:
        warnings.warn(RuntimeWarning("ANTs N4BIasFieldCorrection could not be found."
                                     "Will try using SimpleITK for bias field correction"
                                     " which will take much longer. To fix this problem, add N4BiasFieldCorrection"
                                     " to your PATH system variable. (example: EXPORT PATH=${PATH}:/path/to/ants/bin)"))
        input_image = sitk.ReadImage(in_file, image_type)
        output_image = sitk.N4BiasFieldCorrection(input_image, input_image > 0)
        sitk.WriteImage(output_image, out_file)
        return os.path.abspath(out_file)


def normalize_image(in_file, out_file, bias_correction=True):
    if not os.path.exists(out_file):
        if bias_correction:
            correct_bias(in_file, out_file)
        else:
            shutil.copy(in_file, out_file)
    return out_file

def check_origin(in_file, in_file2):
    """
    check origin of in_file1 and in_file2
    if origins are not same, in_file1's origin will be overwritten with in_file2's origin
    """
    image = sitk.ReadImage(in_file)
    image2 = sitk.ReadImage(in_file2)
    if not image.GetOrigin() == image2.GetOrigin(): 
        image.SetOrigin(image2.GetOrigin())
        sitk.WriteImage(image, in_file)

def convert_brats_folder(in_folder, out_folder, truth_name='seg', no_bias_correction_modalities=None, bias_correct=True):
#     pdb.set_trace()
    for name in config["all_modalities"]:
        try:
            image_file = get_image(in_folder, name)
        except RuntimeError as error:
            if name == 't1ce':
                print_red(in_folder)
                image_file = get_image(in_folder, 't1Gd')
                truth_name = "GlistrBoost_ManuallyCorrected"
            else:
                raise error

        out_file = os.path.abspath(os.path.join(out_folder, name + ".nii.gz"))
        
        if bias_correct:
            perform_bias_correction = no_bias_correction_modalities and name not in no_bias_correction_modalities
            normalize_image(image_file, out_file, bias_correction=perform_bias_correction)
        else:
            if not os.path.exists(out_file):
                shutil.copy(image_file, out_file)
    
    # copy the truth file only for training dataset
    if in_folder.split('/')[-2] == 'val':
        return
    try:
        truth_file = get_image(in_folder, truth_name)
    except RuntimeError:
        truth_file = get_image(in_folder, truth_name.split("_")[0])

    out_file = os.path.abspath(os.path.join(out_folder, "truth.nii.gz"))
    if not os.path.exists(out_file):
        shutil.copy(truth_file, out_file)
    check_origin(out_file, get_image(in_folder, config["all_modalities"][0]))
    
    return

def convert_brats_data(brats_folder, out_folder, bias_correct=True, overwrite=True, no_bias_correction_modalities=("flair",)):
    """
    Preprocesses the BRATS data and writes it to a given output folder. 
    :param brats_folder: folder containing the original brats data
    :param out_folder: output folder to which the preprocessed data will be written
    :param bias_correct: if False, just copy the original images to preprocessed folders.
    :param overwrite: set to True in order to redo all the preprocessing
    :param no_bias_correction_modalities: performing bias correction could reduce the signal of certain modalities. If
    concerned about a reduction in signal for a specific modality, specify by including the given modality in a list
    or tuple.
    :return:
    """
#     pdb.set_trace()
    
    for subject_folder in tqdm(glob.glob(os.path.join(brats_folder, "*", "*"))):
#         continue
        if os.path.isdir(subject_folder):
            subject = os.path.basename(subject_folder)
            new_subject_folder = os.path.join(out_folder, os.path.basename(os.path.dirname(subject_folder)),
                                              subject)
            if not os.path.exists(new_subject_folder) or overwrite:
                if not os.path.exists(new_subject_folder):
                    os.makedirs(new_subject_folder)
                convert_brats_folder(subject_folder, new_subject_folder,
                                     no_bias_correction_modalities=no_bias_correction_modalities,bias_correct=bias_correct)
        else:
            print(subject_folder)

    return

In [10]:
##### dev_tools/my_tools.py #####
# from dev_tools.my_tools import minmax_normalize

def minmax_normalize(img_npy):
    '''
    img_npy: ndarray
    '''
    min_value = np.min(img_npy)
    max_value = np.max(img_npy)
    return (img_npy - min_value)/(max_value - min_value)

##### unet3d/normaize.py #####
from progressbar import *

def normalize_data_storage(data_storage, offset=0.1, mul_factor=100, save_file='../data/mean_std.pkl'):
    '''
    data_storage is modality_storage_list
    1. -mean/std(all nonzero voxels(brain area) of all images for the same modality)
    2. minmax(each image individually)
    offset and mul_factor are used to make brain voxel distinct from background zero points.
    '''
#     pdb.set_trace()
    print('normalize_data_storage...')
    mean_std_values = {}
    for modality_storage in data_storage:
        means = []
        pbar = ProgressBar().start()
        print('calculate mean value...')
        n_subs = modality_storage.shape[0]
        for i in range(n_subs):
            means.append(np.mean(np.ravel(modality_storage[i])[np.flatnonzero(modality_storage[i])]))
            pbar.update(int(i*100/(n_subs-1)))
        pbar.finish()
        mean = np.mean(means)
        mean_std_values[modality_storage.name + '_mean'] = mean 
        print('mean=',mean)
        
        std_means = []
        pbar = ProgressBar().start()
        print('calculate std value...')
        for i in range(n_subs):
            std_means.append(np.mean(np.power(np.ravel(modality_storage[i])[np.flatnonzero(modality_storage[i])]-mean,2)))
            pbar.update(int(i*100/(n_subs-1)))
        pbar.finish()
        std = np.sqrt(np.mean(std_means))
        mean_std_values[modality_storage.name + '_std'] = std
        print('std=',std)
        
#         pdb.set_trace()
        for i in range(n_subs):
            brain_index = np.nonzero(modality_storage[i])
            temp_img = np.copy(modality_storage[i])
            temp_img[brain_index] = (minmax_normalize((modality_storage[i][brain_index]-mean)/std) + offset)*mul_factor
            modality_storage[i] = temp_img
    print('normalization FINISHED')
    with open(save_file,'wb') as f:
        pickle.dump(mean_std_values,f)

def normalize_data_storage_val(data_storage, offset=0.1, mul_factor=100, save_file='../data/mean_std.pkl'):
    print('normalize validation data storage...')
    if not os.path.exists(save_file):
        print_red('There\'s no mean_std.pkl file.')
        return
    with open(save_file,'rb') as f:
        mean_std_values = pickle.load(f)
    for modality_storage in data_storage:
        n_subs = modality_storage.shape[0]
        mean = mean_std_values[modality_storage.name + '_mean']
        std = mean_std_values[modality_storage.name + '_std']
        
#         pdb.set_trace()
        for i in tqdm(range(n_subs)):
            brain_index = np.nonzero(modality_storage[i])
            temp_img = np.copy(modality_storage[i])
            temp_img[brain_index] = (minmax_normalize((modality_storage[i][brain_index]-mean)/std) + offset)*mul_factor
            modality_storage[i] = temp_img
    print('normalization FINISHED')
    return


##### unet3d/data.py #####
# uc: unchanged
import os

import numpy as np
import tables
import nibabel as nib
from tqdm import tqdm
# from .normalize import normalize_data_storage

import sys
sys.path.append('..')
# from dev_tools.my_tools import pad_image, print_red


def create_data_file(out_file, n_samples, image_shape, modality_names):
    '''
    create storage in data.h5
    
    :param: out_file      : directory path of the h5 file to be generated
    :param: n_samples     : number of samples. e.g: 1
    :param: image_shape   : e.g: (155,240,240)
    :param: modality_names: 
    :
    '''
#     pdb.set_trace()
    hdf5_file = tables.open_file(out_file, mode='w')
    filters = tables.Filters(complevel=5, complib='blosc')
    modality_shape = tuple([0, 1] + list(image_shape)) # e.g. (0,1,155,240,240)
    truth_shape =    tuple([0, 1] + list(image_shape)) # e.g. (0,1,155,240,240)
    brain_width_shape = (0,2,3)
    
    
    modality_storage_list = [hdf5_file.create_earray(hdf5_file.root, modality_name, tables.Float32Atom(), shape=modality_shape,
                             filters=filters, expectedrows=n_samples) for modality_name in modality_names]
    
    truth_storage = hdf5_file.create_earray(hdf5_file.root, 'truth', tables.UInt8Atom(), shape=truth_shape,
                                            filters=filters, expectedrows=n_samples)
    
    brain_width_storage = hdf5_file.create_earray(hdf5_file.root, 'brain_width', tables.UInt8Atom(), shape=brain_width_shape,
                                            filters=filters, expectedrows=n_samples)
    tumor_width_storage = hdf5_file.create_earray(hdf5_file.root, 'tumor_width', tables.UInt8Atom(), shape=brain_width_shape,
                                            filters=filters, expectedrows=n_samples)
    
    return hdf5_file, [modality_storage_list, truth_storage, brain_width_storage, tumor_width_storage]



def write_image_data_to_file(image_files, storage_list,
                             image_shape, modality_names, truth_dtype=np.uint8, trivial_check = True):
    '''
    1. check the compliance of h5 files's modality order with modality_names argument.
    2. 
    3. 
    4. 
    5. 
    :param: image_files   : 
    :param: storage_list  : 
    :param: image_shape   : 
    :param: modality_names: 
    :param: truth_dtype   : 
    :param: trivial_check : 
        to see if all images share the same affine info and pad_width, the incompliance file names 
        would be printed in red lines.
        Also to check the order of modalities when added to the .h5
    '''
#     pdb.set_trace()

    # 0.? affine settings ?: 
    # J.Lee: 
    # it seems affine_0 and save_affine could be assigned as argument in previous versions. 
    # that seems like the reason why 'trivial_check' in included in the code.
    # affine argument seems to be required when 'augment=True'
    affine_0 = None
    save_affine = True
    print('write_image_data_to_file...')
    
    #J.Lee_start
    print(len(image_files))
    print(f'image_files will be loaded from image_files \n image_files[0]:{image_files[0]}')
    #J.Lee_end
    
    # 1. check compliance of h5 files's modality order with modality_names argument.
    for set_of_files in tqdm(image_files):
        if trivial_check:
            if not [os.path.basename(img_file).split('.')[0] for img_file in set_of_files] == modality_names + ['truth']:
                print('wrong order of modalities')
                print_red(image_nii_path)
        subject_data = []
        brain_widths = []
        for i, image_nii_path in enumerate(set_of_files):
            img = nib.load(image_nii_path)
            affine = img.affine #J.Lee: all nii file has its affine value.
            if affine_0 is None:
                affine_0 = affine
#             if trivial_check:
#                 if np.sum(affine_0 - affine):
#                     print('affine incompliance:')
#                     print_red(image_nii_path)
#                     save_affine = False
            img_npy = img.get_fdata()
            subject_data.append(img_npy)
            
            if i < len(set_of_files)-1: # we don't calculate brain_width for truth
                brain_widths.append(cal_outline(img_npy))
            else:
                tumor_width = cal_outline(img_npy)
                
        start_edge = np.min(brain_widths,axis=0)[0]
        end_edge = np.max(brain_widths,axis=0)[1]
        brain_width = np.vstack((start_edge,end_edge))
        
        if add_data_to_storage(storage_list,
                               subject_data, brain_width, tumor_width, truth_dtype, modality_names = modality_names):
            print_red('modality_storage.name != modality_name')
            print_red(set_of_files)
    print('write_image_data_to_file...FINISHED')
    if save_affine:
        np.save('affine_N4_norm',affine_0)
    return 


def cal_outline(img_npy):
    '''
    return a (2,3) array indicating the outline
    J.Lee: i.e. highest and lowest coordinate of non-zero region will be ruturned.
    '''
    # J.Lee
    # np.nonzero returns 3d coordinates with non-zero values.
    # when (1,3),(1,2),(4,2) are returned, it means only (1,1,4) and (3,2,2) coordinates have non-zero values.
    brain_index = np.asarray(np.nonzero(img_npy)) 
    start_edge = np.maximum(np.min(brain_index,axis=1)-1,0) #
    end_edge = np.minimum(np.max(brain_index,axis=1)+1,img_npy.shape)
    
    return np.vstack((start_edge,end_edge))


def add_data_to_storage(storage_list,
                        subject_data, brain_width, tumor_width, truth_dtype, modality_names):
#     pdb.set_trace()
    modality_storage_list,truth_storage,brain_width_storage,tumor_width_storage = storage_list
    for i in range(len(modality_names)):
        if modality_storage_list[i].name != modality_names[i]:
            print_red('modality_storage.name != modality_name')
            return 1
        modality_storage_list[i].append(np.asarray(subject_data[i])[np.newaxis][np.newaxis])
    if truth_storage.name != 'truth':
        print_red('truth_storage.name != truth')
        return 1
    truth_storage.append(np.asarray(subject_data[-1], dtype=truth_dtype)[np.newaxis][np.newaxis])
    brain_width_storage.append(np.asarray(brain_width, dtype=truth_dtype)[np.newaxis])
    tumor_width_storage.append(np.asarray(tumor_width, dtype=truth_dtype)[np.newaxis])
    return 0

def write_data_to_file(training_data_files, out_file, image_shape, modality_names, truth_dtype=np.uint8, subject_ids=None,
                       normalize=True, mean_std_file='../data/mean_std.pkl'):
#     pdb.set_trace()
    n_samples = len(training_data_files)

    hdf5_file, storage_list = create_data_file(out_file,
                                               n_samples=n_samples,
                                               image_shape=image_shape,
                                               modality_names = modality_names)
    modality_storage_list = storage_list[0]
    write_image_data_to_file(training_data_files, 
                             storage_list,
                             image_shape, truth_dtype=truth_dtype, modality_names = modality_names)
    if subject_ids:
        hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
    if normalize:
        normalize_data_storage(modality_storage_list, save_file=mean_std_file)
    hdf5_file.close()
    return out_file


def open_data_file(filename, readwrite="r"):
    return tables.open_file(filename, readwrite)



In [11]:
##### demo_task1/train_model.py #####

import os
import glob
import pickle

# from unet3d.data import write_data_to_file, open_data_file
# from unet3d.generator import get_training_and_validation_generators
# from unet3d.model import isensee2017_model
# from unet3d.training import load_old_model, train_model
import pdb
import time
#from dev_tools.my_tools import sec2hms

def fetch_training_data_files(return_subject_ids=False):
    import os
    os.path.sep = '/'    
    training_data_files = list()
    subject_ids = list()
    for subject_dir in glob.glob(os.path.join("C:/IAMEDIC/Jaeho_code/data", "preprocessed_N4", "*", "*")):
    #for subject_dir in glob.glob("C:/IAMEDIC/Jaeho_code/data/preprocessed/*/*"): #J.Lee
        #subject_dir = '/'.join(subject_dir.split('\\')) #J.Lee:
        subject_ids.append(os.path.basename(subject_dir))
        subject_files = list()
        for modality in config["training_modalities"] + ["truth"]:
            subject_files.append(os.path.join(subject_dir, modality + ".nii.gz"))
        subject_files = ['/'.join(i.split('\\')) for i in subject_files] #J.Lee
        training_data_files.append(tuple(subject_files))
    if return_subject_ids:
        return training_data_files, subject_ids
    else:
        return training_data_files

In [12]:
print(fetch_training_data_files(return_subject_ids=True)[0][0])
print(fetch_training_data_files(return_subject_ids=True)[1][0])

('C:/IAMEDIC/Jaeho_code/data/preprocessed_N4/HGG/BraTS19_2013_10_1/t1.nii.gz', 'C:/IAMEDIC/Jaeho_code/data/preprocessed_N4/HGG/BraTS19_2013_10_1/t1ce.nii.gz', 'C:/IAMEDIC/Jaeho_code/data/preprocessed_N4/HGG/BraTS19_2013_10_1/flair.nii.gz', 'C:/IAMEDIC/Jaeho_code/data/preprocessed_N4/HGG/BraTS19_2013_10_1/t2.nii.gz', 'C:/IAMEDIC/Jaeho_code/data/preprocessed_N4/HGG/BraTS19_2013_10_1/truth.nii.gz')
BraTS19_2013_10_1


In [13]:

##### from preprocess.py #####

"""
cascade of functions:
convert_brats_data -> convert_brats_folder
"""
# convert_brats_data('C:/IAMEDIC/Jaeho_code/data/original',
#                    'C:/IAMEDIC/Jaeho_code/data/preprocessed',
#                    no_bias_correction_modalities=['flair', 't1', 't1ce', 't2'] )

##### demo_task1/train_model.py #####
# from main()
# convert input images into an hdf5 file

"""
J.Lee:
cascade of functions:
main() 
    -> fetch_training_data_files()
    -> write_data_to_file()
        -> create_data_file()
            ->hdf5_file = tables.open_file()
            ->hdf5_file.create_array()
        -> write_image_data_to_file(): 
            -> add_data_to_storage()
        -> if normalize: normalize_data_storage()
"""

# overwrite = True
overwrite = False

if overwrite or not os.path.exists(config["data_file"]):
    training_files, subject_ids = fetch_training_data_files(return_subject_ids=True)
    #J.Lee_start
    print("training_files[0]:") 
    for i in training_files[0]: print(i)
    print("training_files[-1]:")
    for i in training_files[-1]: print(i)
    #J.Lee_end
    write_data_to_file(training_files, 
                       config["data_file"],
                       image_shape=config["image_shape"], 
                       modality_names = config['all_modalities'],
                       subject_ids=subject_ids,
                       mean_std_file = config['mean_std_file'],
                       normalize = True )
    
# data_file_opened = open_data_file(config["data_file"])

# Generator

In [12]:
# from unet3d/utils.py

def pickle_dump(item, out_file):
    with open(out_file, "wb") as opened_file:
        pickle.dump(item, opened_file)


def pickle_load(in_file):
    with open(in_file, "rb") as opened_file:
        return pickle.load(opened_file)

In [13]:
# from unet3d/patches.py

def compute_patch_indices(image_shape, patch_size, overlap, start=None):
#     pdb.set_trace()
    if isinstance(overlap, int):
        overlap = np.asarray([overlap] * len(image_shape))
    if start is None: # this method gets an even distribution of cubics as I wished
        n_patches = np.ceil(image_shape / (patch_size - overlap))
        overflow = (patch_size - overlap) * n_patches - image_shape + overlap
        start = -np.ceil(overflow/2)
    elif isinstance(start, int):
        start = np.asarray([start] * len(image_shape))
    stop = image_shape + start
    step = patch_size - overlap
    patches = get_set_of_patch_indices(start, stop, step)
    # add the center cubic:
    patches = np.vstack((patches, (image_shape - patch_size)//2))
    return patches

def compute_patch_indices_for_prediction(image_shape, patch_size, center_patch=True):
#     pdb.set_trace()
    pdb_set = False
    if pdb_set:
        if np.any(np.array(2*np.array(patch_size) - np.array(image_shape))<=0):
            print_red('error patch: too large')
        if  np.any(np.array(image_shape-patch_size)<=0):
            print_red('error patch: too small')
    start_2 = np.asarray(image_shape - patch_size)
    start_2[start_2 < 0] = 0
    patches = np.array([[0,         0,         0         ],
                        [start_2[0],0,         0         ],
                        [0,         start_2[1],0         ],
                        [0,         0,         start_2[2]],
                        [start_2[0],start_2[1],0         ],
                        [start_2[0],start_2[1],start_2[2]],
                        [start_2[0],0,         start_2[2]],
                        [0,         start_2[1],start_2[2]]])
    if center_patch:
        patches = np.vstack((patches, (image_shape - patch_size)//2))
    return patches


def get_set_of_patch_indices(start, stop, step):
#     pdb.set_trace()
    return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
                               start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int)


def get_random_nd_index(index_max):
    return tuple([np.random.choice(index_max[index] + 1) for index in range(len(index_max))])


def get_patch_from_3d_data(data, patch_shape, patch_index):
    """
    Returns a patch from a numpy array.
    :param data: numpy array from which to get the patch.
    :param patch_shape: shape/size of the patch.
    :param patch_index: corner index of the patch.
    :return: numpy array take from the data with the patch shape specified.
    """
    patch_index = np.asarray(patch_index, dtype=np.int16)
    patch_shape = np.asarray(patch_shape)
    image_shape = data.shape[-3:]
    if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
        data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)
    return data[..., patch_index[0]:patch_index[0]+patch_shape[0], patch_index[1]:patch_index[1]+patch_shape[1],
                patch_index[2]:patch_index[2]+patch_shape[2]]

def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3):
    """
    Pads the data and alters the patch index so that a patch will be correct.
    :param data:
    :param patch_shape:
    :param patch_index:
    :return: padded data, fixed patch index
    """
    image_shape = data.shape[-ndim:]
    pad_before = np.abs((patch_index < 0) * patch_index)
    pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
    pad_args = np.stack([pad_before, pad_after], axis=1)
    if pad_args.shape[0] < len(data.shape):
        pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
#     data = np.pad(data, pad_args, mode="edge")
    data = np.pad(data, pad_args, 'constant',constant_values=0)
    patch_index += pad_before
    return data, patch_index

In [14]:
##### from unet3d/augment.py #####
import random
import itertools
import numpy as np
import nibabel as nib
from nilearn.image import new_img_like, resample_to_img

def scale_image(image, scale_factor):
    scale_factor = np.asarray(scale_factor)
    new_affine = np.copy(image.affine)
    new_affine[:3, :3] = image.affine[:3, :3] * scale_factor
    new_affine[:, 3][:3] = image.affine[:, 3][:3] + (image.shape * np.diag(image.affine)[:3] * (1 - scale_factor)) / 2
    return new_img_like(image, data=image.get_fdata(), affine=new_affine)


def flip_image(image, axis):
    try:
        new_data = np.copy(image.get_fdata())
        for axis_index in axis:
            new_data = np.flip(new_data, axis=axis_index)
    except TypeError:
        new_data = np.flip(image.get_fdata(), axis=axis)
    return new_img_like(image, data=new_data)


def random_flip_dimensions(n_dimensions):
    axis = list()
    for dim in range(n_dimensions):
        if random_boolean():
            axis.append(dim)
    return axis


def random_scale_factor(n_dim=3, mean=1, std=0.25):
    return np.random.normal(mean, std, n_dim)


def random_boolean():
    return np.random.choice([True, False])

def distort_image(image, flip_axis=None, scale_factor=None):
    if flip_axis:
        image = flip_image(image, flip_axis)
    if scale_factor is not None:
        image = scale_image(image, scale_factor)
    return image

def get_image(data, affine, nib_class=nib.Nifti1Image):
    return nib_class(dataobj=data, affine=affine)

def augment_data(data, truth, affine, scale_deviation=None, flip=True):
    n_dim = len(truth.shape)
    if scale_deviation:
        scale_factor = random_scale_factor(n_dim, std=scale_deviation)
    else:
        scale_factor = None
    if flip:
        flip_axis = random_flip_dimensions(n_dim)
    else:
        flip_axis = None
    data_list = list()
    for data_index in range(data.shape[0]):
        image = get_image(data[data_index], affine)
        data_list.append(resample_to_img(distort_image(image, flip_axis=flip_axis,
                                                       scale_factor=scale_factor), image,
                                         interpolation="nearest").get_fdata())
#                                          interpolation="continuous").get_fdata())
    data = np.asarray(data_list)
    truth_image = get_image(truth, affine)
    truth_data = resample_to_img(distort_image(truth_image, flip_axis=flip_axis, scale_factor=scale_factor),
                                 truth_image, interpolation="nearest").get_fdata()
    return data, truth_data


def generate_permutation_keys():
    """
    This function returns a set of "keys" that represent the 48 unique rotations &
    reflections of a 3D matrix.

    Each item of the set is a tuple:
    ((rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)

    As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
    rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
    transposed.

    48 unique rotations & reflections:
    https://en.wikipedia.org/wiki/Octahedral_symmetry#The_isometries_of_the_cube
    """
    return set(itertools.product(
        itertools.combinations_with_replacement(range(2), 2), range(2), range(2), range(2), range(2)))


def random_permutation_key():
    """
    Generates and randomly selects a permutation key. See the documentation for the
    "generate_permutation_keys" function.
    """
    return random.choice(list(generate_permutation_keys()))

def permute_data(data, key):
    """
    Permutes the given data according to the specification of the given key. Input data
    must be of shape (n_modalities, x, y, z).

    Input key is a tuple: (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose)

    As an example, ((0, 1), 0, 1, 0, 1) represents a permutation in which the data is
    rotated 90 degrees around the z-axis, then reversed on the y-axis, and then
    transposed.
    """
    data = np.copy(data)
    (rotate_y, rotate_z), flip_x, flip_y, flip_z, transpose = key

    if rotate_y != 0:
        data = np.rot90(data, rotate_y, axes=(1, 3))
    if rotate_z != 0:
        data = np.rot90(data, rotate_z, axes=(2, 3))
    if flip_x:
        data = data[:, ::-1]
    if flip_y:
        data = data[:, :, ::-1]
    if flip_z:
        data = data[:, :, :, ::-1]
    if transpose:
        for i in range(data.shape[0]):
            data[i] = data[i].T
    return data

def random_permutation_x_y(x_data, y_data):
    """
    Performs random permutation on the data.
    :param x_data: numpy array containing the data. Data must be of shape (n_modalities, x, y, z).
    :param y_data: numpy array containing the data. Data must be of shape (n_modalities, x, y, z).
    :return: the permuted data
    """
    key = random_permutation_key()
    return permute_data(x_data, key), permute_data(y_data, key)



In [15]:
##### unet3d/generator.py #####

import os
import copy
from random import shuffle
import itertools

import numpy as np

# from .utils import pickle_dump, pickle_load
#from .patches import compute_patch_indices, get_random_nd_index, get_patch_from_3d_data, compute_patch_indices_for_prediction
#from .augment import augment_data, random_permutation_x_y

import pdb
#from dev_tools.my_tools import print_red
from tqdm import tqdm
import time

"""
J.Lee:
cascade of functions:
get_training_and_validation_generators()
    ->get_validation_split()
        ->split_list()
        ->pickle_dump(),pickle_load()
    ->data_generator(tr)
        ->if patch: create_patch_index_list()
        ->while:
            add_data()
            if len(idx_list) == batch_size: yield convert_data()
    ->data_generator(val) 
    ->get_number_of_patches(tr)
    ->get_number_of_steps(tr)
    ->get_number_of_patches(val)
    ->get_number_of_steps(val)
    
    
get_training_and_validation_generators
    ->get_validation_split
        ->split_list
    ->data_generator
        ->create_patch_index_list
            ->get_random_nd_index
            ->compute_patch_indices
                ->get_set_of_patch_indices
        ->while:
          add_data
            ->if patch_shape:
              get_data_from_file
                ->get_patch_from_3d_data
                    ->fix_out_of_bound_patch_attempt
            ->augment_data
            ->random_permutation_x_y
        ->if:
          yield convert_data()
    ->get_number_of_patches
        ->if patch_shape: create_patch_index_list
    ->get_number_of_steps
    
"""
    

def get_training_and_validation_generators(data_file, batch_size, n_labels, training_keys_file, validation_keys_file,
                                           data_split=0.8, overwrite=False, labels=None, augment=False,
                                           augment_flip=True, augment_distortion_factor=0.25, patch_shape=None,
                                           validation_patch_overlap=0, training_patch_start_offset=None,
                                           validation_batch_size=None, skip_blank=True, permute=False,num_model=1,
                                           pred_specific=False, overlap_label=True,
                                           for_final_val=False):
#     pdb.set_trace()
    if not validation_batch_size:
        validation_batch_size = batch_size

    training_list, validation_list = get_validation_split(data_file,
                                                          data_split=data_split,
                                                          overwrite=overwrite,
                                                          training_file=training_keys_file,
                                                          validation_file=validation_keys_file)
    if for_final_val:
        training_list = training_list + validation_list

    training_generator = data_generator(data_file, training_list,
                                        batch_size=batch_size,
                                        n_labels=n_labels,
                                        labels=labels,
                                        augment=augment,
                                        augment_flip=augment_flip,
                                        augment_distortion_factor=augment_distortion_factor,
                                        patch_shape=patch_shape,
                                        patch_overlap=validation_patch_overlap,
                                        patch_start_offset=training_patch_start_offset,
                                        skip_blank=skip_blank,
                                        permute=permute,
                                        num_model=num_model,
                                        pred_specific=pred_specific,
                                        overlap_label=overlap_label)
    
    validation_generator = data_generator(data_file, validation_list,
                                          batch_size=validation_batch_size,
                                          n_labels=n_labels,
                                          labels=labels,
                                          patch_shape=patch_shape,
                                          patch_overlap=validation_patch_overlap,
                                          skip_blank=skip_blank,
                                          num_model=num_model,
                                          pred_specific=pred_specific,
                                          overlap_label=overlap_label)

    # Set the number of training and testing samples per epoch correctly
#     pdb.set_trace()

    #J.Lee:It takes long time. for testing, freeze and set num_training_steps as a small constant like 8. 
    if os.path.exists('num_patches_training_N4_norm_1.3.npy'):
        num_patches_training = int(np.load('num_patches_training_N4_norm_1.3.npy'))
    else:
        num_patches_training = get_number_of_patches(data_file, training_list, patch_shape,
                                                       skip_blank=skip_blank,
                                                       patch_start_offset=training_patch_start_offset,
                                                       patch_overlap=validation_patch_overlap,
                                                       pred_specific=pred_specific)
        np.save('num_patches_training_N4_norm_1.3', num_patches_training)
    num_training_steps = get_number_of_steps(num_patches_training, batch_size)
    print("Number of training steps in each epoch: ", num_training_steps)

    if os.path.exists('num_patches_val_N4_norm_1.3.npy'):
        num_patches_val = int(np.load('num_patches_val_N4_norm_1.3.npy'))
    else:
        num_patches_val = get_number_of_patches(data_file, validation_list, patch_shape,
                                                 skip_blank=skip_blank,
                                                 patch_overlap=validation_patch_overlap,
                                                 pred_specific=pred_specific)
        np.save('num_patches_val_N4_norm_1.3', num_patches_val)
    num_validation_steps = get_number_of_steps(num_patches_val, validation_batch_size)
    print("Number of validation steps in each epoch: ", num_validation_steps)

    return training_generator, validation_generator, num_training_steps, num_validation_steps



def get_number_of_steps(n_samples, batch_size):
    if n_samples <= batch_size:
        return n_samples
    elif np.remainder(n_samples, batch_size) == 0:
        return n_samples//batch_size
    else:
        return n_samples//batch_size + 1


def get_validation_split(data_file, training_file, validation_file, data_split=0.8, overwrite=False):
    '''
    Splits the data into the training and validation indices list.
    '''
    if overwrite or not os.path.exists(training_file):
        print("Creating validation split...")
        nb_samples = data_file.root.truth.shape[0]
        sample_list = list(range(nb_samples))
        training_list, validation_list = split_list(sample_list, split=data_split)
        pickle_dump(training_list, training_file)
        pickle_dump(validation_list, validation_file)
        return training_list, validation_list
    else:
        print("Loading previous validation split...")
        return pickle_load(training_file), pickle_load(validation_file)


def split_list(input_list, split=0.8, shuffle_list=True):
    if shuffle_list:
        shuffle(input_list)
    n_training = int(len(input_list) * split)
    training = input_list[:n_training]
    testing = input_list[n_training:]
    return training, testing


def data_generator(data_file, index_list, batch_size=1, n_labels=1, labels=None, augment=False, augment_flip=True,
                   augment_distortion_factor=0.25, patch_shape=None, patch_overlap=0, patch_start_offset=None,
                   shuffle_index_list=True, skip_blank=True, permute=False, num_model=1, pred_specific=False,overlap_label=False):
#     pdb.set_trace()

    orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                 patch_overlap, patch_start_offset,pred_specific=pred_specific)
        else:
            index_list = copy.copy(orig_index_list)

        if shuffle_index_list:
            shuffle(index_list)
        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
#                 convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model)
                x_list = list()
                y_list = list()



def get_number_of_patches(data_file, index_list, patch_shape=None, patch_overlap=0, patch_start_offset=None,
                          skip_blank=True,pred_specific=False):
    if patch_shape:
        index_list = create_patch_index_list(index_list, data_file, patch_shape, patch_overlap,
                                             patch_start_offset,pred_specific=pred_specific)
        count = 0
        for index in tqdm(index_list):
            x_list = list()
            y_list = list()
            add_data(x_list, y_list, data_file, index, skip_blank=skip_blank, patch_shape=patch_shape)
            if len(x_list) > 0:
                count += 1
        return count
    else:
        return len(index_list)


def create_patch_index_list(index_list, data_file, patch_shape, patch_overlap, patch_start_offset=None, pred_specific=False):
    patch_index = list()
    for index in index_list:
        brain_width = data_file.root.brain_width[index]
        image_shape = brain_width[1] - brain_width[0] + 1
        if pred_specific:
            patches = compute_patch_indices_for_prediction(image_shape, patch_shape)
        else:
            if patch_start_offset is not None:
                random_start_offset = np.negative(get_random_nd_index(patch_start_offset))
                patches = compute_patch_indices(image_shape, patch_shape, overlap=patch_overlap, start=random_start_offset)
            else:
                patches = compute_patch_indices(image_shape, patch_shape, overlap=patch_overlap)
        patch_index.extend(itertools.product([index], patches))
    return patch_index


def add_data(x_list, y_list, data_file, index, augment=False, augment_flip=False, augment_distortion_factor=0.25,
             patch_shape=False, skip_blank=True, permute=False):
    '''
    add qualified x,y to the generator list
    '''
#     pdb.set_trace()
    data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)
    
    if np.sum(truth) == 0:
        return
    if augment:
        affine = np.load('affine_N4_norm.npy')
        data, truth = augment_data(data, truth, affine, flip=augment_flip, scale_deviation=augment_distortion_factor)

    if permute:
        if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
            raise ValueError("To utilize permutations, data array must be in 3D cube shape with all dimensions having "
                             "the same length.")
        data, truth = random_permutation_x_y(data, truth[np.newaxis])
    else:
        truth = truth[np.newaxis]

    if not skip_blank or np.any(truth != 0):
        x_list.append(data)
        y_list.append(truth)


def get_data_from_file(data_file, index, patch_shape=None):
#     pdb.set_trace()
    if patch_shape:
        index, patch_index = index
        data, truth = get_data_from_file(data_file, index, patch_shape=None)
        x = get_patch_from_3d_data(data, patch_shape, patch_index)
        y = get_patch_from_3d_data(truth, patch_shape, patch_index)
    else:
        brain_width = data_file.root.brain_width[index]
        x = np.array([modality_img[index,0,
                                   brain_width[0,0]:brain_width[1,0]+1,
                                   brain_width[0,1]:brain_width[1,1]+1,
                                   brain_width[0,2]:brain_width[1,2]+1] 
                      for modality_img in [data_file.root.t1,
                                           data_file.root.t1ce,
                                           data_file.root.flair,
                                           data_file.root.t2]])
        y = data_file.root.truth[index, 0,
                                 brain_width[0,0]:brain_width[1,0]+1,
                                 brain_width[0,1]:brain_width[1,1]+1,
                                 brain_width[0,2]:brain_width[1,2]+1]
    return x, y


def convert_data(x_list, y_list, n_labels=1, labels=None, num_model=1,overlap_label=False):
#     pdb.set_trace()
    x = np.asarray(x_list)
    y = np.asarray(y_list)
    if n_labels == 1:
        y[y > 0] = 1
    elif n_labels > 1:
        if overlap_label:
            y = get_multi_class_labels_overlap(y, n_labels=n_labels, labels=labels)
        else:
            y = get_multi_class_labels(y, n_labels=n_labels, labels=labels)
    if num_model == 1:
        return x, y
    else:
        return [x]*num_model, y


def get_multi_class_labels_overlap(data, n_labels=3, labels=(1,2,4)):
    """
    4: ET
    1+4: TC
    1+2+4: WT
    """
#     pdb.set_trace()
    new_shape = [data.shape[0], n_labels] + list(data.shape[2:])
    y = np.zeros(new_shape, np.int8)
    
    y[:,0][np.logical_or(data[:,0] == 1,data[:,0] == 4)] = 1    #1
    y[:,1][np.logical_or(data[:,0] == 1,data[:,0] == 2, data[:,0] == 4)] = 1 #2
    y[:,2][data[:,0] == 4] = 1    #4
    return y

##### from ellisdg's unet3d/generator.py #####
def get_multi_class_labels(data, n_labels, labels=None):
    """
    Translates a label map into a set of binary labels.
    :param data: numpy array containing the label map with shape: (n_samples, 1, ...).
    :param n_labels: number of labels.
    :param labels: integer values of the labels.
    :return: binary numpy array of shape: (n_samples, n_labels, ...)
    """
    new_shape = [data.shape[0], n_labels] + list(data.shape[2:])
    y = np.zeros(new_shape, np.int8)
    for label_index in range(n_labels):
        if labels is not None:
            y[:, label_index][data[:, 0] == labels[label_index]] = 1
        else:
            y[:, label_index][data[:, 0] == (label_index + 1)] = 1
    return y

In [16]:
##### demo_task1/train_model.py #####

config = dict()
config["overwrite"] = False              # To overwrite data.h5.
# config["pool_size"] = (2, 2, 2)          # pool size for the max pooling operations

#J.lee: input_shape
config["image_shape"] = (128,128,64) #(240,240,155)  # This determines what shape the images will be cropped/resampled to.

#J.lee: patching
config["patch_shape"] = (128, 128, 128)     # switch to None to train on the whole image
config["training_patch_start_offset"] = (4, 4, 4)  # randomly offset the first patch index by up to this offset
config["validation_patch_overlap"] = 32                # if > 0, during training, validation patches will be overlapping     
config['pred_specific'] = False          # =True: To train with patching strategy specificly for prediction. 
config['center_patch'] = True            # To include the center patch in the patching strategy.

#J.lee: batch
config["batch_size"] = 1
config["validation_batch_size"] = 1 # 2
config["n_epochs"] = 25 # 300

#J.lee: paths for files.
config["data_file"] = 'C:/IAMEDIC/Jaeho_code/data/data_N4_norm.h5' # os.path.abspath("../data/data.h5")
# config["model_file"] = 'C:/IAMEDIC/Jaeho_code/woodywff_seg_model.h5'# os.path.abspath("seg_model.h5")
config["model_file"] = 'C:/IAMEDIC/Jaeho_code/seg_model_1.3.h5'# os.path.abspath("seg_model.h5")
config['mean_std_file'] = 'C:/IAMEDIC/Jaeho_code/data/mean_std.pkl' #os.path.abspath('../data/mean_std.pkl')

#
config["training_file"] = "C:/IAMEDIC/Jaeho_code/data/list_cv1.3_train.pkl"
config["validation_file"] = "C:/IAMEDIC/Jaeho_code/data/list_cv1.3_val.pkl"


config['for_final_val'] = True
#--------------------------------------------------------------------------------
config['logging_file'] = 'C:/IAMEDIC/Jaeho_code/training.log' #os.path.abspath('training.log')

# truth.shape = (240,240,155) with value in [1,2,4], if 4 is on top of others or surrounded by others 
# config['overlap_label_generator'] = False
# config['overlap_label_predict'] = False
config['overlap_label_generator'] = True
config['overlap_label_predict'] = True


config["labels"] = (1, 2, 4)             # the label numbers on the input image
config["n_labels"] = len(config["labels"])
config["all_modalities"] = ["t1", "t1ce", "flair", "t2"]
config["training_modalities"] = config["all_modalities"]  # change this if you want to only use some of the modalities
config["nb_channels"] = len(config["training_modalities"])

config["n_base_filters"] = 16

if "patch_shape" in config and config["patch_shape"] is not None:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["patch_shape"]))
else:
    config["input_shape"] = tuple([config["nb_channels"]] + list(config["image_shape"]))

config["truth_channel"] = config["nb_channels"]
config["deconvolution"] = True           # if False, will use upsampling instead of deconvolution


config["patience"] = 5    # learning rate will be reduced after this many epochs if the validation loss is not improving
config["early_stop"] = 50  # training will be stopped after this many epochs without the validation loss improving
config["initial_learning_rate"] = 5e-4
config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
config["validation_split"] = 0.8    # portion of the data that will be used for training

config["flip"] = False              # augments the data by randomly flipping an axis during
# config["flip"] = True
# config["permute"] = False
config["permute"] = True  # data shape must be a cube. Augments the data by permuting in various directions
# config["distort"] = None  # switch to None if you want no distortion
config["distort"] = 0.25
config["augment"] = config["flip"] or config["distort"]

config["skip_blank"] = True                           # if True, then patches without any target will be skipped



In [19]:
#J.Lee
# print('n_train_steps:', n_train_steps)
# print('n_validation_steps:', n_validation_steps)
# print()
# print('np.shape(next(train_generator)[0]):')
# print(f'{np.shape(next(train_generator)[0])}')
# print()
# print('np.shape(next(validation_generator)[0]):')
# print(f'{np.shape(next(validation_generator)[0])}')

## Training

In [18]:
import math
from functools import partial
import pdb
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler, ReduceLROnPlateau, EarlyStopping
from tensorflow_addons.callbacks import TQDMProgressBar
from tensorflow.keras.models import load_model

# from unet3d.metrics import dice_coefficient, dice_coefficient_loss, weighted_dice_coefficient_loss, weighted_dice_coefficient

K.set_image_data_format('channels_first')

##### unet3d/training.py
def step_decay(epoch, initial_lrate, drop, epochs_drop):
    return initial_lrate * math.pow(drop, math.floor((1+epoch)/float(epochs_drop)))

def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None,
                  learning_rate_patience=50, logging_file="training.log", verbosity=1,
                  early_stopping_patience=None):
    callbacks = list()
    callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(CSVLogger(logging_file, append=True))
    if learning_rate_epochs:
        callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
                                                       drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
    else:
        callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
                                           verbose=verbosity))
    if early_stopping_patience:
        callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
    callbacks.append(TQDMProgressBar())
    return callbacks

def load_old_model(model_file):
#     pdb.set_trace()
    print("Loading pre-trained model")
    custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient,
                      'weighted_dice_coefficient': weighted_dice_coefficient,
                      'weighted_dice_coefficient_loss': weighted_dice_coefficient_loss}
    try:
        from tensorflow_addons.layers import InstanceNormalization
        custom_objects["InstanceNormalization"] = InstanceNormalization
    except ImportError:
        pass
    try:
        return load_model(model_file, custom_objects=custom_objects)
    except ValueError as error:
        if 'InstanceNormalization' in str(error):
            raise ValueError(str(error) + "\n\nInstall tensorflow_addons in order to use instance normalization\n")
        else:
            raise error
            
def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
                initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
                learning_rate_patience=20, early_stopping_patience=None, logging_file = 'training.log'):
    model.fit(x=training_generator,
                steps_per_epoch=steps_per_epoch,
                epochs=n_epochs,
                validation_data=validation_generator,
                validation_steps=validation_steps,
                callbacks=get_callbacks(model_file,
                                        initial_learning_rate=initial_learning_rate,
                                        learning_rate_drop=learning_rate_drop,
                                        learning_rate_epochs=learning_rate_epochs,
                                        learning_rate_patience=learning_rate_patience,
                                        logging_file = logging_file,
                                        early_stopping_patience=early_stopping_patience))

In [21]:
data_file_opened = open_data_file(config["data_file"])

In [21]:

##### demo_task1/train_model.py #####

#from main()
overwrite = False

data_file_opened = open_data_file(config["data_file"])

if not overwrite and os.path.exists(config["model_file"]):
    model = load_old_model(config["model_file"])
else:
    # instantiate new model
    model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                              depth=5, #J.Lee added
                              n_segmentation_levels = 3, #J.Lee added
                              loss_function=weighted_dice_coefficient_loss,
                              initial_learning_rate=config["initial_learning_rate"],
                              n_base_filters=config["n_base_filters"])

train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
    data_file_opened,
    batch_size= config["batch_size"],
    data_split= config["validation_split"],
    overwrite= False, #overwrite,
    validation_keys_file=config["validation_file"],
    training_keys_file=config["training_file"],
    n_labels=config["n_labels"],
    labels=config["labels"],
    patch_shape=config["patch_shape"],
    validation_batch_size=config["validation_batch_size"],
    validation_patch_overlap=config["validation_patch_overlap"],
    training_patch_start_offset=config["training_patch_start_offset"],
    permute=config["permute"],
    augment=config["augment"],
    #skip_blank=config["skip_blank"],
    #augment_flip=config["flip"],
    augment_distortion_factor=config["distort"],
    #pred_specific=config['pred_specific'],
    #overlap_label=config['overlap_label_generator'],
    #for_final_val=config['for_final_val']
    )



Loading pre-trained model
Loading previous validation split...
Number of training steps in each epoch:  1660
Number of validation steps in each epoch:  555


In [22]:
# not overwrite and os.path.exists('C:\IAMEDIC\Jaeho_code\woodywff_seg_model.h5')

In [23]:
##### demo_task1/train_model.py #####
time_0 = time.time()
train_model(model=model,
            model_file=config["model_file"],
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=n_train_steps,
            validation_steps=n_validation_steps,
            initial_learning_rate=config["initial_learning_rate"],
            learning_rate_drop=config["learning_rate_drop"],
            learning_rate_patience=config["patience"],
            early_stopping_patience=config["early_stop"],
            n_epochs=config["n_epochs"],
            logging_file = config['logging_file'])
print('Training time:', sec2hms(time.time() - time_0))
# data_file_opened.close()

HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=25.0, style=Progr…

Epoch 1/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 1/25
Epoch 2/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 2/25
Epoch 3/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 3/25
Epoch 4/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 4/25
Epoch 5/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 5/25
Epoch 6/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 6/25
Epoch 7/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 7/25
Epoch 8/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 8/25
Epoch 9/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 9/25
Epoch 10/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 10/25
Epoch 11/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 11/25
Epoch 12/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 12/25
Epoch 13/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 13/25
Epoch 14/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 14/25
Epoch 15/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 15/25
Epoch 16/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 16/25
Epoch 17/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 17/25
Epoch 18/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 18/25
Epoch 19/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 19/25
Epoch 20/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 20/25
Epoch 21/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 21/25
Epoch 22/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 22/25
Epoch 23/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 23/25
Epoch 24/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 24/25
Epoch 25/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 25/25

Training time: 0 days, 16 hours, 12 mins, 20.264 secs.


In [23]:
model = load_old_model(config["model_file"])
train_model(model=model,
            model_file=config["model_file"],
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=n_train_steps,
            validation_steps=n_validation_steps,
            initial_learning_rate=0.0002 , # config["initial_learning_rate"],
            learning_rate_drop=config["learning_rate_drop"],
            learning_rate_patience=config["patience"],
            early_stopping_patience=config["early_stop"],
            n_epochs=config["n_epochs"],
            logging_file = config['logging_file'])
# data_file_opened.close()



Loading pre-trained model


HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=25.0, style=Progr…

Epoch 1/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 1/25
Epoch 2/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 2/25
  15/1660 [..............................] - ETA: 35:11 - loss: -0.6188 - dice_coefficient: 0.7046

KeyboardInterrupt: 

In [24]:
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
    data_file_opened,
    batch_size= config["batch_size"],
    data_split= config["validation_split"],
    overwrite= False, #overwrite,
    validation_keys_file=config["validation_file"],
    training_keys_file=config["training_file"],
    n_labels=config["n_labels"],
    labels=config["labels"],
    patch_shape=config["patch_shape"],
    validation_batch_size=config["validation_batch_size"],
    validation_patch_overlap=config["validation_patch_overlap"],
    training_patch_start_offset=config["training_patch_start_offset"],
    permute=config["permute"],
    augment=config["augment"],
    skip_blank=config["skip_blank"],
    #augment_flip=config["flip"],
    augment_distortion_factor=config["distort"],
    pred_specific=config['pred_specific'],
    overlap_label=config['overlap_label_generator'],
    #for_final_val=config['for_final_val']
    )

Loading previous validation split...
Number of training steps in each epoch:  1660
Number of validation steps in each epoch:  555


In [25]:
model = load_old_model(config["model_file"])
train_model(model=model,
            model_file=config["model_file"],
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=n_train_steps,
            validation_steps=n_validation_steps,
            initial_learning_rate=0.0002 ,#config["initial_learning_rate"],
            learning_rate_drop=config["learning_rate_drop"],
            learning_rate_patience=config["patience"],
            early_stopping_patience=config["early_stop"],
            n_epochs=config["n_epochs"],
            logging_file = config['logging_file'])



Loading pre-trained model


HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=25.0, style=Progr…

Epoch 1/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 1/25
Epoch 2/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 2/25
Epoch 3/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 3/25
Epoch 4/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 4/25
Epoch 5/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 5/25
Epoch 6/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 6/25
Epoch 7/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 7/25
Epoch 8/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 8/25
Epoch 9/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 9/25
Epoch 10/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 10/25
Epoch 11/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 11/25
Epoch 12/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 12/25
Epoch 13/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 13/25
Epoch 00013: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.

Epoch 14/25


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1660.0), HTML(value='')), layout=Layout(d…

Epoch 14/25
  47/1660 [..............................] - ETA: 34:03 - loss: -0.6945 - dice_coefficient: 0.6659

KeyboardInterrupt: 

In [22]:
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
    data_file_opened,
    batch_size= config["batch_size"],
    data_split= config["validation_split"],
    overwrite= False, #overwrite,
    validation_keys_file=config["validation_file"],
    training_keys_file=config["training_file"],
    n_labels=config["n_labels"],
    labels=config["labels"],
    patch_shape=config["patch_shape"],
    validation_batch_size=config["validation_batch_size"],
    validation_patch_overlap=config["validation_patch_overlap"],
    training_patch_start_offset=config["training_patch_start_offset"],
    permute=config["permute"],
    augment=config["augment"],
    skip_blank=config["skip_blank"],
    #augment_flip=config["flip"],
    augment_distortion_factor=config["distort"],
    pred_specific=config['pred_specific'],
    overlap_label=config['overlap_label_generator'],
    #for_final_val=config['for_final_val']
    )

Loading previous validation split...
Number of training steps in each epoch:  1660
Number of validation steps in each epoch:  555


In [23]:
model = load_old_model(config["model_file"])
model.compile(optimizer=Adam(lr=0.0002), loss=weighted_dice_coefficient_loss, metrics=dice_coefficient)



Loading pre-trained model


In [24]:
import datetime
from tensorflow.keras.callbacks import TensorBoard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None,
                  learning_rate_patience=50, logging_file="training.log", verbosity=1,
                  early_stopping_patience=None):
    callbacks = list()
    callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(CSVLogger(logging_file, append=True))
    if learning_rate_epochs:
        callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
                                                       drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
    else:
        callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
                                           verbose=verbosity))
    if early_stopping_patience:
        callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
    callbacks.append(tensorboard_callback)
    return callbacks

def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
                initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
                learning_rate_patience=20, early_stopping_patience=None, logging_file = 'training.log'):
    model.fit(x=training_generator,
                steps_per_epoch=steps_per_epoch,
                epochs=n_epochs,
                validation_data=validation_generator,
                validation_steps=validation_steps,
                callbacks=get_callbacks(model_file,
                                        initial_learning_rate=initial_learning_rate,
                                        learning_rate_drop=learning_rate_drop,
                                        learning_rate_epochs=learning_rate_epochs,
                                        learning_rate_patience=learning_rate_patience,
                                        logging_file = logging_file,
                                        early_stopping_patience=early_stopping_patience))

In [25]:
%load_ext tensorboard

In [26]:
train_model(model=model,
            model_file=config["model_file"],
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=n_train_steps,
            validation_steps=n_validation_steps,
            initial_learning_rate= 0.0002, # config["initial_learning_rate"],
            learning_rate_drop=config["learning_rate_drop"],
            learning_rate_patience=config["patience"],
            early_stopping_patience=config["early_stop"],
            n_epochs=config["n_epochs"],
            logging_file = config['logging_file'])

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 00008: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-05.
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


In [28]:
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
    data_file_opened,
    batch_size= config["batch_size"],
    data_split= config["validation_split"],
    overwrite= False, #overwrite,
    validation_keys_file=config["validation_file"],
    training_keys_file=config["training_file"],
    n_labels=config["n_labels"],
    labels=config["labels"],
    patch_shape=config["patch_shape"],
    validation_batch_size=config["validation_batch_size"],
    validation_patch_overlap=config["validation_patch_overlap"],
    training_patch_start_offset=config["training_patch_start_offset"],
    permute=config["permute"],
    augment=config["augment"],
    skip_blank=config["skip_blank"],
    #augment_flip=config["flip"],
    augment_distortion_factor=config["distort"],
    pred_specific=config['pred_specific'],
    overlap_label=config['overlap_label_generator'],
    #for_final_val=config['for_final_val']
    )

model = load_old_model(config["model_file"])
model.compile(optimizer=Adam(lr=0.0001), loss=weighted_dice_coefficient_loss, metrics=dice_coefficient)

import datetime
from tensorflow.keras.callbacks import TensorBoard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

def get_callbacks(model_file, initial_learning_rate=0.0001, learning_rate_drop=0.5, learning_rate_epochs=None,
                  learning_rate_patience=50, logging_file="training.log", verbosity=1,
                  early_stopping_patience=None):
    callbacks = list()
    callbacks.append(ModelCheckpoint(model_file, save_best_only=True))
    callbacks.append(CSVLogger(logging_file, append=True))
    if learning_rate_epochs:
        callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
                                                       drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
    else:
        callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
                                           verbose=verbosity))
    if early_stopping_patience:
        callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
    callbacks.append(tensorboard_callback)
    return callbacks

def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
                initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
                learning_rate_patience=20, early_stopping_patience=None, logging_file = 'training.log'):
    model.fit(x=training_generator,
                steps_per_epoch=steps_per_epoch,
                epochs=n_epochs,
                validation_data=validation_generator,
                validation_steps=validation_steps,
                callbacks=get_callbacks(model_file,
                                        initial_learning_rate=initial_learning_rate,
                                        learning_rate_drop=learning_rate_drop,
                                        learning_rate_epochs=learning_rate_epochs,
                                        learning_rate_patience=learning_rate_patience,
                                        logging_file = logging_file,
                                        early_stopping_patience=early_stopping_patience))



Loading previous validation split...
Number of training steps in each epoch:  1660
Number of validation steps in each epoch:  555
Loading pre-trained model


In [29]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [30]:
train_model(model=model,
            model_file=config["model_file"],
            training_generator=train_generator,
            validation_generator=validation_generator,
            steps_per_epoch=n_train_steps,
            validation_steps=n_validation_steps,
            initial_learning_rate= 0.0001, # config["initial_learning_rate"],
            learning_rate_drop=config["learning_rate_drop"],
            learning_rate_patience=config["patience"],
            early_stopping_patience=config["early_stop"],
            n_epochs=config["n_epochs"],
            logging_file = config['logging_file'])

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 00009: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 00014: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Epoch 15/25
Epoch 16/25
 200/1660 [==>...........................] - ETA: 31:51 - loss: -0.7459 - dice_coefficient: 0.7540

KeyboardInterrupt: 

In [31]:
model.save(filepath='C:/IAMEDIC/Jaeho_code/seg_model_1.31.h5',
           overwrite=True, include_optimizer=True, save_format=None,
           signatures=None, options=None)

In [48]:
##### unet3d/prediction.py #####

def run_validation_case(data_index, output_dir, model, data_file, training_modalities,
                        threshold=0.5, labels=None, overlap=16, 
                        permute=False, center_patch=True, overlap_label=True, 
                        final_val=False):
#     pdb.set_trace()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    affine = np.load('./affine.npy')
    test_data = np.array([modality_img[data_index,0] 
                      for modality_img in [data_file.root.t1,
                                           data_file.root.t1ce,
                                           data_file.root.flair,
                                           data_file.root.t2]])[np.newaxis]
    for i in range(test_data.shape[1]):
        if i == 0:
            brain_mask = np.copy(test_data[0,i])
            brain_mask[np.nonzero(brain_mask)] = True
        else:
            temp_mask = np.copy(test_data[0,i])
            temp_mask[np.nonzero(temp_mask)] = True
            brain_mask = np.logical_or(brain_mask,temp_mask)
    
    
    for i, modality in enumerate(training_modalities):
        image = nib.Nifti1Image(test_data[0, i], affine)
        image.to_filename(os.path.join(output_dir, "data_{0}.nii.gz".format(modality)))
    
    if not final_val:
        test_truth = nib.Nifti1Image(data_file.root.truth[data_index][0], affine)
        test_truth.to_filename(os.path.join(output_dir, "truth.nii.gz"))
    
    brain_width = data_file.root.brain_width[data_index]

    patch_shape = tuple([int(dim) for dim in model.input.shape[-3:]])
    if patch_shape == test_data.shape[-3:]:
        prediction = predict(model, test_data, permute=permute)
    else:
        prediction = patch_wise_prediction(model=model, data=test_data, brain_width=brain_width,
                                           overlap=overlap, permute=permute, center_patch=center_patch)[np.newaxis]
#     pdb.set_trace()
    prediction_image = prediction_to_image(prediction, affine, brain_mask,
                                           threshold=threshold, labels=labels,output_dir=output_dir,overlap_label=overlap_label)
    if isinstance(prediction_image, list):
        for i, image in enumerate(prediction_image):
            image.to_filename(os.path.join(output_dir, "prediction_{0}.nii.gz".format(i + 1)))
    else:
        prediction_image.to_filename(os.path.join(output_dir, data_file.root.subject_ids[data_index].decode()+'.nii.gz'))


def run_validation_cases(validation_keys_file, model_file, training_modalities, labels, hdf5_file,
                         output_dir=".", threshold=0.5, overlap=16, 
                         permute=False,center_patch=True, overlap_label=True, final_val = False):
    validation_indices = pickle_load(validation_keys_file)
    model = load_old_model(model_file)
    data_file = tables.open_file(hdf5_file, "r")
    
    for index in tqdm(validation_indices):
        if 'subject_ids' in data_file.root:
            case_directory = os.path.join(output_dir, data_file.root.subject_ids[index].decode('utf-8'))
        else:
            case_directory = os.path.join(output_dir, "validation_case_{}".format(index))
        run_validation_case(data_index=index, output_dir=case_directory, model=model, data_file=data_file,
                            training_modalities=training_modalities, labels=labels,
                            threshold=threshold, overlap=overlap, permute=permute, center_patch=center_patch,
                            overlap_label=overlap_label,
                            final_val=final_val)
    data_file.close()
#     pdb.set_trace()

In [None]:
##### demo_task1/data_for_val.py
# uc: unchanged
import os
import pdb
import numpy as np
import tables
import nibabel as nib
from tqdm import tqdm

import sys
sys.path.append('..')
#from dev_tools.my_tools import print_red
#from unet3d.normalize import normalize_data_storage_val
#from unet3d.data import cal_outline

def create_data_file(out_file, n_samples, image_shape, modality_names):
#     pdb.set_trace()
    hdf5_file = tables.open_file(out_file, mode='w')
    filters = tables.Filters(complevel=5, complib='blosc')
    modality_shape = tuple([0, 1] + list(image_shape))
    brain_width_shape = (0,2,3)
    
    
    modality_storage_list = [hdf5_file.create_earray(hdf5_file.root, modality_name, tables.Float32Atom(), shape=modality_shape,
                             filters=filters, expectedrows=n_samples) for modality_name in modality_names]
    
    brain_width_storage = hdf5_file.create_earray(hdf5_file.root, 'brain_width', tables.UInt8Atom(), shape=brain_width_shape,
                                            filters=filters, expectedrows=n_samples)
    
    return hdf5_file, modality_storage_list, brain_width_storage



def write_image_data_to_file(image_files, data_storage,brain_width_storage, 
                             image_shape, modality_names, trivial_check = True):
    '''
    trivial_check: to see if all images share the same affine info and pad_width, the incompliance file names 
                   would be printed in red lines.
                   Also to check the order of modalities when added to the .h5
    '''
#     pdb.set_trace()
    affine_0 = np.load('affine.npy')
    
#     temp = 0
    print('write_image_data_to_file...')
    for set_of_files in tqdm(image_files):
        if trivial_check:
            if not [os.path.basename(img_file).split('.')[0] for img_file in set_of_files] == modality_names:
                print('wrong order of modalities')
                print_red(image_nii_path)
        subject_data = []
        brain_widths = []
        for i,image_nii_path in enumerate(set_of_files):
            img = nib.load(image_nii_path)
            affine = img.affine
            if trivial_check:
                if np.sum(affine_0 - affine):
                    print('affine incompliance:')
                    print_red(image_nii_path)
            img_npy = img.get_data()
            subject_data.append(img_npy)
            
            brain_widths.append(cal_outline(img_npy))
                
        start_edge = np.min(brain_widths,axis=0)[0]
        end_edge = np.max(brain_widths,axis=0)[1]
        brain_width = np.vstack((start_edge,end_edge))
        
        if add_data_to_storage(data_storage, brain_width_storage, 
                               subject_data, brain_width, modality_names = modality_names):
            print_red('modality_storage.name != modality_name')
            print_red(set_of_files)
    print('write_image_data_to_file...FINISHED')
    return data_storage


def add_data_to_storage(data_storage, brain_width_storage, 
                        subject_data, brain_width, modality_names):
#     pdb.set_trace()
    for i in range(len(modality_names)):
        if data_storage[i].name != modality_names[i]:
            print_red('modality_storage.name != modality_name')
            return 
        data_storage[i].append(np.asarray(subject_data[i])[np.newaxis][np.newaxis])
    
    brain_width_storage.append(np.asarray(brain_width, dtype=np.uint8)[np.newaxis])
    return 0

def write_data_to_file(training_data_files, out_file, image_shape, modality_names, subject_ids=None,
                       normalize=True, mean_std_file='../data/mean_std.pkl'):

#     pdb.set_trace()
    n_samples = len(training_data_files)

    hdf5_file, data_storage, brain_width_storage = create_data_file(out_file,
                                                                      n_samples=n_samples,
                                                                      image_shape=image_shape,
                                                                      modality_names = modality_names)

    write_image_data_to_file(training_data_files, 
                                data_storage, brain_width_storage, 
                                image_shape, modality_names = modality_names)
    if subject_ids:
        hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
    if normalize:
        normalize_data_storage_val(data_storage, save_file = mean_std_file)
    hdf5_file.close()
    return out_file

In [None]:
##### demo_task1/run_validation.py #####
import os
#from train_model import config
import pdb
import glob
#from data_for_val import write_data_to_file
# from unet3d.prediction import run_validation_cases
import pickle
#from dev_tools.my_tools import my_mkdir, my_makedirs
from tqdm import tqdm
import shutil

def fetch_val_data_files(return_subject_ids=True):
#     pdb.set_trace()
    val_data_files = list()
    subject_ids = list()
    for subject_dir in glob.glob(os.path.join("C:/IAMEDIC/Jaeho_code/data", "preprocessed_val_data", "*", "*")):
        subject_ids.append(os.path.basename(subject_dir))
        subject_files = list()
        for modality in config['all_modalities']:
            subject_files.append(os.path.join(subject_dir, modality + ".nii.gz"))
        val_data_files.append(tuple(subject_files))
    if return_subject_ids:
        return val_data_files, subject_ids
    else:
        return val_data_files


def gen_val_h5():
    if os.path.exists(config['val_data_file']):
        print(config['val_data_file'],'exists already!')
        return

    val_files, subject_ids = fetch_val_data_files()

    write_data_to_file(val_files, 
                        config['val_data_file'], 
                        image_shape=config["image_shape"], 
                        modality_names = config['all_modalities'],
                        subject_ids=subject_ids,
                       mean_std_file = config['mean_std_file'])
    return
    
def mv_results(source_dir,target_dir):
#     print('moving for upload...')
    my_makedirs(target_dir)
    for sub_id in tqdm(os.listdir(source_dir)):
        source_name = os.path.join(source_dir,sub_id,sub_id+'.nii.gz')
        target_name = os.path.join(target_dir,sub_id+'.nii.gz')
        if not os.path.exists(target_name):
            shutil.move(source_name,target_name)
    
def main_run():
    config['num_val_subjects'] = len(os.listdir('C:/IAMEDIC/Jaeho_code/data/preprocessed_val_data/val'))
    
    gen_val_h5()
    
    if not os.path.exists(config['val_index_list']):
        with open(config['val_index_list'],'wb') as f:
            pickle.dump(list(range(config['num_val_subjects'])),f)
    print('Validation dataset prediction starts...')        
    run_validation_cases(validation_keys_file=config['val_index_list'],
                         model_file=config["model_file"],
                         training_modalities=config["training_modalities"],
                         labels=config["labels"],
                         hdf5_file=config["val_data_file"],
                         output_dir=config['val_predict_dir'],
                         center_patch=config['center_patch'],
                         overlap_label=config['overlap_label_predict'],
                         final_val = True)
    mv_results(config['val_predict_dir'],config['val_to_upload'])
    print('Validation dataset prediction finished.')
    return

# def predict_training_dataset():
#     if not os.path.exists(config['training_index_list']):
#         with open(config['training_index_list'],'wb') as f:
#             pickle.dump(list(range(config['num_training_subjects'])),f)
#     print('Training dataset prediction starts...')        
#     run_validation_cases(validation_keys_file=config['training_index_list'],
#                          model_file=config["model_file"],
#                          training_modalities=config["training_modalities"],
#                          labels=config["labels"],
#                          hdf5_file=config["data_file"],
#                          output_dir=config['training_predict_dir'],
#                          center_patch=config['center_patch'],
#                          overlap_label=config['overlap_label_predict'],
#                          final_val = True)
#     mv_results(config['training_predict_dir'],config['training_to_upload'])
#     print('Training dataset prediction finished.')
#     return

In [None]:
main_run()

In [19]:
model = load_old_model(config["model_file"])



Loading pre-trained model


None
