In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import random
import os
import sys
import itertools
sys.path.append('src/')
import nn
import process_data
import cv2

from __future__ import division, print_function, absolute_import
from sklearn.metrics import confusion_matrix
import scipy.sparse
from scipy.misc import imrotate
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import rotate
from skimage import exposure
from skimage.io import imread, imsave
os.environ["CUDA_VISIBLE_DEVICES"]="0"

from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
print(local_device_protos)

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 2961736824169864948
]


# Modeling Data Setup

### Load Data from File

In [2]:
def one_hot_encode(L, class_labels):
    """
    2D array (image) of segmentation labels -> .npy
    # One Hot Encode the label 2d array -> .npy files with dim (h, w, len(class_labels))
    # num classes will be 8? but currently dynamically allocated based on num colors in all scans.
    """
    h, w = L.shape  # Should be 482, 395
    try:
        encoded = np.array([list(map(class_labels.index, L.flatten()))])
    except Exception as e:
        print(e)
    L = encoded.reshape(h, w)

    Lhot = np.zeros((L.shape[0], L.shape[1], len(class_labels)))
    for i in range(L.shape[0]):
        for j in range(L.shape[1]):
            Lhot[i,j,L[i,j]] = 1
    return Lhot  # Should be shape (482, 395, 9)
    
def uncode_one_hot(npy_file):
    """
    .npy file -> JPEG
    """
    pass

def show_images(images, cols = 1, titles = None):
    """Display a list of images in a single figure with matplotlib.
    
    Parameters
    ---------
    images: List of np.arrays compatible with plt.imshow.
    
    cols (Default = 1): Number of columns in figure (number of rows is 
                        set to np.ceil(n_images/float(cols))).
    
    titles: List of titles corresponding to each image. Must have
            the same length as titles.
    """
    assert((titles is None)or (len(images) == len(titles)))
    n_images = len(images)
    if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
    fig = plt.figure()
    for n, (image, title) in enumerate(zip(images, titles)):
        a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1)
#         if image.ndim == 2:
#             plt.gray()
        plt.imshow(image)
        a.set_title(title)
    fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
    plt.show()
    
def load_sparse_csr(filename):
    # Sparse matrix reading function to read our raw .npz files
    assert filename.endswith('.npz')
    loader = np.load(filename)  # filename must end with .npz
    return scipy.sparse.csr_matrix((loader['data'], loader['indices'], loader['indptr']),
                      shape=loader['shape'])

def get_raw_pixel_classes():
    import nibabel as nib
    base_data_dir = "/Users/kireet/ucb/HART Research/Muscle Segmentation/raw_nifti_scan"
    example_segmentation = os.path.join(base_data_dir, 'trial10_30_w1_seg2_TRANS.nii')
    scan_voxel = nib.load(example_segmentation)
    struct_arr = scan_voxel.get_data()
    n, h, w = struct_arr.shape
    class_labels = list(np.unique(struct_arr))

In [3]:
raw_pixel_classes =[0, 7, 8, 9, 45, 51, 52, 53, 68]  # Expected raw grayscale values for each pixel
directory = "/Users/kireet/ucb/HART Research/Muscle Segmentation/cleaned_images_test"
filenames = []  # Stores all filenames
raw_images = []  # Stores X (Raw cross section images as 2D np.ndarray)
segmentations = []  # Stores Y (Labeled/Segmented image as one-hot-encoded NumClasses-D np.ndarray)

for folder in os.listdir(directory):
    class_labels = set()
    if not folder.startswith('.'):
        path = os.path.join(directory, folder)
        print(path)
        files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and not f.startswith('.')])
        
        # Class label sanity check
        for f in files:
            if 'label' in f:
                img = imread(os.path.join(path, f), flatten=True)
                class_labels = class_labels.union(np.unique(img))
        if not class_labels.issubset(raw_pixel_classes):
            print("Class labels found in labeled images do not match the expected classes for scan {}".format(folder))
            print("Expected {}".format(raw_pixel_classes))
            print("Received {}".format(sorted(class_labels)))
            break
        
        # Sanity image read and show some images in pairs (play with the range inputs)
#         for f in range(0, len(files), 2):
#             label_name = files[f]
#             raw_name = files[f+1]
#             label_img = imread(os.path.join(path, label_name), flatten=True)
#             raw_img = load_sparse_csr(os.path.join(path, raw_name)).toarray()  # Load sparse csr mat img -> to 2D numpy array
#             show_images([label_img, raw_img], titles=[label_name, raw_name])
        
        # Set up Datasets (X, Y) pairs of data ->
        # files are sorted by the name: either '#_label' or '#_raw'
        for f in files:
            if 'raw' in f:
                raw_images.append(img)
            elif 'label' in f:
                encoded_img = one_hot_encode(img, raw_pixel_classes)
                segmentations.append(encoded_img)
            filenames.append(os.path.join(folder, f))
# print(filenames)

#     image = imresize(imread(directory + folder + '/' + folder + '.jpg', flatten = True),(h, w))
#     images.append(image)
#     filenames.append(folder)
#     seg = np.load(directory+folder+'/seg.npy')
#     temp = np.zeros((h,w,1))
#     temp[:,:,1] = imresize(seg[:,:,1],(h,w), interp='nearest')/255.0
#     segmentations.append(temp)
            
# images = np.array(images)
# segmentations = np.round(np.array(segmentations)).astype('uint8')


# study_num = int(2)
# train_lst = np.load('data/splits/train_lst_' + str(study_num) + '.npy')
# val_lst = np.load('data/splits/val_lst_' + str(study_num) + '.npy')

/Users/kireet/ucb/HART Research/Muscle Segmentation/cleaned_images_test/trial10_30_w1


### Split into Training, Cross Validation and Test sets

In [4]:
"""
TODO: Same Scan cannot be used across Train, Validation and Test sets
TODO: Different weight conditions and angles may be used to segment other raw_scans
TODO: Bounding Box, image resizing, padding edges
"""
# raw_images holds our X data
# segmentations holds out Y data
h, w = 482, 395
x_train, y_train = [], []
x_val, y_val = [], []
x_test, y_test = [], []

percent_train, percent_val, percent_test = 70, 0, 30
num_train = np.round(len(raw_images) * percent_train/100).astype(np.int)
num_val = np.round(num_train + len(raw_images) * percent_val/100).astype(np.int)
num_test = np.round(num_val + len(raw_images) * percent_test/100).astype(np.int)

assert len(raw_images) == len(segmentations)
rand_indices = list(np.random.choice(len(raw_images), len(raw_images), replace=False))

for i in rand_indices[:num_train]:
    x_train.append(raw_images[i])
    y_train.append(segmentations[i])
for j in rand_indices[num_train:num_val]:
    x_val.append(raw_images[j])
    y_val.append(segmentations[j])
for k in rand_indices[num_val:num_test]:
    x_test.append(raw_images[k])
    y_test.append(segmentations[k])
                
x_train = np.array(x_train).reshape((len(x_train),h,w,1))
x_test = np.array(x_test).reshape((len(x_test),h,w,1))
y_train = np.array(y_train)
y_test = np.array(y_test)

print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)

# Fix data padding to create square 482 by 482 matrix
npad = ((0, 0), (15, 15), (59, 58), (0, 0))
x_train = np.pad(x_train, pad_width=npad, mode='constant', constant_values=0)
x_test = np.pad(x_test, pad_width=npad, mode='constant', constant_values=0)
y_train = np.pad(y_train, pad_width=npad, mode='constant', constant_values=0)
y_test = np.pad(y_test, pad_width=npad, mode='constant', constant_values=0)

print()
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)
# (x, 482, 482, 1)
# (x, 482, 482, 9)

(6, 482, 395, 1)
(2, 482, 395, 1)
(6, 482, 395, 9)
(2, 482, 395, 9)

(6, 512, 512, 1)
(2, 512, 512, 1)
(6, 512, 512, 9)
(2, 512, 512, 9)


# U-Net Model

In [5]:
h, w = 512, 512
class Unet(object):        
    def __init__(self, mean, weight_decay, learning_rate, label_dim = 8, dropout = 0.9):
        self.x_train = tf.placeholder(tf.float32, [None, h, w, 1])
        self.y_train = tf.placeholder(tf.float32, [None, h, w, 9])
        self.x_test = tf.placeholder(tf.float32, [None, h, w, 1])
        self.y_test = tf.placeholder(tf.float32, [None, h, w, 9])
        
        self.label_dim = label_dim
        self.weight_decay = weight_decay
        self.learning_rate = learning_rate
        self.dropout = dropout

        self.output = self.unet(self.x_train, mean, keep_prob=self.dropout)
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = self.output, labels = self.y_train))
        self.opt = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
        
        self.pred = self.unet(self.x_test, mean, reuse = True, keep_prob = 1.0)
        self.loss_summary = tf.summary.scalar('loss', self.loss)
    
    # Gradient Descent on mini-batch
    def fit_batch(self, sess, x_train, y_train):
        _, loss, loss_summary = sess.run((self.opt, self.loss, self.loss_summary), feed_dict={self.x_train: x_train, self.y_train: y_train})
        return loss, loss_summary
    
    def predict(self, sess, x):
        prediction = sess.run((self.pred), feed_dict={self.x_test: x})
        return prediction

    
    def unet(self, input, mean, keep_prob = 0.9, reuse = None):
        with tf.variable_scope('vgg', reuse=reuse):
            input = input - mean
            pool_ = lambda x: nn.max_pool(x, 2, 2)
            conv_ = lambda x, output_depth, name, padding = 'SAME', relu = True, filter_size = 3: nn.conv(x, filter_size, output_depth, 1, self.weight_decay, 
                                                                                                           name=name, padding=padding, relu=relu)
            deconv_ = lambda x, output_depth, name: nn.deconv(x, 2, output_depth, 2, self.weight_decay, name=name)
            
            conv_1_1 = conv_(input, 64, 'conv1_1')
            conv_1_2 = conv_(conv_1_1, 64, 'conv1_2')

            pool_1 = pool_(conv_1_2)

            conv_2_1 = conv_(pool_1, 128, 'conv2_1')
            conv_2_2 = conv_(conv_2_1, 128, 'conv2_2')

            pool_2 = pool_(conv_2_2)

            conv_3_1 = conv_(pool_2, 256, 'conv3_1')
            conv_3_2 = conv_(conv_3_1, 256, 'conv3_2')

            pool_3 = pool_(conv_3_2)

            conv_4_1 = conv_(pool_3, 512, 'conv4_1')
            conv_4_2 = conv_(conv_4_1, 512, 'conv4_2')

            pool_4 = pool_(conv_4_2)

            conv_5_1 = conv_(pool_4, 1024, 'conv5_1')
            conv_5_2 = conv_(conv_5_1, 1024, 'conv5_2')
            
            pool_5 = pool_(conv_5_2)
            
            conv_6_1 = tf.nn.dropout(conv_(pool_5, 2048, 'conv6_1'), keep_prob)
            conv_6_2 = tf.nn.dropout(conv_(conv_6_1, 2048, 'conv6_2'), keep_prob)
            
            up_7 = tf.concat([deconv_(conv_6_2, 1024, 'up7'), conv_5_2], 3)  # Error here rn
            
            conv_7_1 = conv_(up_7, 1024, 'conv7_1')
            conv_7_2 = conv_(conv_7_1, 1024, 'conv7_2')

            up_8 = tf.concat([deconv_(conv_7_2, 512, 'up8'), conv_4_2], 3)
            
            conv_8_1 = conv_(up_8, 512, 'conv8_1')
            conv_8_2 = conv_(conv_8_1, 512, 'conv8_2')

            up_9 = tf.concat([deconv_(conv_8_2, 256, 'up9'), conv_3_2], 3)
            
            conv_9_1 = conv_(up_9, 256, 'conv9_1')
            conv_9_2 = conv_(conv_9_1, 256, 'conv9_2')

            up_10 = tf.concat([deconv_(conv_9_2, 128, 'up10'), conv_2_2], 3)
            
            conv_10_1 = conv_(up_10, 128, 'conv10_1')
            conv_10_2 = conv_(conv_10_1, 128, 'conv10_2')

            up_11 = tf.concat([deconv_(conv_10_2, 64, 'up11'), conv_1_2], 3)
            
            conv_11_1 = conv_(up_11, 64, 'conv11_1')
            conv_11_2 = conv_(conv_11_1, 64, 'conv11_2')
            
            conv_12 = conv_(conv_11_2, 9, 'conv12_2', filter_size = 1, relu = False)
            return conv_12

In [6]:
# Hyperparameters
mean = 24
weight_decay = 1e-6
learning_rate = 1e-4
label_dim = 8
maxout = False

# Create TF graph and initialize variables
tf.reset_default_graph()
sess = tf.Session()
model = Unet(mean, weight_decay, learning_rate, label_dim , dropout = 0.5)
sess.run(tf.global_variables_initializer())

In [7]:
# Restore old model
# saver = tf.train.Saver()
# saver.restore(sess, '/media/deoraid03/jeff/models/a4c_experiments/deep_256_2')

In [8]:
# Train Model
nn.train(sess, model, x_train, y_train, x_test, y_test, epochs = 1, batch_size = 1)

Epoch 0 | Iter 5 | Loss: 1.694 | Data: 5/6 | Time 1.2e+02     

  dice = 2*np.sum(overlap)/(np.sum(gt) + np.sum(pred))


Epoch 0 | Iter 5 | Loss: 1.7 | Acc: [ 0.       nan  0.     0.     0.     0.696  0.014  0.   ] | Time 1.2e+02     


In [9]:
# IOU Accuracies for each label
nn.validate(sess, model, x_test, y_test)

[0.0, nan, 0.0, 0.0, 0.0, 0.69592735277930651, 0.013969732246798603, 0.0]

In [None]:
# Save model
# saver = tf.train.Saver()
# saver.save(sess, '/media/deoraid03/jeff/models/a4c_experiments/deep_256_2')