In [2]:
from __future__ import division
import os
import numpy as np
import pprint
import tensorflow as tf
import tensorflow.contrib.slim as slim
import pickle, csv
import sys

sys.path.append('..')
from utils import *
from model import UNet3D

In [4]:
# Check patched data
# Need to fix label values

mr_patch= np.load('/Users/nikhil/code/git_repos/skull_seg/test_output/sub04/0.npy')
labels = np.unique(mr_patch[:,:,:,1])
len(labels)

2

In [5]:
# FLAGS
flags = tf.app.flags
flags.DEFINE_integer("epoch", 1, "Epoch to train [4]")
flags.DEFINE_string("train_patch_dir", "/Users/nikhil/code/git_repos/skull_seg/test_output", "Directory of the training data [patches]")
flags.DEFINE_bool("split_train", False, "Whether to split the train data into train and val [False]")
flags.DEFINE_string("train_data_dir", "/Users/nikhil/code/git_repos/skull_seg/test_input/", "Directory of the train data [../BraTS17TrainingData]")
flags.DEFINE_string("deploy_data_dir", "/Users/nikhil/code/git_repos/skull_seg/test_input/", "Directory of the test data [../BraTS17ValidationData]")
flags.DEFINE_string("deploy_output_dir", "/Users/nikhil/code/git_repos/skull_seg/valid_input/", "Directory name of the output data [output]")
flags.DEFINE_string("train_csv", "../BraTS17TrainingData/survival_data.csv", "CSV path of the training data")
flags.DEFINE_string("deploy_csv", "../BraTS17ValidationData/survival_evaluation.csv", "CSV path of the validation data")
flags.DEFINE_integer("batch_size", 10, "Batch size [1]")
flags.DEFINE_integer("seg_features_root", 8, "Number of features in the first filter in the seg net [48]")
flags.DEFINE_integer("survival_features", 8, "Number of features in the survival net [16]")
flags.DEFINE_integer("conv_size", 3, "Convolution kernel size in encoding and decoding paths [3]")
flags.DEFINE_integer("layers", 3, "Encoding and deconding layers [3]")
flags.DEFINE_string("loss_type", "cross_entropy", "Loss type in the model [cross_entropy]")
flags.DEFINE_float("dropout", 0.5, "Drop out ratio [0.5]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("log_dir", "logs", "Directory name to save logs [logs]")
flags.DEFINE_boolean("train", True, "True for training, False for deploying [False]")
flags.DEFINE_boolean("run_seg", True, "True if run segmentation [True]")
flags.DEFINE_boolean("run_survival", False, "True if run survival prediction [True]")
FLAGS = flags.FLAGS

In [6]:
# pp = pprint.PrettyPrinter()
# pp.pprint(flags.FLAGS.__flags)


# Train
all_train_paths = []
# for dirpath, dirnames, files in os.walk(FLAGS.train_data_dir):
#     if os.path.basename(dirpath)[0:7] == 'Brats17':
#         all_train_paths.append(dirpath)
        
subject_dirs = next(os.walk(FLAGS.train_data_dir))[1]
for d in subject_dirs: 
    all_train_paths.append(os.path.join(FLAGS.train_data_dir,d))

if FLAGS.split_train:
    if os.path.exists(os.path.join(FLAGS.train_patch_dir, 'files.log')):
        with open(os.path.join(FLAGS.train_patch_dir, 'files.log'), 'r') as f:
            training_paths, testing_paths = pickle.load(f)
    else:
        all_paths = [os.path.join(FLAGS.train_patch_dir, p) for p in sorted(os.listdir(FLAGS.train_data_dir))]
        np.random.shuffle(all_paths)
        n_training = int(len(all_paths) * 4 / 5)
        training_paths = all_paths[:n_training]
        testing_paths = all_paths[n_training:]
        # Save the training paths and testing paths
        with open(os.path.join(FLAGS.train_data_dir, 'files.log'), 'w') as f:
            pickle.dump([training_paths, testing_paths], f)

    training_ids = [os.path.basename(i) for i in training_paths]
    testing_ids = [os.path.basename(i) for i in testing_paths]

    training_survival_data = {}
    testing_survival_data = {}
    with open(FLAGS.train_csv, 'r') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if row[0] in training_ids:
                training_survival_data[row[0]] = (row[1], row[2])
            elif row[0] in testing_ids:
                testing_survival_data[row[0]] = (row[1], row[2])

    training_survival_paths = [p for p in all_train_paths if os.path.basename(p) in training_survival_data.keys()]
    testing_survival_paths = [p for p in all_train_paths if os.path.basename(p) in testing_survival_data.keys()]
else:
    training_paths = [os.path.join(FLAGS.train_patch_dir, name) for name in os.listdir(FLAGS.train_patch_dir)
                      if '.log' not in name]
    testing_paths = None

    training_ids = [os.path.basename(i) for i in training_paths]
    training_survival_paths = []
    testing_survival_paths = None
    training_survival_data = {}
    testing_survival_data = None

#     with open(FLAGS.train_csv, 'r') as csvfile:
#         reader = csv.reader(csvfile)
#         for row in reader:
#             if row[0] in training_ids:
#                 training_survival_data[row[0]] = (row[1], row[2])
#     training_survival_paths = [p for p in all_train_paths if os.path.basename(p) in training_survival_data.keys()]

if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)

if not os.path.exists(FLAGS.log_dir):
    os.makedirs(FLAGS.log_dir)

In [7]:
training_ids

['sub04', 'sub05']

In [None]:
# Segmentation net
tf.reset_default_graph()
if FLAGS.run_seg:
    run_config = tf.ConfigProto()
    with tf.Session(config=run_config) as sess:
        unet = UNet3D(sess, checkpoint_dir=FLAGS.checkpoint_dir, log_dir=FLAGS.log_dir, training_paths=training_paths,
                      testing_paths=testing_paths, batch_size=FLAGS.batch_size, layers=FLAGS.layers,
                      features_root=FLAGS.seg_features_root, conv_size=FLAGS.conv_size,
                      dropout=FLAGS.dropout, loss_type=FLAGS.loss_type)

        if FLAGS.train:
            model_vars = tf.trainable_variables()
            slim.model_analyzer.analyze_vars(model_vars, print_info=True)

            train_config = {}
            train_config['epoch'] = FLAGS.epoch

            unet.train(train_config)
        else:
            # Deploy
            if not os.path.exists(FLAGS.deploy_output_dir):
                os.makedirs(FLAGS.deploy_output_dir)
            unet.deploy(FLAGS.deploy_data_dir, FLAGS.deploy_output_dir)

    tf.reset_default_graph()

---------
Variables: name (type shape) [size]
---------
encoding0/w1:0 (float32_ref 3x3x3x1x8) [216, bytes: 864]
encoding0/b1:0 (float32_ref 8) [8, bytes: 32]
encoding0/bn1/beta:0 (float32_ref 8) [8, bytes: 32]
encoding0/w2:0 (float32_ref 3x3x3x8x8) [1728, bytes: 6912]
encoding0/b2:0 (float32_ref 8) [8, bytes: 32]
encoding0/bn2/beta:0 (float32_ref 8) [8, bytes: 32]
encoding1/w1:0 (float32_ref 3x3x3x8x16) [3456, bytes: 13824]
encoding1/b1:0 (float32_ref 16) [16, bytes: 64]
encoding1/bn1/beta:0 (float32_ref 16) [16, bytes: 64]
encoding1/w2:0 (float32_ref 3x3x3x16x16) [6912, bytes: 27648]
encoding1/b2:0 (float32_ref 16) [16, bytes: 64]
encoding1/bn2/beta:0 (float32_ref 16) [16, bytes: 64]
encoding2/w1:0 (float32_ref 3x3x3x16x32) [13824, bytes: 55296]
encoding2/b1:0 (float32_ref 32) [32, bytes: 128]
encoding2/bn1/beta:0 (float32_ref 32) [32, bytes: 128]
encoding2/w2:0 (float32_ref 3x3x3x32x32) [27648, bytes: 110592]
encoding2/b2:0 (float32_ref 32) [32, bytes: 128]
encoding2/bn2/beta:0 (flo

In [10]:
unet.accuracy_summary

<tf.Tensor 'accuracy:0' shape=() dtype=string>