In [None]:
#-*- coding: utf-8 -*-
import numpy as np
import time
import os
from os.path import join, isdir, isfile
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
import matplotlib.pyplot as plt
import tensorflow as tf
import nibabel as nib
import random
import scipy.io
from skimage.util import view_as_blocks
from itertools import compress

In [None]:
### Training steps
CROSS_START = 0
CROSS_MAX = 10
EPOCH_START = 0
EPOCH_MAX = 150

In [None]:
### Network parameters
NUM_BATCH = 8
slice_size = 512
patch_size = 128
patch_stride = 128
INPUT_CH = 3
NUM_LABEL = 2
TR_LR = 0.5

In [None]:
def set_dir(_dir):
    if not os.path.exists(_dir): os.makedirs(_dir)
    return _dir

In [None]:
def extract_nifti_array_from_subject(_dict):
    
    nifti_FLAIR = nib.load(_dict['FLAIR']).get_data()
    nifti_T1 = nib.load(_dict['T1']).get_data()
    nifti_GT = nib.load(_dict['GT']).get_data()
    nifti_GT = np.flip(nifti_GT.astype(np.int32), axis=0)    
    nifti_SEG = nib.load(_dict['SEG']).get_data()
    nifti_SEG = np.flip(nifti_SEG.astype(np.int32), axis=0)    

    nifti_FLAIR = (nifti_FLAIR - nifti_FLAIR.mean()) / nifti_FLAIR.std()
    nifti_T1 = (nifti_T1 - nifti_T1.mean()) / nifti_T1.std()    
    
    nifti_FLAIR = nifti_FLAIR[:,:,:]
    nifti_T1 = nifti_T1[:,:,:]
    nifti_GT = nifti_GT[:,:,:]
    nifti_SEG = nifti_SEG[:,:,:]
    
    # slice to patches
    nifti_FLAIR = np.moveaxis(view_as_blocks(nifti_FLAIR,(patch_size,patch_size,1)).squeeze().reshape(-1,patch_size,patch_size), 0, -1)
    nifti_T1 = np.moveaxis(view_as_blocks(nifti_T1,(patch_size,patch_size,1)).squeeze().reshape(-1,patch_size,patch_size), 0, -1)
    nifti_GT = np.moveaxis(view_as_blocks(nifti_GT,(patch_size,patch_size,1)).squeeze().reshape(-1,patch_size,patch_size), 0, -1)
    nifti_SEG = np.moveaxis(view_as_blocks(nifti_SEG,(patch_size,patch_size,1)).squeeze().reshape(-1,patch_size,patch_size), 0, -1)    
    
    # non-zero patch index & zero patch index
    nz_temp = np.mean(nifti_SEG, axis=(0,1))>0
    nz_index = list(compress(xrange(len(nz_temp)), nz_temp))
    z_index = list(compress(xrange(len(np.logical_not(nz_temp))), np.logical_not(nz_temp)))
    
    # zero patch index selection
    z_index = random.sample(z_index, nz_index.__len__())
    
    # shuffle non-zero patch index and zero patch index
    sel_index = nz_index + z_index
    random.shuffle(sel_index)
    
    # output with selected patch
    nifti_FLAIR = nifti_FLAIR[:,:,sel_index]
    nifti_T1 = nifti_T1[:,:,sel_index]
    nifti_GT = nifti_GT[:,:,sel_index]
    nifti_SEG = nifti_SEG[:,:,sel_index]
    
    return nifti_FLAIR, nifti_T1, nifti_GT, nifti_SEG

In [None]:
def model_data(batch_data):
    """
    input: batch_data (t1+flair+seg)
    
    down_n: (pool)-conv-conv
    up_n: unpool-conv-conv
    
    batch_input [N,480,480,2]
    down_1 [N,,,]
    down_2 [N,,,]
    down_3 [N,,,]
    down_4 [N,,,]
    down_5 [N,,,]
    
    up_1 [N,,,]
    up_2 [N,,,]
    up_3 [N,,,]
    up_4 [N,,,]
    
    output
    
    output: output of the model
    """

    with tf.variable_scope('down_1'):

        down_1 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=batch_data, filters=128, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))
        down_1 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=down_1, filters=128, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))        

    with tf.variable_scope('down_2'):

        down_2 = tf.layers.max_pooling2d( inputs=down_1,  pool_size=2, strides=2)
        down_2 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=down_2, filters=256, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))
        down_2 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=down_2, filters=256, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))        

    with tf.variable_scope('down_3'):

        down_3 = tf.layers.max_pooling2d( inputs=down_2,  pool_size=2, strides=2)
        down_3 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=down_3, filters=512, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))
        down_3 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=down_3, filters=512, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))        

    with tf.variable_scope('up_1'):

        up_1 = tf.layers.conv2d_transpose(  inputs=down_3, filters=256, kernel_size=2, strides=2, activation=tf.nn.relu)
        up_1 = tf.concat([down_2, up_1], axis=-1)
        up_1 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=up_1, filters=256, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))
        up_1 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=up_1, filters=256, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))

    with tf.variable_scope('up_2'):

        up_2 = tf.layers.conv2d_transpose(  inputs=up_1, filters=128, kernel_size=2, strides=2, activation=tf.nn.relu)
        up_2 = tf.concat([down_1, up_2], axis=-1)
        up_2 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=up_2, filters=128, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))
        up_2 = tf.layers.batch_normalization(  tf.layers.conv2d(  inputs=up_2, filters=128, kernel_size=3, activation=tf.nn.relu, use_bias=True, padding='same'))

    with tf.variable_scope('model_data_output'):

        batch_output = tf.layers.conv2d(  inputs=up_2, filters=2, kernel_size=1, activation=tf.nn.relu, use_bias=False, padding='same')
        
    return batch_output

In [None]:
def model_label(batch_label):
    """
    input: batch_label (raw)
    
    output: batch_label (one_hot)
    """
    
    with tf.variable_scope('model_label_output'):

        label_output = tf.squeeze(   tf.one_hot(  indices=batch_label, depth=2, on_value=1.0, off_value=0.0, axis=-1), [-2])
        
    return label_output       

In [None]:
# Data cropping for visualizing

def model_data_crop(batch_data):
    
    batch_crop_data = tf.slice(
        batch_data,
        [0,  batch_crop_label_start, batch_crop_label_start, 0],
        [-1, batch_crop_label_range, batch_crop_label_range, -1])
    
    return batch_crop_data

In [None]:
batch_data = tf.placeholder(tf.float32, shape=[None, patch_size, patch_size, INPUT_CH])
batch_label = tf.placeholder(tf.int32, shape=[None, patch_size, patch_size, 1])
batch_rate = tf.placeholder(tf.float32)

In [None]:
with tf.name_scope('dice_loss'):
    
    with tf.device('/gpu:0'):
        
        logits = model_data(batch_data)    
        labels = model_label(batch_label)

        batch_loss = tf.losses.softmax_cross_entropy(  onehot_labels=labels, logits=logits)

In [None]:
with tf.name_scope('train'):
    
    with tf.device('/gpu:0'):
        
        optimizer = tf.train.AdadeltaOptimizer(batch_rate).minimize(batch_loss)

In [None]:
saver_model = tf.train.Saver(max_to_keep=None)

config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
config.gpu_options.allow_growth = True    

with tf.Session(config=config) as sess:

    for c_cross in xrange(10):

        cross_dir = set_dir(result_dir + str(c_cross) +'/')
        log_dir = set_dir(cross_dir + 'log/')
        
        train_list = load_list(list_path=seg_dir+str(c_cross)+'/train_list.txt')
        valid_list = load_list(list_path=seg_dir+str(c_cross)+'/valid_list.txt')

        train_writer = tf.summary.FileWriter(log_dir + 'train', sess.graph)
        valid_writer = tf.summary.FileWriter(log_dir + 'valid')
        
        sess.run(tf.global_variables_initializer())

        """ Stacking data """
        tr_total_data, tr_total_label = stack_data(train_list, channel=INPUT_CH, flag=0)
        vl_total_data, vl_total_label = stack_data(valid_list, channel=INPUT_CH, flag=0)

        rate_current = TR_LR

        for c_epoch in xrange(EPOCH_MAX):

            start_time = time.time()

            train_total_losses = []
            valid_total_losses = []

            """ train """
            tr_batch_step_total = int(tr_total_data.shape[0]/NUM_BATCH)
            print c_epoch, 'epochs, training...', rate_current,
            for c_step in xrange(tr_batch_step_total):       

                tr_batch_data = tr_total_data[c_step*NUM_BATCH:(c_step+1)*NUM_BATCH,:,:,:]
                tr_batch_label = tr_total_label[c_step*NUM_BATCH:(c_step+1)*NUM_BATCH,:,:]

                tr_batch_loss, _ = sess.run([batch_loss, optimizer],
                                            feed_dict={
                                                batch_data:tr_batch_data,
                                                batch_label:tr_batch_label,
                                                batch_rate:rate_current})
                train_total_losses.append(tr_batch_loss)

            """ valid """
            vl_batch_step_total = int(vl_total_data.shape[0]/NUM_BATCH)                
            print c_epoch, 'epochs, validating...'            
            for c_step in xrange(vl_batch_step_total):

                vl_batch_data = vl_total_data[c_step*NUM_BATCH:(c_step+1)*NUM_BATCH,:,:,:]
                vl_batch_label = vl_total_label[c_step*NUM_BATCH:(c_step+1)*NUM_BATCH,:,:]

                vl_batch_loss = sess.run([batch_loss],
                                            feed_dict={
                                                batch_data:vl_batch_data,
                                                batch_label:vl_batch_label})
                valid_total_losses.append(vl_batch_loss)

            end_time = time.time()

            train_total_loss = np.mean(train_total_losses)
            train_summary = tf.Summary(value=[tf.Summary.Value(tag='train loss', simple_value=train_total_loss)])
            train_writer.add_summary(train_summary, c_epoch)

            valid_total_loss = np.mean(valid_total_losses)
            valid_summary = tf.Summary(value=[tf.Summary.Value(tag='valid loss', simple_value=valid_total_loss)])
            valid_writer.add_summary(valid_summary, c_epoch)

            print c_epoch, 'epoch train_total_loss:', train_total_loss, 'epoch valid_total_loss:', valid_total_loss, (end_time - start_time), 'seconds', c_cross

            if c_epoch % 25 == 24:
                rate_current = 0.5 * rate_current
                saver_path = saver_model.save(sess, cross_dir+'saver_model.ckpt', global_step=c_epoch)  