In [1]:
from __future__ import division
import os
import numpy as np
import pprint
import tensorflow as tf
import tensorflow.contrib.slim as slim
import pickle, csv
import sys
import time
import datetime
import pandas as pd
#import nilearn
import nibabel as nib
from scipy.ndimage.morphology import binary_dilation
sys.path.append('..')
from utils import *
from model import UNet3D
#
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Project dir
proj_dir = '/data/ipl/scratch03/nikhil/skull_seg/'

# Model training

In [3]:
# FLAGS
flags = tf.app.flags
flags.DEFINE_integer("epoch", 1, "Epoch to train [4]")
flags.DEFINE_string("train_patch_dir", proj_dir + "data/patches/p64/", "Directory of the training data [patches]")
flags.DEFINE_bool("split_train", False, "Whether to split the train data into train and val [False]")
flags.DEFINE_string("train_data_dir", proj_dir + "data/input/", "Directory of the train data [../BraTS17TrainingData]")
flags.DEFINE_string("deploy_data_dir", proj_dir + "data/input/", "Directory of the test data [../BraTS17ValidationData]")
flags.DEFINE_string("deploy_output_dir", proj_dir + "data/predictions/", "Directory name of the output data [output]")
# flags.DEFINE_string("train_csv", "../BraTS17TrainingData/survival_data.csv", "CSV path of the training data")
# flags.DEFINE_string("deploy_csv", "../BraTS17ValidationData/survival_evaluation.csv", "CSV path of the validation data")
flags.DEFINE_integer("batch_size", 10, "Batch size [1]")
flags.DEFINE_integer("seg_features_root",32, "Number of features in the first filter in the seg net [48]")
flags.DEFINE_integer("conv_size", 3, "Convolution kernel size in encoding and decoding paths [3]")
flags.DEFINE_integer("layers", 1, "Encoding and deconding layers [3]")
flags.DEFINE_string("loss_type", "dice", "Loss type in the model [cross_entropy]")
flags.DEFINE_float("dropout", 0.5, "Drop out ratio [0.5]")
flags.DEFINE_string("checkpoint_dir", "checkpoint2", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("log_dir", "logs2", "Directory name to save logs [logs]")
flags.DEFINE_boolean("train", True, "True for training, False for deploying [False]")
flags.DEFINE_boolean("run_seg", True, "True if run segmentation [True]")
flags.DEFINE_boolean("run_survival", False, "True if run survival prediction [True]")

flags.DEFINE_string('f', '', 'kernel') # Needed for jupyter to work 
FLAGS = flags.FLAGS #flags.FLAGS

In [4]:
# pp = pprint.PrettyPrinter()
# pp.pprint(flags.FLAGS.__flags)

# Train
# all_train_paths = []
# subject_dirs = next(os.walk(FLAGS.train_data_dir))[1]
# for d in subject_dirs: 
#     all_train_paths.append(os.path.join(FLAGS.train_data_dir,d))

test_patch_dir = proj_dir + "data/patches/p64/"

training_paths = [os.path.join(FLAGS.train_patch_dir, name) for name in os.listdir(FLAGS.train_patch_dir)
                  if '.DS' not in name]
testing_paths = [os.path.join(test_patch_dir, name) for name in os.listdir(test_patch_dir)
                  if '.DS' not in name]

training_ids = [os.path.basename(i) for i in training_paths]
training_survival_paths = []
testing_survival_paths = None
training_survival_data = {}
testing_survival_data = None

if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)

if not os.path.exists(FLAGS.log_dir):
    os.makedirs(FLAGS.log_dir)

In [5]:
# training bookkeeping
trained_model_stats = pd.DataFrame(columns=['run_index','net_structure','training_params',
                                            'stage','acc','dice','loss'])
run_index = 'UNET_1'
input_size = 64
net_structure = '{}_{}_{}'.format(input_size,FLAGS.layers,FLAGS.seg_features_root)
training_params = '{}_{}_{}'.format(FLAGS.epoch,FLAGS.loss_type,FLAGS.batch_size)

train_network = True
test_network = True

start_time = datetime.datetime.now()
print(start_time)

if FLAGS.run_seg:
    run_config = tf.ConfigProto()
    with tf.Session(config=run_config) as sess:
        unet = UNet3D(sess, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir, training_paths=training_paths,
                      testing_paths=testing_paths, batch_size=FLAGS.batch_size, layers=FLAGS.layers,
                      features_root=FLAGS.seg_features_root, conv_size=FLAGS.conv_size,
                      dropout=FLAGS.dropout, loss_type=FLAGS.loss_type)

        if train_network:
            model_vars = tf.trainable_variables()
            slim.model_analyzer.analyze_vars(model_vars, print_info=True)

            train_config = {}
            train_config['epoch'] = FLAGS.epoch

            train_metrics,test_metrics = unet.train(train_config)
            
            #Save the perfromance at the end of training
            train_end_acc = train_metrics['acc'][-1]
            train_end_dice = train_metrics['dice'][-1]
            train_end_loss = train_metrics['loss'][-1]
            test_end_acc = test_metrics['acc'][-1]
            test_end_dice = test_metrics['dice'][-1]
            test_end_loss = test_metrics['loss'][-1]
            
            trained_model_stats.loc[0] = [run_index, net_structure, training_params, 'train', 
                                          train_end_acc, train_end_dice, train_end_loss]
            trained_model_stats.loc[1] = [run_index, net_structure, training_params, 'test',
                                          test_end_acc, test_end_dice, test_end_loss]
            
        print('Training complete')
        
        if test_network:
            print('\nTesting trained model with actual MR image')
            if not os.path.exists(FLAGS.deploy_output_dir):
                os.makedirs(FLAGS.deploy_output_dir)

            unet.deploy(FLAGS.deploy_data_dir, FLAGS.deploy_output_dir)
            print('Test complete. Output at {}'.format(FLAGS.deploy_output_dir))
        
    tf.reset_default_graph()

end_time = datetime.datetime.now()
print('start: {}, end: {}'.format(start_time, end_time))

2018-11-16 20:30:40.511357
encoding...
decoding...
---------
Variables: name (type shape) [size]
---------
encoding0/w1:0 (float32_ref 3x3x3x1x32) [864, bytes: 3456]
encoding0/b1:0 (float32_ref 32) [32, bytes: 128]
encoding0/batch_normalization/beta:0 (float32_ref 32) [32, bytes: 128]
encoding0/batch_normalization/gamma:0 (float32_ref 32) [32, bytes: 128]
encoding0/w2:0 (float32_ref 3x3x3x32x32) [27648, bytes: 110592]
encoding0/b2:0 (float32_ref 32) [32, bytes: 128]
encoding0/batch_normalization_1/beta:0 (float32_ref 32) [32, bytes: 128]
encoding0/batch_normalization_1/gamma:0 (float32_ref 32) [32, bytes: 128]
bottom/w1:0 (float32_ref 3x3x3x32x64) [55296, bytes: 221184]
bottom/b1:0 (float32_ref 64) [64, bytes: 256]
bottom/batch_normalization/beta:0 (float32_ref 64) [64, bytes: 256]
bottom/batch_normalization/gamma:0 (float32_ref 64) [64, bytes: 256]
bottom/w2:0 (float32_ref 3x3x3x64x64) [110592, bytes: 442368]
bottom/b2:0 (float32_ref 64) [64, bytes: 256]
bottom/batch_normalization_1/b

train order: 211
train order: 212
train order: 213
train order: 214
train order: 215
train order: 216
train order: 217
train order: 218
train order: 219
220:train_loss: 0.53584 test_loss: 0.466057
220:train_acc: 0.958495 test_acc: 0.957271
220:train_dice: 0.626059 test_dice: 0.69933
train order: 220
train order: 221
train order: 222
train order: 223
train order: 224
train order: 225
train order: 226
train order: 227
train order: 228
train order: 229
230:train_loss: 0.456745 test_loss: 0.520775
230:train_acc: 0.964516 test_acc: 0.960309
230:train_dice: 0.687575 test_dice: 0.630264
train order: 230
train order: 231
train order: 232
train order: 233
train order: 234
train order: 235
train order: 236
train order: 237
train order: 238
train order: 239
240:train_loss: 0.526273 test_loss: 0.533973
240:train_acc: 0.963484 test_acc: 0.970473
240:train_dice: 0.619352 test_dice: 0.592829
train order: 240
train order: 241
train order: 242
train order: 243
train order: 244
train order: 245
train or

train order: 483
train order: 484
train order: 485
train order: 486
train order: 487
train order: 488
train order: 489
490:train_loss: 0.495243 test_loss: 0.461144
490:train_acc: 0.960472 test_acc: 0.96149
490:train_dice: 0.657132 test_dice: 0.680098
train order: 490
train order: 491
train order: 492
train order: 493
train order: 494
train order: 495
train order: 496
train order: 497
train order: 498
train order: 499
500:train_loss: 0.463932 test_loss: 0.58554
500:train_acc: 0.965345 test_acc: 0.954832
500:train_dice: 0.669098 test_dice: 0.563455
train order: 500
train order: 501
train order: 502
train order: 503
train order: 504
train order: 505
train order: 506
train order: 507
train order: 508
train order: 509
510:train_loss: 0.499747 test_loss: 0.481058
510:train_acc: 0.963269 test_acc: 0.962867
510:train_dice: 0.639783 test_dice: 0.662345
train order: 510
train order: 511
train order: 512
train order: 513
train order: 514
train order: 515
train order: 516
train order: 517
train or

train order: 755
train order: 756
train order: 757
train order: 758
train order: 759
760:train_loss: 0.475015 test_loss: 0.442623
760:train_acc: 0.954376 test_acc: 0.961946
760:train_dice: 0.687102 test_dice: 0.700202
train order: 760
train order: 761
train order: 762
train order: 763
train order: 764
train order: 765
train order: 766
train order: 767
train order: 768
train order: 769
770:train_loss: 0.466939 test_loss: 0.489929
770:train_acc: 0.964661 test_acc: 0.960729
770:train_dice: 0.666296 test_dice: 0.648981
train order: 770
train order: 771
train order: 772
train order: 773
train order: 774
train order: 775
train order: 776
train order: 777
train order: 778
train order: 779
780:train_loss: 0.498709 test_loss: 0.534694
780:train_acc: 0.955132 test_acc: 0.949751
780:train_dice: 0.652199 test_dice: 0.630957
train order: 780
train order: 781
train order: 782
train order: 783
train order: 784
train order: 785
train order: 786
train order: 787
train order: 788
train order: 789
790:tr

In [7]:
# Simple plots (tensorboard alternative)

train_acc = np.array(train_metrics['acc'])
train_dice = np.array(train_metrics['dice'])
train_loss = np.array(train_metrics['loss'])
test_acc = np.array(test_metrics['acc'])
test_dice = np.array(test_metrics['dice'])
test_loss = np.array(test_metrics['loss'])

test_interval = 10
steps = range(test_interval,len(train_acc)+test_interval,test_interval)

plt.figure(figsize=(15,5))
plt.subplot(1,2,1)
plt.plot(train_acc,label='train_acc')
plt.plot(steps,test_acc,'-d',label='test_acc')

plt.plot(train_dice,label='train_dice')
plt.plot(steps,test_dice,'-d',label='test_dice')

plt.legend()
plt.subplot(1,2,2)
plt.plot(train_loss,label='train_loss')
plt.plot(steps,test_loss,'-d',label='test_loss')
plt.legend()

NameError: name 'train_metrics' is not defined

In [7]:
trained_model_stats

Unnamed: 0,run_index,net_structure,training_params,stage,acc,dice,loss
0,UNET_1,32_1_32,1_dice_10,train,0.894873,0.779638,0.604426
1,UNET_1,32_1_32,1_dice_10,test,0.914304,0.719764,0.632419
