In [None]:
import numpy as np
import tensorflow as tf
from scipy import misc
import glob
import imageio
import matplotlib.pyplot as plt
import random
import os

# Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation. ...
# InInternational Conference on Medical image computing and computer-assisted intervention 2015 Oct 5 (pp. 234-241). Springer, Cham.

In [None]:
# Get number of .png files in path
def get_num_files(path_ground):

    ground_files=glob.glob(path_ground + "*.png")
    num_files=int(len(glob.glob(path_ground + "*.png")))
    
    return num_files

In [None]:
def fix_dims(image):
    image=np.array(image)
    depth=4
    #Start with dim 1
    if image.shape[0]%(2**depth)!=0:
        lb=int(image.shape[0]/(2**depth))
        while image.shape[0]<((2**depth)*(lb+1)):
            image=np.concatenate((image,np.zeros((1,image.shape[1]))),axis=0)

    if image.shape[1]%(2**depth)!=0:
        lb=int(image.shape[1]/(2**depth))
        while image.shape[1]<((2**depth)*(lb+1)):
            image=np.concatenate((image,np.zeros((image.shape[0],1))),axis=1)
                    
    return image
    
        
    

In [None]:
def random_translate_seg_version(input_block,ground_truth_block,y_dir,x_dir,reflect_flag):
    # Input_block is [height, width, depth=1]
    # Ground_truth_block is [height, width, depth=1]
    
    imsize=input_block.shape
    temp=np.empty((imsize[0],imsize[1],imsize[2]))
    temp_ground=np.empty((imsize[0],imsize[1],imsize[2]))
    y_abs=np.absolute(y_dir)
    x_abs=np.absolute(x_dir)
    
    # Reflection about y-axis (width)
    if reflect_flag==1:
        input_block=np.flip(input_block,1)
        ground_truth_block=np.flip(ground_truth_block,1)        
    
    # Vertical translation
    if y_dir != 0: 
        height_pad=np.zeros((y_abs,imsize[1],imsize[2]))
        if y_dir>0: # Shift up --> attach zeros to bottom
            temp=np.concatenate((input_block[y_abs:,:,:],height_pad),axis=0)
            temp_ground=np.concatenate((ground_truth_block[y_abs:,:,:],height_pad),axis=0)
        else: # Shift down --> attach zeros to top
            temp=np.concatenate((height_pad,input_block[:(-1*y_abs),:,:]),axis=0)
            temp_ground=np.concatenate((height_pad,ground_truth_block[:(-1*y_abs),:,:]),axis=0)
    else: 
        temp=input_block
        temp_ground=ground_truth_block
        
    # Horizontal translation
    if x_dir!=0:
        width_pad=np.zeros((imsize[0],x_abs,imsize[2]))
        if x_dir>0:
            temp2=np.concatenate((width_pad,temp[:,:(-1*x_abs),:]),axis=1)
            temp2_ground=np.concatenate((width_pad,temp_ground[:,:(-1*x_abs),:]),axis=1)
        else: 
            temp2=np.concatenate((temp[:,x_abs:,:],width_pad),axis=1)
            temp2_ground=np.concatenate((temp_ground[:,x_abs:,:],width_pad),axis=1)
    else:
        temp2=temp
        temp2_ground=temp_ground
        
    # Returns tensors with same dimensions as the inputs
    return temp2, temp2_ground

In [None]:
def random_files(num_files):
    x = ([[i] for i in range(num_files)])
    shuflist=random.sample(x,len(x))
    list_files=[]

    s=''
    for i in range(num_files):
        ID=str(shuflist[i][0])
        while len(ID)<6:
            ID='0'+ID
        list_files.append(ID)
            
    return list_files

In [None]:
# Layer wrappers
def conv_layer(inputs, channels_in, channels_out, stvs, strides=1, scopename="Conv"):
    with tf.name_scope(scopename):
        s=''
        weightname=(scopename,'_weights')
        biasname=(scopename,'_bias')
        
        w=tf.Variable(tf.random_normal([3, 3, channels_in, channels_out],stddev=stvs),name=s.join(weightname))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        
        b=tf.Variable(tf.random_normal([channels_out],stddev=stvs),name=s.join(biasname))
        tf.summary.histogram(("Weight" + scopename),w)
        tf.summary.histogram(("Bias" + scopename),b)
        
        x = tf.nn.conv2d(inputs, w, strides=[1, strides, strides, 1], padding='SAME')
        x = tf.nn.bias_add(x, b)
        #x=tf.contrib.layers.layer_norm(x)
        epsilon = 1e-3
        scale = tf.Variable(tf.ones([x.get_shape()[-1]]))
        beta = tf.Variable(tf.zeros([x.get_shape()[-1]]))
        batch_mean, batch_var = tf.nn.moments(x,[0,1,2])
        x=tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon)
        return tf.nn.relu(x)


def maxpool2d(x, k=2, scopename="Pool"):
    with tf.name_scope(scopename):
        return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],padding='VALID')

def upconv2d(x, channels_in, channels_out, stvs, stride=2, scopename="Upconv"):
    with tf.name_scope(scopename):
        w=tf.Variable(tf.random_normal([2, 2, channels_out, channels_in],stddev=stvs))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        
        x_shape = tf.shape(x)
        output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2]) # [BS doubleheight doubl width  halvedepth]
        return tf.nn.conv2d_transpose(x, w, output_shape, strides=[1, stride, stride, 1], padding='SAME')

def concatenate(in1, in2, scopename="Concat"):
    with tf.name_scope(scopename):
        return tf.concat([in1, in2], 3) 

In [None]:
# Definition of network
def conv_net(x, inputdepth):
    stvs=0.01
    depthstart=64
    num_classes=2

    conv1a = conv_layer(x, inputdepth, depthstart, stvs,scopename="conv1a") 
    conv1b = conv_layer(conv1a, depthstart, depthstart, stvs, scopename="conv1b")
    pooled1 = maxpool2d(conv1b,scopename="pooled1")
        
    conv2a = conv_layer(pooled1, depthstart, depthstart*2, stvs,scopename="conv2a") 
    conv2b = conv_layer(conv2a, depthstart*2, depthstart*2, stvs, scopename="conv2b")
    pooled2 = maxpool2d(conv2b,scopename="pooled2")
    
    conv3a = conv_layer(pooled2, depthstart*2, depthstart*4, stvs,scopename="conv3a") 
    conv3b = conv_layer(conv3a, depthstart*4, depthstart*4, stvs, scopename="conv3b")
    pooled3 = maxpool2d(conv3b,scopename="pooled3")
    
    conv4a = conv_layer(pooled3, depthstart*4, depthstart*8, stvs,scopename="conv4a") 
    conv4b = conv_layer(conv4a, depthstart*8, depthstart*8, stvs, scopename="conv4b")
    pooled4 = maxpool2d(conv4b,scopename="pooled4")
    
    conv5a = conv_layer(pooled4, depthstart*8, depthstart*16, stvs,scopename="conv5a") 
    conv5b = conv_layer(conv5a, depthstart*16, depthstart*16, stvs, scopename="conv5b")
    conv5c = conv_layer(conv5b, depthstart*16, depthstart*16, stvs, scopename="conv5c")
    
    # Begin upsampling
    upconv1=upconv2d(conv5c, depthstart*16, depthstart*8, stvs, scopename="upconv1")
    conc1=concatenate(conv4b,upconv1,scopename="concat1")
    conv6a = conv_layer(conc1, depthstart*16, depthstart*8, stvs,scopename="conv6a")
    conv6b = conv_layer(conv6a, depthstart*8, depthstart*8, stvs, scopename="conv6b")
    
    upconv2=upconv2d(conv6b, depthstart*8, depthstart*4, stvs, scopename="upconv2")
    conc2=concatenate(conv3b,upconv2,scopename="concat2")
    conv7a = conv_layer(conc2, depthstart*8, depthstart*4, stvs,scopename="conv7a")
    conv7b = conv_layer(conv7a, depthstart*4, depthstart*4, stvs, scopename="conv7b")
    
    upconv3=upconv2d(conv7b, depthstart*4, depthstart*2, stvs, scopename="upconv3")
    conc3=concatenate(conv2b,upconv3,scopename="concat3")
    conv8a = conv_layer(conc3, depthstart*4, depthstart*2, stvs,scopename="conv8a")
    conv8b = conv_layer(conv8a, depthstart*2, depthstart*2, stvs, scopename="conv8b")
    
    upconv4=upconv2d(conv8b, depthstart*2, depthstart, stvs, scopename="upconv4")
    conc4=concatenate(conv1b,upconv4,scopename="concat4")
    conv9a = conv_layer(conc4, depthstart*2, depthstart, stvs,scopename="conv9a")
    conv9b = conv_layer(conv9a, depthstart, depthstart, stvs, scopename="conv9b")
        
    # Reduce depth to num_classes
    with tf.name_scope("Logs"):
        w=tf.Variable(tf.random_normal([1, 1, depthstart, num_classes],stddev=stvs))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        
        b=tf.Variable(tf.random_normal([num_classes],stddev=stvs))

        loglayer = tf.nn.conv2d(conv9b, w, strides=[1, 1, 1, 1], padding='SAME')
        loglayer = tf.nn.bias_add(loglayer,b)
        
    return loglayer
    

In [None]:
# Dice loss function -- Defined according to the V-net paper by Milletari et al
# Milletari F, Navab N, Ahmadi SA. V-net: Fully convolutional neural networks for volumetric medical image segmentation. ...
# In3D Vision (3DV), 2016 Fourth International Conference on 2016 Oct 25 (pp. 565-571). IEEE.
def dice_loss(logits, onehot_labels):
    with tf.name_scope("Dice_Loss"):

        eps = 1e-5
        prediction = tf.nn.softmax(logits)
        intersection = tf.reduce_sum(prediction * onehot_labels)
        union =  eps + tf.reduce_sum(prediction) + tf.reduce_sum(onehot_labels)
        dice_loss = -(2 * intersection/ (union))

        return dice_loss

In [None]:
def get_one_hot(image, num_classes): # Image is an [m  n] atrix
    
    # Make sure ground truth image is either 0 or 1, not 0 or 255
    if np.amax(image>0):
        image=np.divide(image,np.amax(image))
    image=image.astype(int)
    
    b=np.zeros((image.shape[0],image.shape[1],num_classes))
    for kk in range(image.shape[0]):
        for jj in range(image.shape[1]):
            #b[kk, range(image.shape[1]), image[kk,:]]=1
            b[kk, jj, image[kk,jj]]=1
    
    # b is the one-hot version of image with dimensions [m n num_classes]
    return b 
    

In [None]:
def build_graph(input_depth, num_classes):
    
    # Define placeholders
    with tf.name_scope("Input"):
        X = tf.placeholder(tf.float32, [None, None, None, input_depth])
    with tf.name_scope("Ground_Truth"):
        Y = tf.placeholder(tf.float32, [None, None, None, num_classes])
    
    # Define flow graph
    logits = conv_net(X, input_depth)

    # Prediction function for evaluating accuracy
    with tf.name_scope("Softmax"):
        prediction = tf.nn.softmax(logits)

    # Define Loss
    with tf.name_scope("Loss"):
        loss_op=dice_loss(logits,Y)
        #regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)
        #reg_variables = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        #reg_term = tf.contrib.layers.apply_regularization(regularizer, reg_variables)
        #loss_op += reg_term
        tf.summary.scalar("Loss",loss_op)

    # Define optimizer
    with tf.name_scope("Optimizer"):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)    
        train_op = optimizer.minimize(loss_op)
        
    # Evaluate model
    with tf.name_scope("Accuracy"):
        correct_pred = tf.equal(tf.argmax(prediction, axis=3), tf.argmax(Y, axis=3))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        tf.summary.scalar("accuracy",accuracy)
        
    with tf.name_scope("Metrics"):
        # Dice similarity coefficient
        eps = 1e-5
        intersection = tf.reduce_sum(prediction * Y)
        union =  eps + tf.reduce_sum(prediction) + tf.reduce_sum(Y)
        dice = (2 * intersection/ (union))
        tf.summary.scalar("Dice_Similarity_Coefficient",dice)

        # IoU
        #iou=tf.metrics.mean_iou(Y,prediction,num_classes)
        #tf.summary.scalar("IoU",iou)
        
        # Images for TensorBoard
        tf.summary.image('Predict', tf.cast(tf.expand_dims(tf.argmax(logits ,axis=3), 3),tf.float32), input_depth)
        tf.summary.image('Ground', tf.cast(tf.expand_dims(tf.argmax(Y ,axis=3), 3),tf.float32), input_depth)  
    
        # Confusion Matrix
        batch_confusion = tf.confusion_matrix(tf.reshape(tf.argmax(Y,axis=3),[-1]),tf.reshape(tf.argmax(prediction,axis=3),[-1]),num_classes=num_classes,name='Batch_confusion')

    #Define writer for Tensorboard
    writer=tf.summary.FileWriter("/output/1")
    summ=tf.summary.merge_all()

    # Initialize the variables
    init = tf.global_variables_initializer()

    #Define saver for model saver
    saver = tf.train.Saver()
    
    return X, Y, logits, prediction, loss_op, train_op, accuracy, batch_confusion, dice, writer, summ, init, saver

In [None]:
#Set hyperparameters, etc.
learning_rate = 0.00001

num_epochs = 50
batch_size = 1
input_depth = 1
num_classes = 2 

display_step = 480
validation_step = 480
save_step = 960


In [None]:
train_mode=True
tf.reset_default_graph()
X, Y, logits, prediction, loss_op, train_op, accuracy, batch_confusion, dice, writer, summ, init, saver = build_graph(input_depth, num_classes)

# Some flags
mode='j'
plane_view='Sagittal';
#plane_view='Coronal';
#plane_view='Transverse';

# Distinguish paths between running on local notebook vs. running in the cloud
if mode=='j':    
    imdpath="/path/to/dir/Knee_MRI3/Femur/Train/" + plane_view + "/Imd/"
    pxdpath="/path/to/dir/Knee_MRI3/Femur/Train/" + plane_view + "/Pxd/"

    imdpath_val="/path/to/dir/Knee_MRI3/Femur/Val/" + plane_view + "/Imd/"
    pxdpath_val="/path/to/dir/Knee_MRI3/Femur/Val/" + plane_view + "/Pxd/"
    
    imdpath_test="/path/to/dir/Knee_MRI3/Femur/Test/" + plane_view + "/Imd/"
    pxdpath_test="/path/to/dir/Knee_MRI3/Femur/Test/" + plane_view + "/Pxd/"
else:
    imdpath="/mydata/Femur/Train/" + plane_view + "/Imd/"
    pxdpath="/mydata/Femur/Train/" + plane_view + "/Pxd/"

    imdpath_val="/mydata/Femur/Val/" + plane_view + "/Imd/"
    pxdpath_val="/mydata/Femur/Val/" + plane_view + "/Pxd/"
    
    imdpath_test="/mydata/Femur/Test/" + plane_view + "/Imd/"
    pxdpath_test="/mydata/Femur/Test/" + plane_view + "/Pxd/"
    
    !ls /mydata
        

# Get num_files in each directory
num_files = get_num_files(pxdpath)
num_files_val= get_num_files(pxdpath_val)
num_files_test = get_num_files(pxdpath_test)
print(num_files)


In [None]:
# Train Session
# Send it
ct=0

with tf.Session() as sess:
    sess.run(init)
    
    #Option to restore from model checkpoint to resume training
    #saver.restore(sess, '/models/model_17055_Original')
    
    writer.add_graph(sess.graph)
    for epoch in range(num_epochs):
        
        # Reshuffle files every epoch
        file_list=random_files(num_files)

        i=0
        print("EPOCH " + "{:d}".format(epoch))
        while i<num_files: # Assume batch_size = 1
            
            # Set up input
            input_image=imageio.imread((imdpath + file_list[i] + '.png')) #Import slice
            input_image=input_image-np.mean(input_image,axis=(0,1)) # Mean-center the slice
            # image size must be divisible by 2^(depth)
            input_image=fix_dims(input_image)
            input_image=np.expand_dims(input_image,axis=2) # Expand image to have 3 axes --> [height, width, depth=1]
            
            picsize1=input_image.shape[0]
            picsize2=input_image.shape[1] 
            bx=np.empty((batch_size,picsize1,picsize2,input_depth))
 
            # Set up ground truth
            ground=imageio.imread((pxdpath + file_list[i] + '.png'))
            ground=fix_dims(ground)
            ground=np.expand_dims(ground,axis=2) # Expand image to have 3 axes --> [height, width, depth=1]
            by=np.empty((batch_size,picsize1,picsize2,num_classes))
            
            # randomly sample parameters for translation
            x_stvs=15; y_stvs=15; 
            x_t=int(x_stvs*np.random.randn()); y_t=int(y_stvs*np.random.randn());
            # randomly sample reflection flag
            reflect_flag=np.random.binomial(1,0.5,1)
            # Apply data augmentation (x and y translation and random y-axis reflection)
            input_image_aug, ground_aug = random_translate_seg_version(input_image, ground, x_t, y_t, reflect_flag)
            
            # Reformat ground truth into one-hot
            ground_one_hot = get_one_hot(ground_aug, num_classes)
            
            # Put input and ground into final format for flow graph [Batch, height, width, depth]
            bx[0,:,:,:], by[0,:,:,:] = input_image_aug, ground_one_hot

            # After setting up batch_x and batch_y, we can run train_op
            sess.run(train_op,feed_dict={X: bx, Y: by})
            loss, summary = sess.run([loss_op,summ],feed_dict={X: bx, Y: by})
            writer.add_summary(summary, ct)

            if i%display_step==0 or i==batch_size or i==0: 
                # Get metrics of interest and display
                BC, DSC =sess.run([batch_confusion,dice], feed_dict={X: bx, Y: by})
                print(BC)
                print("Step " + str(i) + ", Minibatch Loss= " + "{:.4f}".format(loss) + ", Dice Similarity Coefficient = " + "{:.3f}".format(DSC)) 
                
                # Verify data by displaying ground truth with original image
                f , (ax1, ax2) = plt.subplots(1,2,sharey=True, figsize=(6,2))
                ax1.set_title("Slice")
                ax1.imshow(np.squeeze(bx[0,:,:,0]),cmap='gray')
                ax2.imshow(np.squeeze(by[0,:,:,1]),cmap='gray')
                ax2.set_title("Ground Truth")
                f.subplots_adjust(hspace=5.0)
                plt.show()
    
            # Validation Step
            if i%validation_step==0 and i>0: 
                dsc_vec_val=[];
                ct_val=0;
                
                file_list_val=random_files(num_files_val)
                for j in range(num_files_val): #Import data for that mini-batch
                    
                    # Set up validation input
                    input_image=imageio.imread((imdpath_val + file_list_val[j] + '.png')) #Import slice
                    input_image=input_image-np.mean(input_image,axis=(0,1)) # Mean-center the slice
                    input_image=fix_dims(input_image)
                    input_image=np.expand_dims(input_image,axis=2) # Expand image to have 3 axes --> [height, width, depth]
        
                    picsize1=input_image.shape[0]
                    picsize2=input_image.shape[1] 
                    bx=np.empty((batch_size,picsize1,picsize2,input_depth))

                    # Set up ground truth
                    ground=imageio.imread((pxdpath_val + file_list_val[j] + '.png'))
                    ground=fix_dims(ground)
                    ground_one_hot=get_one_hot(ground,num_classes)
                    by=np.empty((batch_size,picsize1,picsize2,num_classes))

                    # No data augmentation for validation
                    bx[0,:,:,:], by[0,:,:,:] = input_image, ground_one_hot 

                    DSC = sess.run(dice, feed_dict={X: bx, Y: by})
                    dsc_vec_val.append(DSC)
                
                print("Validation Cohort Dice Similarity Coefficient = " + "{:.4f}".format(np.mean(dsc_vec_val)) + " (+-" + "{:.4f}".format(np.std(dsc_vec_val)) + ")")

            #Save point    
            if i%save_step==0:
                s=''
                checkpointnamelist=('./model_',str(ct),'_',plane_view)
                checkpointname= s.join(checkpointnamelist)
                save_path = saver.save(sess,checkpointname)
                print("Model saved in file: %s" % save_path)

            i=i+batch_size
            ct=ct+batch_size

print("We out here")
    

In [None]:
# Test Session
# Separate session for testing an image store on the trained model
ct=0
dsc_vec_test=[] # Vector to store dice similarity coefficient for each image
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, '/models/model_14175_Original') # Import trained model
    writer.add_graph(sess.graph)

    file_list=random_files(num_files_test)
    i=0
    while i<num_files_test:   
        # Set up validation input
        input_image=imageio.imread((imdpath_test + file_list[i] + '.png')) #Import slice
        input_image=input_image-np.mean(input_image,axis=(0,1)) # Mean-center the slice
        input_image=fix_dims(input_image)
        input_image=np.expand_dims(input_image,axis=2) # Expand image to have 3 axes --> [height, width, depth]

        picsize1=input_image.shape[0]
        picsize2=input_image.shape[1] 
        bx=np.empty((batch_size,picsize1,picsize2,input_depth))

        # Set up ground truth
        ground=imageio.imread((pxdpath_test + file_list[i] + '.png'))
        ground=fix_dims(ground)
        ground_one_hot=get_one_hot(ground,num_classes)
        by=np.empty((batch_size,picsize1,picsize2,num_classes))

        # No data augmentation for validation
        bx[0,:,:,:], by[0,:,:,:] = input_image, ground_one_hot 

        # Get prediction and metrics
        pred_out=sess.run(prediction,feed_dict={X: bx, Y: by})
        DSC = sess.run(dice, feed_dict={X: bx, Y: by})
        dsc_vec_test.append(DSC)

        if i%10==0 or i==batch_size or i==0: 

            print(i)
            f , (ax1, ax2,ax3) = plt.subplots(1,3,sharey=True, figsize=(9,3))

            ax1.set_title("Slice")
            ax1.imshow(np.squeeze(bx[0,:,:,0]),cmap='gray')

            ax2.set_title("Ground Truth")
            ax2.imshow(np.squeeze(by[0,:,:,1]),cmap='gray')

            ax3.set_title("Prediction")
            ax3.imshow(np.squeeze(np.argmax(predout[0,:,:,:],axis=-1)),cmap='gray')

            f.subplots_adjust(hspace=5.0)
            plt.show()
            trainstep_actual+=trainstep_actual
            print("Dice Similarity Coefficient = " + "{:.4f}".format(DSC))

        i=i+batch_size
        ct=ct+batch_size
        
    print("Test Cohort Dice Similarity Coefficient = " + "{:.4f}".format(np.mean(dsc_vec_test)) + " (+-" + "{:.4f}".format(np.std(dsc_vec_test)) + ")")


print("We out here")

In [None]:
# Inference Session
# Separate session for inferring the model on each subject
ct=0
dsc_vec_test=[] # Vector to store dice similarity coefficient for each image
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, '/models/model_14175_Original') # Import trained model
    writer.add_graph(sess.graph)

    file_list=random_files(num_files_test)
    
    #Get num subjects
    if mode=='j':
        directory='/path/to/dir/Knee_MRI3/Femur/Inference/'
    else:
        directory='/mydata/Femur/Inference/'
    num_subj=int(len(glob.glob(directory + "*/")))
    
    for kk in range(num_subj):
        #get num slices for this subject
        subj_dir=directory + str(kk+1) + '/'
        num_slices=int(len(glob.glob(subj_dir + "*.png")))
        print(str(num_slices) + ' slices for subject ' + str(kk_1))
        
        # Make directory for this subject
        #inference/subject/plane/Inf
        pathout='/Inference/' + str(kk+1) + '/' + plane_view + '/Inf/'
        os.mkdir(pathout)
        
        i=0
        while i<num_slices: 
            fileID=str(i)
            IDlength=6
            while len(fileID)<IDlength:
                fileID='0'+fileID
            
            # Set up input
            input_image=imageio.imread((imdpath_test + fileID + '.png')) #Import slice
            input_image=input_image-np.mean(input_image,axis=(0,1)) # Mean-center the slice
            original_dims=np.array(input_image).shape
            input_image=fix_dims(input_image)
            input_image=np.expand_dims(input_image,axis=2) # Expand image to have 2 axes --> [height, width, depth]

            picsize1=input_image.shape[0]
            picsize2=input_image.shape[1] 
            bx=np.empty((batch_size,picsize1,picsize2,input_depth))
            
            # No data augmentation
            bx[0,:,:,:] = input_image

            # Get prediction
            probability_map=sess.run(prediction,feed_dict={X: bx})
            
            # Probability map is one-hot, we want the dimension that corresponds to landmark probability (not background)
            imageio.imwrite(pathout + fileID + ".png",np.squeeze(predout[0,:original_dims[0],:original_dims[1],1]))

            i=i+1
            ct=ct+1

print("We out here")