In [None]:
# The goal here is to import a graph of a trained U-Net architecture, restrict the gradient flow through 
# this trained architecture, 
# and build a convLSTM architecture off of the imported graph.

#Relevant Papers:
    # U-Net -->  
        #Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation.
    
    # convLSTM --> 
        # Xingjian SH, Chen Z, Wang H, Yeung DY, Wong WK, Woo WC. Convolutional LSTM network: A machine learning approach for precipitation nowcasting. 
    
    # BDC-LSTM for Segmentation --> 
        #Chen J, Yang L, Zhang Y, Alber M, Chen DZ. Combining fully convolutional and recurrent neural networks for 3d biomedical image segmentation.


import numpy as np
import tensorflow as tf
from scipy import misc
import glob
import imageio
import matplotlib.pyplot as plt
import random
import os

In [None]:
# Get number of .png files in path
def get_num_files(path_ground):

    ground_files=glob.glob(path_ground + "*.png")
    num_files=int(len(glob.glob(path_ground + "*.png")))
    
    return num_files

In [None]:
def fix_dims(image):
    image=np.array(image)
    depth=4
    
    #Start with dim 1
    if image.shape[1]%(2**depth)!=0:
        lb=int(image.shape[1]/(2**depth))
        while image.shape[1]<((2**depth)*(lb+1)):
            image=np.concatenate((image,np.zeros((image.shape[0],1,image.shape[2],image.shape[3]))),axis=1)

    if image.shape[2]%(2**depth)!=0:
        lb=int(image.shape[2]/(2**depth))
        while image.shape[2]<((2**depth)*(lb+1)):
            image=np.concatenate((image,np.zeros((image.shape[0],image.shape[1],1,image.shape[3]))),axis=2)
                    
    return image

In [None]:
def random_translate_seg_version(input_block,ground_truth_block,y_dir,x_dir,reflect_flag):
    # Input_block is [height, width, depth=1]
    # Ground_truth_block is [height, width, depth=1]
    
    imsize=input_block.shape
    temp=np.empty((imsize[0],imsize[1],imsize[2]))
    temp_ground=np.empty((imsize[0],imsize[1],imsize[2]))
    y_abs=np.absolute(y_dir)
    x_abs=np.absolute(x_dir)
    
    # Reflection about y-axis (width)
    if reflect_flag==1:
        input_block=np.flip(input_block,1)
        ground_truth_block=np.flip(ground_truth_block,1)        
    
    # Vertical translation
    if y_dir != 0: 
        height_pad=np.zeros((y_abs,imsize[1],imsize[2]))
        if y_dir>0: # Shift up --> attach zeros to bottom
            temp=np.concatenate((input_block[y_abs:,:,:],height_pad),axis=0)
            temp_ground=np.concatenate((ground_truth_block[y_abs:,:,:],height_pad),axis=0)
        else: # Shift down --> attach zeros to top
            temp=np.concatenate((height_pad,input_block[:(-1*y_abs),:,:]),axis=0)
            temp_ground=np.concatenate((height_pad,ground_truth_block[:(-1*y_abs),:,:]),axis=0)
    else: 
        temp=input_block
        temp_ground=ground_truth_block
        
    # Horizontal translation
    if x_dir!=0:
        width_pad=np.zeros((imsize[0],x_abs,imsize[2]))
        if x_dir>0:
            temp2=np.concatenate((width_pad,temp[:,:(-1*x_abs),:]),axis=1)
            temp2_ground=np.concatenate((width_pad,temp_ground[:,:(-1*x_abs),:]),axis=1)
        else: 
            temp2=np.concatenate((temp[:,x_abs:,:],width_pad),axis=1)
            temp2_ground=np.concatenate((temp_ground[:,x_abs:,:],width_pad),axis=1)
    else:
        temp2=temp
        temp2_ground=temp_ground
        
    # Returns tensors with same dimensions as the inputs
    return temp2, temp2_ground

In [None]:
def random_files(num_files):
    x = ([[i] for i in range(num_files)])
    shuflist=random.sample(x,len(x))
    list_files=[]

    s=''
    for i in range(num_files):
        ID=str(shuflist[i][0])
        while len(ID)<6:
            ID='0'+ID
        list_files.append(ID)
            
    return list_files

In [None]:
# Layer wrappers
def conv_layer(inputs, channels_in, channels_out, stvs, strides=1, scopename="Conv"):
    with tf.name_scope(scopename):
        s=''
        weightname=(scopename,'_weights')
        biasname=(scopename,'_bias')
        
        w=tf.Variable(tf.random_normal([3, 3, channels_in, channels_out],stddev=stvs),name=s.join(weightname))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        
        b=tf.Variable(tf.random_normal([channels_out],stddev=stvs),name=s.join(biasname))
        tf.summary.histogram(("Weight" + scopename),w)
        tf.summary.histogram(("Bias" + scopename),b)
        
        x = tf.nn.conv2d(inputs, w, strides=[1, strides, strides, 1], padding='SAME')
        x = tf.nn.bias_add(x, b)
        #x=tf.contrib.layers.layer_norm(x)
        epsilon = 1e-3
        scale = tf.Variable(tf.ones([x.get_shape()[-1]]))
        beta = tf.Variable(tf.zeros([x.get_shape()[-1]]))
        batch_mean, batch_var = tf.nn.moments(x,[1,2],keep_dims=True)
        x=tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon)
        return tf.nn.relu(x)


def maxpool2d(x, k=2, scopename="Pool"):
    with tf.name_scope(scopename):
        return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],padding='VALID')

def upconv2d(x, channels_in, channels_out, stvs, stride=2, scopename="Upconv"):
    with tf.name_scope(scopename):
        w=tf.Variable(tf.random_normal([2, 2, channels_out, channels_in],stddev=stvs))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        
        x_shape = tf.shape(x)
        output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2]) # [BS doubleheight doubl width  halvedepth]
        return tf.nn.conv2d_transpose(x, w, output_shape, strides=[1, stride, stride, 1], padding='SAME')

def concatenate(in1, in2, scopename="Concat"):
    with tf.name_scope(scopename):
        return tf.concat([in1, in2], 3) 

In [None]:
def convlstm_wrapper(x, trunc_prop, depthstart, cell_state, hidden_state, cell_state_in, hidden_state_in, scopename, cellname):

    with tf.name_scope(scopename):
        
        cell_state_temp_in = tf.placeholder(tf.float32, [1, None, None, depthstart]) # Consistent with input dimensions
        hidden_state_temp_in = tf.placeholder(tf.float32, [1, None, None, depthstart]) # Consistent with input dimensions
        
        cell_state_in.append(cell_state_temp_in)
        hidden_state_in.append(hidden_state_temp_in)
        
        # Initialize output tensor list
        out_list=[];
        
        # Define convolutional parameters 
        # Don't use peepholes for variable input dims, and dont use skip connections for now
        strides=1; forget_bias=1; s='';
        w=tf.Variable(tf.random_normal([3, 3, depthstart*2, depthstart*4],stddev=0.01), name=s.join("Weights_"+scopename))
        tf.add_to_collection("New_Vars", w)
        # tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        b=tf.Variable(tf.random_normal([depthstart*4],stddev=0.01),name=s.join("Bias_"+scopename))
        tf.add_to_collection("New_Vars", b)
        
        hidden_state_temp=hidden_state_temp_in
        cell_state_temp=cell_state_temp_in 
        
        for j in range(trunc_prop):
            
            feature_block = tf.concat([x[j], hidden_state_temp], 3)
            feature_block = tf.nn.conv2d(feature_block, w, strides=[1, strides, strides, 1], padding='SAME')
            feature_block = tf.nn.bias_add(feature_block, b)
            
            input_gate, new_input, forget_gate, output_gate = tf.split(feature_block , 4, axis=-1)
            
            new_cell = tf.nn.sigmoid(forget_gate+forget_bias) * cell_state_temp
            new_cell += tf.nn.sigmoid(input_gate) * tf.nn.tanh(new_input)
            
            out = tf.nn.sigmoid(output_gate) * tf.nn.tanh(new_cell)
            
            hidden_state_temp = out
            cell_state_temp = new_cell
            
            # Append out tensor for this time list
            out_list.append(out)
           
        # Need to store cell_state and hidden_state for next go round
        # also hidden state is the out!
        cell_state.append(cell_state_temp)
        hidden_state.append(hidden_state_temp)
        
        # out_list is a list of length [trunc_prop] tensors of the hidden_states at each time step. 
        # We also want to store the latest hidden_state and cell_state to use for the next trunc_prop iteration
        return out_list, cell_state, hidden_state, cell_state_in, hidden_state_in


In [None]:
# Definition of network
def conv_net(inputdepth, trunc_prop, path_to_trained_model):
    stvs=0.01
    depthstart=64
    num_classes=2
    
    #Import meta_graph
    print("Importing from " + path_to_trained_model + ".meta")
    transfer_saver=tf.train.import_meta_graph((path_to_trained_model + '.meta'))
    
    new_graph = tf.get_default_graph()
    
    # Identify input node from imported graph
    X = new_graph.get_tensor_by_name('Input/Placeholder:0')
    Y = new_graph.get_tensor_by_name('Ground_Truth/Placeholder:0')
    
    #Identify final conv block tensor of imported graph 
    conv9b = new_graph.get_tensor_by_name('conv9b/Relu:0')
    
    #Identify supplemental output from flow graph
    #softmax_supp = new_graph.get_tensor_by_name('Softmax/Reshape_1:0')
    
    #pred_for_plot=tf.split(tf.argmax(softmax_supp, axis=3), trunc_prop, axis=0)   
    #for kk in range(trunc_prop):
    #    tf.summary.image(('Supp_Predict'+ str(kk)), tf.cast(tf.expand_dims(pred_for_plot[kk], 3),tf.float32), inputdepth)
    
    # Cut gradient flow
    conv9b_sg = tf.stop_gradient(conv9b)  # Conv9b is in format [trunc_prop as BS, height, width, depth]
    print(conv9b_sg)
    
    # Now augment the CLSTM component to the graph
    x=tf.expand_dims(conv9b_sg,axis=0) #Now it is [BS, trunc_prop, height, width, depth]
    unstack=tf.unstack(x,trunc_prop,axis=1)
    
    # Initilize list of hidden and cell states and states -- 
    # This list is for keeping track of cell_state, hidden_state for each ConvLSTM that is called in the network
    # One list for current cell and hidden states and one list for the pointers to input placeholders for the flow graph
    cell_state=[]; hidden_state=[]; 
    cell_state_in=[]; hidden_state_in=[];

    scopename="CLSTM1"
    cellname="Cell1"
    with tf.name_scope(scopename):
        outputs, cell_state, hidden_state, cell_state_in, hidden_state_in = convlstm_wrapper(unstack, trunc_prop, depthstart, cell_state, hidden_state, cell_state_in, hidden_state_in, scopename, cellname)
        # states.append(states_temp); cell_state.append(cell_state_temp); hidden_state.append(hidden_state_temp)
    
    scopename="CLSTM2"
    cellname="Cell2"
    with tf.name_scope(scopename):
        outputs, cell_state, hidden_state, cell_state_in, hidden_state_in=convlstm_wrapper(outputs, trunc_prop, depthstart, cell_state, hidden_state, cell_state_in, hidden_state_in, scopename, cellname)
        # states.append(states_temp); cell_state.append(cell_state_temp); hidden_state.append(hidden_state_temp)
        
        # Concatenate the list of time steps back into batch mode --> [trunc_prop as batch, height, width, depth]
        out_map=tf.concat(outputs,0)
        
    out_map2=concatenate(conv9b_sg,out_map,scopename="concat_final")   
    print(out_map2)
         
    # Reduce depth to num_classes
    with tf.name_scope("Logs_CLSTM"):
        w=tf.Variable(tf.random_normal([1, 1, depthstart*2, num_classes], stddev=stvs))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        tf.add_to_collection("New_Vars", w)
        
        b=tf.Variable(tf.random_normal([num_classes],stddev=stvs))
        tf.add_to_collection("New_Vars", b)

        loglayer = tf.nn.conv2d(out_map2, w, strides=[1, 1, 1, 1], padding='SAME')
        loglayer = tf.nn.bias_add(loglayer,b)
        
    return loglayer, cell_state, hidden_state, cell_state_in, hidden_state_in, X, Y, transfer_saver
    

In [None]:
# Dice loss function -- Defined according to the V-net paper by Milletari et al
def dice_loss(logits, onehot_labels):
    with tf.name_scope("Dice_Loss_CLSTM"):

        eps = 1e-5
        prediction = tf.nn.softmax(logits)
        intersection = tf.reduce_sum(prediction * onehot_labels)
        union =  eps + tf.reduce_sum(prediction) + tf.reduce_sum(onehot_labels)
        dice_loss = -(2 * intersection/ (union))

        return dice_loss

In [None]:
def get_one_hot(image, num_classes): # Image is an [m  n] atrix
    
    # Make sure ground truth image is either 0 or 1, not 0 or 255
    if np.amax(image>0):
        image=np.divide(image,np.amax(image))
    image=image.astype(int)
    
    b=np.zeros((image.shape[0],image.shape[1],num_classes))
    #for kk in range(image.shape[0]):
    #    for jj in range(image.shape[1]):
    #        #b[kk, range(image.shape[1]), image[kk,:]]=1
    #        b[kk, jj, image[kk,jj]]=1
    
    b[:, :, 1]=image
    b[:, :, 0]=np.abs(image-1)
    
    # b is the one-hot version of image with dimensions [m n num_classes]
    return b 

In [None]:
def build_graph(input_depth, num_classes, trunc_prop, path_to_model):
    
    #with tf.name_scope("Ground_Truth"):
    #    Y = tf.placeholder(tf.float32, [None, None, None, num_classes])
    
    # Define flow graph
    logits, cell_state, hidden_state, cell_state_in, hidden_state_in, X, Y, transfer_saver = conv_net(input_depth, trunc_prop, path_to_model)

    # Prediction function for evaluating accuracy
    with tf.name_scope("Softmax_CLSTM"):
        prediction = tf.nn.softmax(logits)

    # Define Loss
    with tf.name_scope("Loss_CLSTM"):
        loss_op=dice_loss(logits,Y)
        #regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)
        #reg_variables = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        #reg_term = tf.contrib.layers.apply_regularization(regularizer, reg_variables)
        #loss_op += reg_term
        tf.summary.scalar("Loss_CLSTM",loss_op)

    # Define optimizer
    with tf.name_scope("Optimizer_CLSTM"):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)    
        train_op = optimizer.minimize(loss_op)
        
    with tf.name_scope("Metrics_CLSTM"):
        # Dice similarity coefficient
        eps = 1e-5
        intersection = tf.reduce_sum(prediction * Y)
        union =  eps + tf.reduce_sum(prediction) + tf.reduce_sum(Y)
        dice = (2 * intersection/ (union))
        tf.summary.scalar("Dice_Similarity_Coefficient_CLSTM",dice)

        # IoU
        #iou=tf.metrics.mean_iou(Y,prediction,num_classes)
        #tf.summary.scalar("IoU",iou)
        
        # Images for TensorBoard
        #logs_for_plot=tf.split(tf.argmax(logits ,axis=3), 4, axis=-1)
        #ground_for_plot=tf.split(tf.argmax(Y ,axis=3), 4, axis=-1)
        #slice_for_plot=tf.split(X, 4, axis=-1)
        
        # Images for TensorBoard
        logs_for_plot=tf.split(tf.argmax(logits ,axis=3), trunc_prop, axis=0)
        ground_for_plot=tf.split(tf.argmax(Y ,axis=3),  trunc_prop, axis=0)
        slice_for_plot=tf.split(X,  trunc_prop, axis=0)
        
        for kk in range(trunc_prop):
            tf.summary.image(('Predict'+ str(kk) + "_CLSTM"), tf.cast(tf.expand_dims(logs_for_plot[kk], 3),tf.float32), input_depth)
            tf.summary.image(('Ground'+ str(kk) + "_CLSTM"), tf.cast(tf.expand_dims( ground_for_plot[kk], 3),tf.float32), input_depth) 
            tf.summary.image(('Slice'+ str(kk) + "_CLSTM"), tf.cast(slice_for_plot[kk],tf.float32), input_depth)  
        
        # Confusion Matrix
        batch_confusion = tf.confusion_matrix(tf.reshape(tf.argmax(Y,axis=3),[-1]),tf.reshape(tf.argmax(prediction,axis=3),[-1]),num_classes=num_classes,name='Batch_confusion_CLSTM')

    # Define writer for Tensorboard
    writer=tf.summary.FileWriter("/output/1")
    summ=tf.summary.merge_all()

    # Initialize the variables
    #init = tf.global_variables_initializer()
    init_new_vars = tf.variables_initializer(tf.get_collection("New_Vars"))

    #Define saver for model saver
    saver = tf.train.Saver()
    
    return X, Y, cell_state, hidden_state, cell_state_in, hidden_state_in, logits, prediction, loss_op, train_op, batch_confusion, dice, writer, summ, init_new_vars, saver, transfer_saver


In [None]:
#Set hyperparameters, etc.
learning_rate = 0.00001

num_epochs = 50
batch_size = 1
input_depth = 1
num_classes = 2 

display_step = 480
validation_step = 480
save_step = 960
trunc_prop= 3

In [None]:
# Some flags
mode='j'
plane_view='Sagittal';
#plane_view='Coronal';
#plane_view='Transverse';
if plane_view=='Sagittal':
    x_stvs=15; y_stvs=15; 
else:
    x_stvs=5; y_stvs=5; 



# Distinguish paths between running on local notebook vs. running in the cloud
if mode=='j':    
    imdpath="/path/to/dir/Knee_MRI4/Femur/Train/" + plane_view + "/"
    pxdpath="/path/to/dir/Knee_MRI4/Femur/Train/" + plane_view + "/"

    imdpath_val="/path/to/dir/Knee_MRI4/Femur/Val/" + plane_view + "/"
    pxdpath_val="/path/to/dir/Knee_MRI4/Femur/Val/" + plane_view + "/"
    
    imdpath_test="/path/to/dir/Knee_MRI4/Femur/Test/" + plane_view + "/"
    pxdpath_test="/path/to/dir/Knee_MRI4/Femur/Test/" + plane_view + "/"
    
    # Local path to trained U-Net Architecture onto which the recurrent components are augmented
    path_to_model="/path/to/dir/model_16095_Original"
    
    
    
else:
    imdpath="/mydata/Femur/Train/" + plane_view + "/"
    pxdpath="/mydata/Femur/Train/" + plane_view + "/"

    imdpath_val="/mydata/Femur/Val/" + plane_view + "/"
    pxdpath_val="/mydata/Femur/Val/" + plane_view + "/"
    
    imdpath_test="/mydata/Femur/Test/" + plane_view + "/"
    pxdpath_test="/mydata/Femur/Test/" + plane_view + "/"
    
    # Path to trained U-Net Architecture onto which the recurrent components are augmented
    path_to_model="/models/model_4720_Sagittal"
    
    !ls /mydata
       
train_mode=True
tf.reset_default_graph()
X, Y, cell_state, hidden_state, cell_state_in, hidden_state_in, logits, prediction, loss_op, train_op, batch_confusion, dice, writer, summ, init_new_vars, saver, transfer_saver = build_graph(input_depth, num_classes, trunc_prop, path_to_model)
    
num_subj_train=20; num_subj_val=2; num_subj_test=5;       
# Get num_files in each directory
#num_files = get_num_files(pxdpath)
#num_files_val= get_num_files(pxdpath_val)
#num_files_test = get_num_files(pxdpath_test)
#print(num_files)


# And lastly, get weights from original graph so you can make sure they aren't actually training
#check_weight = tf.get_variable("conv1b/Variable")

In [None]:
def get_pics(path):
    
    num_files=int(len(glob.glob(path + '*.png')))
    file_list=random_files(num_files)
    image=imageio.imread((path + file_list[0] + '.png')) #Import slice
    picsize1=image.shape[0]; picsize2=image.shape[1] 
    images=np.empty((num_files,picsize1,picsize2,1))
    
    for i in range(num_files):
        ID=str(i)
        while len(ID)<6:
            ID='0'+ID
        images[i,:,:,0]=imageio.imread((path + ID + '.png')) #Import slice
    #Return 4d tensor    
    return images

In [None]:
# Train Session
# Send it
ct = 0

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #sess.run(init)
    #Option to restore from model checkpoint to resume training
    #saver.restore(sess, '/models/model_17055_Original')
    
    # Import variables from trained U-Net
    # transfer_saver.restore(sess, (path_to_model + '.data-00000-of-00001'))
    transfer_saver.restore(sess, (path_to_model))
    
    #Initialize the rest of the variables
    sess.run(init_new_vars)
    
    writer.add_graph(sess.graph)

    for epoch in range(num_epochs):
        print("EPOCH " + "{:d}".format(epoch))
        
        for subj in range(1,1+num_subj_train):
            
            #dsc_vec=[]; slice_vec=[];
            
            print("Subject " + str(subj))
            
            # Get stack of inputs -- Returns 4d tensor
            images=get_pics((imdpath  + str(subj) + '/Imd/'))
            images=images-np.mean(images)

            # Get stack of outputs -- Returns 4d tensor
            grounds=get_pics((pxdpath  + str(subj) + '/Pxd/'))
            
            images=fix_dims(images)
            grounds=fix_dims(grounds)
            
            # Randomly sample parameters for translation
            x_t=int(x_stvs*np.random.randn()); y_t=int(y_stvs*np.random.randn());
            # Randomly sample reflection flag
            reflect_flag=np.random.binomial(1,0.5,1)
            
            bxx=np.empty((images.shape[0],images.shape[1],images.shape[2],1)) # Intialize np matrix to hold all images
            byy=np.empty((grounds.shape[0],grounds.shape[1],grounds.shape[2],num_classes)) # Intialize np matrix to hold all images
            for j in range(images.shape[0]): # Apply translations to all images
                images[j,:,:,:], grounds[j,:,:,:] = random_translate_seg_version(images[j,:,:,:],grounds[j,:,:,:],x_t,y_t,reflect_flag)
            
            for j in range(images.shape[0]): # One-hot
                byy[j,:,:,:] = get_one_hot(grounds[j,:,:,0], num_classes)
                bxx[j,:,:,:] = images[j,:,:,:]
            
            # Initialize cell and hidden for this subject
            current_cell=[]; current_hidden=[];
            for kk in range(2):
                current_cell.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))
                current_hidden.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))
                         
            i=0
            # Because we are using a static_rnn, iterate through the subject [trunc_prop] slices at a time
            while i+trunc_prop<=images.shape[0]:
                bx=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], input_depth))
                by=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], num_classes))

                bx=bxx[i:i+trunc_prop,:,:,:] #Bx gets all the images in that batch
                by=byy[i:i+trunc_prop,:,:,:]     
                #by[0,:,:,:]=byy[i+trunc_prop-1,:,:,:] #By gets only last image
                
                if epoch > 0 or np.max(by[:,:,:,1])>0: #Only train on positive examples for the first epoch
           
                    # sess.run(train_op,feed_dict={X: bx, Y: by, cell_state[0]: current_cell[0], hidden_state[0]: current_hidden[0], cell_state[1]: current_cell[1], hidden_state[1]: current_hidden[1],cell_state[2]: current_cell[2], hidden_state[2]: current_hidden[2],cell_state[3]: current_cell[3], hidden_state[3]: current_hidden[3]})
                    sess.run(train_op,feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})

                    # loss, acc, summary, pred=sess.run([loss_op,accuracy,summ,prediction],feed_dict={X: bx, Y: by, cell_state[0]: current_cell[0], hidden_state[0]: current_hidden[0],cell_state[1]: current_cell[1], hidden_state[1]: current_hidden[1],cell_state[2]: current_cell[2], hidden_state[2]: current_hidden[2],cell_state[3]: current_cell[3], hidden_state[3]: current_hidden[3]})
                    DSC, summary = sess.run([dice, summ],feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})

                    writer.add_summary(summary, ct)
                
                # After training iteration, retrieve cell and hidden states for next iteration in current subject
                current_cell, current_hidden = sess.run([cell_state, hidden_state],feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})


                if i==30:
                    # Verify data by displaying ground truth with original image that it is trying to predict
                    f , (ax1, ax2) = plt.subplots(1,2,sharey=True, figsize=(6,2))
                    ax1.set_title("Slice")
                    ax1.imshow(np.squeeze(bx[-1,:,:,0]),cmap='gray')
                    ax2.imshow(np.squeeze(by[-1,:,:,1]),cmap='gray')
                    ax2.set_title("Ground Truth")
                    f.subplots_adjust(hspace=5.0)
                    plt.show()
                    
                    #plt.plot(prediction_vec,'r',logits_vec,'g',ground_vec,'k--')
                    #plt.ylabel('Subject Predictions')
                    #plt.show()

                i+=trunc_prop
                ct+=trunc_prop
                         
            # VALIDATION
            if subj>0 and subj%20==0:  
                s=''
                checkpointnamelist=('./model_',str(ct),'_',plane_view,'_CLSTM')
                checkpointname= s.join(checkpointnamelist)
                save_path = saver.save(sess,checkpointname)
                print("Model saved in file: %s" % save_path)
                                 
                
                
                
                
                for subj_val in range(1,1+num_subj_val):
                    DSC_val_vec=[];
                    print("Validation Subject " + str(subj_val))
                         
                    # Get stack of inputs -- Returns 4d tensor
                    images=get_pics((imdpath_val  + str(subj_val) + '/Imd/'))
                    images=images-np.mean(images)
                    # Get stack of outputs -- Returns 4d tensor
                    grounds=get_pics((pxdpath_val  + str(subj_val) + '/Pxd/'))

                    images=fix_dims(images)
                    grounds=fix_dims(grounds)

                    bxx=np.empty((images.shape[0],images.shape[1],images.shape[2],1)) # Intialize np matrix to hold all images
                    byy=np.empty((grounds.shape[0],grounds.shape[1],grounds.shape[2],num_classes)) # Intialize np matrix to hold all images
                  
                    for j in range(images.shape[0]): # One-hot 
                        byy[j,:,:,:] = get_one_hot(grounds[j,:,:,0], num_classes)
                        bxx[j,:,:,:] = images[j,:,:,:]

                    # Initialize cell and hidden for this subject
                    current_cell=[]; current_hidden=[];
                    for kk in range(2):
                        current_cell.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))
                        current_hidden.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))

                    i=0
                    # Because we are using a static_rnn, iterate through the subject [trunc_prop] slices at a time
                    while i+trunc_prop<=images.shape[0]:
                        bx=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], input_depth))
                        by=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], num_classes))

                        bx=bxx[i:i+trunc_prop,:,:,:] # Bx gets all the images in that batch
                        by=byy[i:i+trunc_prop,:,:,:]     
                        #by[0,:,:,:]=byy[i+trunc_prop-1,:,:,:] #By gets only last image

                        # loss, acc, summary, pred=sess.run([loss_op,accuracy,summ,prediction],feed_dict={X: bx, Y: by, cell_state[0]: current_cell[0], hidden_state[0]: current_hidden[0],cell_state[1]: current_cell[1], hidden_state[1]: current_hidden[1],cell_state[2]: current_cell[2], hidden_state[2]: current_hidden[2],cell_state[3]: current_cell[3], hidden_state[3]: current_hidden[3]})
                        DSC=sess.run(dice,feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})
                        DSC_val_vec.append(DSC)
                        
                         # After training iteration, retrieve cell and hidden states for next iteration in current subject
                        current_cell, current_hidden =sess.run([cell_state, hidden_state],feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})

                        # Unpack the state
                        # For kk in range(len(current_hidden)):
                        #    Current_cell[kk], current_hidden[kk] = current_total[kk]

                        if i==30:
                            #Verify data by displaying ground truth with original image that it is trying to predict
                            f , (ax1, ax2) = plt.subplots(1,2,sharey=True, figsize=(6,2))
                            ax1.set_title("Slice")
                            ax1.imshow(np.squeeze(bx[-1,:,:,0]),cmap='gray')
                            ax2.imshow(np.squeeze(by[-1,:,:,1]),cmap='gray')
                            ax2.set_title("Ground Truth")
                            f.subplots_adjust(hspace=5.0)
                            plt.show()
                                    
                        i+=trunc_prop
                        ct+=trunc_prop
                        
                    print("Validation Subject " + str(subj_val) + " Dice Similarity Coeff = " + str(np.mean(DSC_val_vec)) + " +- " + str(np.std(DSC_val_vec)))
                    print(' ')
                    

print("We out here")


In [1]:
#Test session

ct = 0
with tf.Session() as sess:
    #Test session
    sess.run(tf.global_variables_initializer())
    #sess.run(init)
    #Option to restore from model checkpoint to resume training
    #saver.restore(sess, '/models/model_17055_Original')

    # Import variables from trained U-Net
    # transfer_saver.restore(sess, (path_to_model + '.data-00000-of-00001'))
    #transfer_saver.restore(sess, (path_to_model))

    #Initialize the rest of the variables
    #sess.run(init_new_vars)
    
    saver.restore(sess, ('/models2/model_23094_' + plane_view + "_CLSTM"))
    DSC_test_vec_total=[];
    for subj_test in range(1,1+num_subj_test):
        DSC_test_vec=[];
        print("Test Subject " + str(subj_test))

        # Get stack of inputs -- Returns 4d tensor
        images=get_pics((imdpath_test  + str(subj_test) + '/Imd/'))
        images=images-np.mean(images)
        # Get stack of outputs -- Returns 4d tensor
        grounds=get_pics((pxdpath_test  + str(subj_test) + '/Pxd/'))

        images=fix_dims(images)
        grounds=fix_dims(grounds)

        bxx=np.empty((images.shape[0],images.shape[1],images.shape[2],1)) # Intialize np matrix to hold all images
        byy=np.empty((grounds.shape[0],grounds.shape[1],grounds.shape[2],num_classes)) # Intialize np matrix to hold all images

        for j in range(images.shape[0]): # One-hot 
            byy[j,:,:,:] = get_one_hot(grounds[j,:,:,0], num_classes)
            bxx[j,:,:,:] = images[j,:,:,:]

        # Initialize cell and hidden for this subject
        current_cell=[]; current_hidden=[];
        for kk in range(2):
            current_cell.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))
            current_hidden.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))

        i=0
        # Because we are using a static_rnn, iterate through the subject [trunc_prop] slices at a time
        while i+trunc_prop<=images.shape[0]:
            bx=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], input_depth))
            by=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], num_classes))

            bx=bxx[i:i+trunc_prop,:,:,:] # Bx gets all the images in that batch
            by=byy[i:i+trunc_prop,:,:,:]     
            #by[0,:,:,:]=byy[i+trunc_prop-1,:,:,:] #By gets only last image

            # loss, acc, summary, pred=sess.run([loss_op,accuracy,summ,prediction],feed_dict={X: bx, Y: by, cell_state[0]: current_cell[0], hidden_state[0]: current_hidden[0],cell_state[1]: current_cell[1], hidden_state[1]: current_hidden[1],cell_state[2]: current_cell[2], hidden_state[2]: current_hidden[2],cell_state[3]: current_cell[3], hidden_state[3]: current_hidden[3]})
            DSC=sess.run(dice,feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})
            DSC_test_vec.append(DSC)
            DSC_test_vec_total.append(DSC)

            # After training iteration, retrieve cell and hidden states for next iteration in current subject
            current_cell, current_hidden =sess.run([cell_state, hidden_state],feed_dict={X: bx, Y: by, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})

            # Unpack the state
            # For kk in range(len(current_hidden)):
            #    Current_cell[kk], current_hidden[kk] = current_total[kk]

            if i==30:
                #Verify data by displaying ground truth with original image that it is trying to predict
                f , (ax1, ax2) = plt.subplots(1,2,sharey=True, figsize=(6,2))
                ax1.set_title("Slice")
                ax1.imshow(np.squeeze(bx[-1,:,:,0]),cmap='gray')
                ax2.imshow(np.squeeze(by[-1,:,:,1]),cmap='gray')
                ax2.set_title("Ground Truth")
                f.subplots_adjust(hspace=5.0)
                plt.show()

            i+=trunc_prop
            ct+=trunc_prop

        print("Test Subject " + str(subj_test) + " Dice Similarity Coeff = " + str(np.mean(DSC_test_vec)) + " +- " + str(np.std(DSC_test_vec)))
        print(' ')
        
print("Test Statistics: Dice Similarity Coeff = " + str(np.mean(DSC_test_vec_total)) + " +- " + str(np.std(DSC_test_vec_total)))
               





In [None]:
def padding_function(images,trunc_prop):
    if images.shape[0]%trunc_prop!=0:
        while images.shape[0]%trunc_prop!=0:
            temp=np.empty((images.shape[0]+1,images.shape[1],images.shape[2],1))
            temp[:images.shape[0],:,:,:]=images
            temp[-1,:,:,:]=images[-1,:,:,:] # Pad the last image
            images=temp
            
    return images
    

In [None]:
#Inference session
ct = 0
with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())
    #sess.run(init)
    #Option to restore from model checkpoint to resume training
    #saver.restore(sess, '/models/model_17055_Original')

    # Import variables from trained U-Net
    # transfer_saver.restore(sess, (path_to_model + '.data-00000-of-00001'))
    #transfer_saver.restore(sess, (path_to_model))

    #Initialize the rest of the variables
    #sess.run(init_new_vars)
    saver.restore(sess, ('/models2/model_23094_' + plane_view + "_CLSTM"))
    
    pathout='/output/Inference/'
    if not (os.path.isdir(pathout)):   
        os.mkdir(pathout)
    
    #saver.restore(sess, '/models/model_17055_Original')
    DSC_test_total=[];
    for subj_test in range(1,1+num_subj_test):
        #Set up output path
        
        pathout='/output/Inference/' + str(subj_test) + '/'
        if not (os.path.isdir(pathout)):
            os.mkdir(pathout)   
        pathout='/output/Inference/' + str(subj_test) + '/'  + plane_view + '/'
        if not (os.path.isdir(pathout)):
            os.mkdir(pathout)
        
        pathout='/output/Inference/' + str(subj_test) + '/' + plane_view + '/Inf/'
        if not (os.path.isdir(pathout)):
            os.mkdir(pathout)
        
        print("Test Subject " + str(subj_test))
        print(pathout)

        # Get stack of inputs -- Returns 4d tensor
        images=get_pics((imdpath_test  + str(subj_test) + '/Imd/'))
        original_dims=images.shape
        images=images-np.mean(images)
        images=fix_dims(images)
        images=padding_function(images, trunc_prop) # Function to pad sequence
        bxx=np.empty((images.shape[0],images.shape[1],images.shape[2],1)) # Intialize np matrix to hold all images
        for j in range(images.shape[0]): # One-hot 
            bxx[j,:,:,:] = images[j,:,:,:]

        # Initialize cell and hidden for this subject
        current_cell=[]; current_hidden=[];
        for kk in range(2):
            current_cell.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))
            current_hidden.append(np.zeros((1, bxx.shape[1], bxx.shape[2], 64),dtype=np.float32))

        i=0
        # Because we are using a static_rnn, iterate through the subject [trunc_prop] slices at a time
        while i+trunc_prop<=images.shape[0]:
            bx=np.empty((trunc_prop, bxx.shape[1], bxx.shape[2], input_depth))
            bx=bxx[i:i+trunc_prop,:,:,:] # Bx gets all the images in that batch  

            probability_map=sess.run(prediction,feed_dict={X: bx, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})
            if i==0:
                print(probability_map.shape)
            
            # After iteration, retrieve cell and hidden states for next iteration in current subject
            current_cell, current_hidden =sess.run([cell_state, hidden_state],feed_dict={X: bx, cell_state_in[0]: current_cell[0], hidden_state_in[0]: current_hidden[0], cell_state_in[1]: current_cell[1], hidden_state_in[1]: current_hidden[1]})

            # Export images
            for kk in range(trunc_prop):
                if i+kk<original_dims[0]: # Make sure you don't export the sequence padding
                    fileID=str(i+kk)
                    IDlength=6
                    while len(fileID)<IDlength:
                        fileID='0'+fileID
                    if i==0:
                        print((pathout + fileID + '.png'))

                    imageio.imwrite(pathout + fileID + ".png",np.squeeze(probability_map[kk,:original_dims[1],:original_dims[2],1]))

            
            
            if i==30:
                #Verify data by displaying ground truth with original image that it is trying to predict
                f , (ax1, ax2) = plt.subplots(1,2,sharey=True, figsize=(6,2))
                ax1.set_title("Slice")
                ax1.imshow(np.squeeze(bx[-1,:,:,0]),cmap='gray')
                ax2.imshow(np.squeeze(bx[-1,:,:,0]),cmap='gray')
                ax2.set_title("Ground Truth")
                f.subplots_adjust(hspace=5.0)
                plt.show()

            i+=trunc_prop
            ct+=trunc_prop


print("We out here")