In [1]:
import numpy as np
import tensorflow as tf
from scipy import misc
import glob
import imageio
import matplotlib.pyplot as plt
import random
import csv
from tensorflow.contrib import rnn

#Basically a densenet but the last fully connected layer is recurrent
#Huang G, Liu Z, Weinberger KQ, van der Maaten L. Densely connected convolutional networks. ...
#InProceedings of the IEEE conference on computer vision and pattern recognition 2017 Jul 1 (Vol. 1, No. 2, p. 3).

In [2]:
def random_files(num_files):
    x = ([[i] for i in range(num_files)])
    shuflist=random.sample(x,len(x))
    list_files=[]

    s=''
    for i in range(num_files):
        ID=str(shuflist[i][0])
        while len(ID)<5:
            ID='0'+ID
        list_files.append(ID)
        
        
    return list_files

In [3]:
def return_list(path_ground):
    s=''
    classtrue=[]
    
    f = open(path_ground, 'r')
    reader = csv.reader(f)
    for row in reader:
        if row[0]=='0':
            sublist=(1,0)
        else:
            sublist=(0,1)
    
        classtrue.append(sublist)
    f.close()
    
    by=np.empty((len(classtrue),num_classes))
    for i in range(len(classtrue)):
        by[i,:]=classtrue[i]

    return by
    
def return_images(path,picsize,inputdepth):
    inputdepth=7
    numfiles=int(len(glob.glob(path+'*.png'))/7)
    
    images=np.empty((numfiles,picsize,picsize,1))
    for i in range(numfiles):
        ID=str(i)
        while len(ID)<5:
            ID='0'+ID
        images[i,:,:,0]=imageio.imread((path + ID + 'D.png')) #Import slice
        
    return images
    


In [4]:
def random_translate(input_block,y_dir,x_dir):
    imsize=input_block.shape
    temp=np.empty((imsize[0],imsize[1],imsize[2]))
    y_abs=np.absolute(y_dir)
    x_abs=np.absolute(x_dir)
    
    #Vertical translation
    if y_dir != 0: 
        height_pad=np.zeros((y_abs,imsize[1],imsize[2]))
        if y_dir>0: #Shift up --> attach zeros to bottom
            temp=np.concatenate((input_block[y_abs:,:,:],height_pad),axis=0)
        else: #Shift down --> attach zeros to top
            temp=np.concatenate((height_pad,input_block[:(-1*y_abs),:,:]),axis=0)
    else: 
        temp=input_block
        
    #Horizontal translation
    if x_dir!=0:
        width_pad=np.zeros((imsize[0],x_abs,imsize[2]))
        if x_dir>0:
            temp2=np.concatenate((width_pad,temp[:,:(-1*x_abs),:]),axis=1)
        else: 
            temp2=np.concatenate((temp[:,x_abs:,:],width_pad),axis=1)
    else:
        temp2=temp
        
    return temp2

In [5]:
#Get number of files in training set and also median frequency class weights for cross entropy loss
def get_num_files(path_ground):  
    s=''
    classtrue=[]
    f = open(path_ground, 'r')
    reader = csv.reader(f)
    for row in reader:
        if row[0]=='0':
            sublist=(1,0)
        else:
            sublist=(0,1)
    
        classtrue.append(sublist)
    f.close()
    ground_vec=classtrue

    return len(ground_vec), ground_vec

In [6]:
# Layer wrappers
def conv_layer(inputs, channels_in, channels_out,is_training,pad_val='SAME',stvs=0.01,filter_size=1,strides=1,scopename="conv"):
    with tf.name_scope(scopename):
        s=''; weightname=(scopename,'_weights'); biasname=(scopename,'_bias')
        
        w=tf.Variable(tf.random_normal([filter_size, filter_size, channels_in, channels_out],stddev=stvs),name=s.join(weightname))
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, w)
        b=tf.Variable(tf.random_normal([channels_out],stddev=stvs),name=s.join(biasname))
        tf.summary.histogram(("Weight" + scopename),w)
        tf.summary.histogram(("Bias" + scopename),b)
        
        x = tf.nn.conv2d(inputs, w, strides=[1, strides, strides, 1], padding=pad_val)
        x = tf.nn.bias_add(x, b)
        #x=batch_norm_wrapper(x, is_training, decay = 0.99)
        #return tf.nn.relu(x)
        return x
        
def BNR(inputs,is_training,scopename="BNR"):
    # Batch norm + ReLu
    with tf.name_scope(scopename):
        x=tf.contrib.layers.layer_norm(inputs)
        return tf.nn.relu(x)
    
def shortcut_function(inputs,channels_in,num_filter_out,stvs=0.01,strides=1,scopename="shortcut"):
    with tf.name_scope(scopename):
        s=''; weightname=(scopename,'_weights')
        filter_size=1
        w=tf.Variable(tf.random_normal([filter_size, filter_size, channels_in, num_filter_out],stddev=stvs),name=s.join(weightname))
       
        shortcut = tf.nn.conv2d(inputs, w, strides=[1, strides, strides, 1], padding='VALID')
        
        return shortcut
    
def maxpool2d(x, k=2, stride=2,scopename="pool"):
    with tf.name_scope(scopename):
        # MaxPool2D wrapper
        return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, stride, stride, 1],padding='VALID')

def Denseblock(inputs,num_dense_layers,k,channels_current, is_training,scopename='Denseblock'):
    
    with tf.name_scope(scopename):
        #Initialize shortcut list
        shortcut=list()
        shortcut.append(inputs)
    
        #Create first dense layer
        x = Denselayer(inputs,channels_current,k, is_training,scopename=scopename)
        shortcut.append(x); channels_current+=k #Iterate number of current channels
    
        for i in range(1,num_dense_layers):
            x=tf.concat(shortcut,3)
            x = Denselayer(x,channels_current,k, is_training,scopename=scopename)
            shortcut.append(x); channels_current+=k
    
        out=tf.concat(shortcut,3)
        
        return out, channels_current       
    
def Denselayer(inputs,channels_current,k, is_training,scopename='Denselayer'):
    
    with tf.name_scope(scopename):
        x=BNR(inputs,is_training,scopename=scopename)
        x=conv_layer(x,channels_current,4*k, is_training,filter_size=1,scopename=scopename)
    
        x=BNR(x,is_training,scopename=scopename)
        x=conv_layer(x,4*k,k, is_training,filter_size=3,scopename=scopename)
    
        return x        
    
def Transitionlayer(inputs, theta, channels_current, is_training,scopename='Tranny'):
    
    with tf.name_scope(scopename):
     
        # Conv
        x=conv_layer(inputs,channels_current,int(channels_current*theta), is_training,filter_size=1,scopename=scopename)
        channels_current=channels_current*theta
    
        # Pooling
        avg_pool=tf.nn.pool(x, [2,2], "AVG", "SAME",strides=[2,2])
    
        return avg_pool, int(channels_current)
    
    
# Batch normalization wrapper to distinguish between training and testing
# REF: https://r2rt.com/implementing-batch-normalization-in-tensorflow.html
def batch_norm_wrapper(inputs, is_training, decay = 0.99):
    
    epsilon = 1e-3
    scale = tf.Variable(tf.ones([inputs.get_shape()[-1]]))
    beta = tf.Variable(tf.zeros([inputs.get_shape()[-1]]))
    #pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-3], inputs.get_shape()[-2] ,inputs.get_shape()[-1]]), trainable=False)
    #pop_var = tf.Variable(tf.ones([inputs.get_shape()[-3], inputs.get_shape()[-2] ,inputs.get_shape()[-1]]), trainable=False)
    
    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)
  
    if is_training:
        batch_mean, batch_var = tf.nn.moments(inputs,[0,1,2])
        train_mean = tf.assign(pop_mean,
                               pop_mean * decay + batch_mean * (1 - decay))
        train_var = tf.assign(pop_var,
                              pop_var * decay + batch_var * (1 - decay))
        with tf.control_dependencies([train_mean, train_var]):
            return tf.nn.batch_normalization(inputs,
                batch_mean, batch_var, beta, scale, epsilon)
    else:
        return tf.nn.batch_normalization(inputs,
            pop_mean, pop_var, beta, scale, epsilon)
    
    

    

In [7]:
def densenet(x, inputdepth, picsize, is_training, densenodes, trunc_prop):
    depth_start=64
    num_classes=2
    inputdepth=1
    stvs=0.01
    k=12
    theta=0.5
    BS=1
    trunc_prop=10
    
    
    x = tf.reshape(x, shape=[-1, picsize, picsize, inputdepth]) 
    
    # Initial conv_block
    with tf.name_scope("Initial_Conv"):
        conv_initial = conv_layer(x, inputdepth, 2*k, is_training, filter_size=7, scopename="conv_initial")
        channels_current=2*k
    
    # Initial max pooling
    with tf.name_scope("Max_Pool1"):
        # From 256 down to 128
        pool_initial = maxpool2d(conv_initial,scopename="pool_initial")
          
    # Dense block 1
    scope_in="Denseblock1"
    with tf.name_scope(scope_in):
        print(scope_in)
        num_dense_layers=6
        x, channels_current=Denseblock(pool_initial,num_dense_layers,k,channels_current, is_training,scopename=scope_in)
        print(channels_current)
        
    # Transition 1
    scope_in="Tranny1"
    with tf.name_scope(scope_in):
        # From 128 to 64
        print(scope_in)
        x, channels_current=Transitionlayer(x, theta, channels_current, is_training,scopename=scope_in)
        print(channels_current)
        
    # Dense block 2
    scope_in="Denseblock2"
    with tf.name_scope(scope_in):
        scope_in=scope_in
        print(scope_in)
        num_dense_layers=12
        x, channels_current=Denseblock(x,num_dense_layers,k,channels_current, is_training,scopename=scope_in)
        print(channels_current)
        
    # Transition 2
    scope_in="Tranny2"
    with tf.name_scope(scope_in):
        # From 64 to 32
        print(scope_in)
        x, channels_current=Transitionlayer(x, theta, channels_current, is_training,scopename=scope_in)
        print(channels_current)
        
    # Dense block 3
    with tf.name_scope(scope_in):
        scope_in="Denseblock3"
        print(scope_in)
        num_dense_layers=24
        x, channels_current=Denseblock(x,num_dense_layers,k,channels_current, is_training,scopename=scope_in)
        print(channels_current)
        
    # Transition 3
    scope_in="Tranny3"
    with tf.name_scope(scope_in):
        # From 32 to 16
        print(scope_in)
        x, channels_current=Transitionlayer(x, theta, channels_current, is_training,scopename=scope_in)
        print(channels_current)
        
    # Dense block 4     
    scope_in="Denseblock4"
    with tf.name_scope(scope_in):
        print(scope_in)
        num_dense_layers=16
        x, channels_current=Denseblock(x,num_dense_layers,k,channels_current, is_training,scopename=scope_in)
        print(channels_current)
    
    # Average pool
    with tf.name_scope("Average_Pool_final"):
        # From 16 to 8
        avg_pool=tf.nn.pool(x, [2,2], "AVG", "SAME",strides=[2,2])
    
  
    # Fully connected
    with tf.name_scope("Dense"):
        #Flatten for fully connected
        #print(channels_current)
        #print((picsize/(2**5)))
        nodes_in=int(channels_current*(picsize/(2**5))*(picsize/(2**5)))
        print(nodes_in)
        flatten=tf.reshape(avg_pool,[-1,nodes_in])

        # RNN time --> First we need to rearrange the data format into [BS,time_steps, vars]
        # We want the timestep to become the batches and the batch size becomes 1.
        flat_extend=tf.expand_dims(flatten,axis=0) #--> Now we are in [BS,timesteps,vars] format
        
        # Unpack the tensor to be a list of tensors at each time step
        flat_unstack=tf.unstack(flat_extend,trunc_prop,axis=1) # Unpack axis 1
        
        cell_state = tf.placeholder(tf.float32, [BS, densenodes])
        hidden_state = tf.placeholder(tf.float32, [BS, densenodes])
        init_state = rnn.LSTMStateTuple(cell_state, hidden_state)
        
        lstm_cell = rnn.BasicLSTMCell(densenodes, state_is_tuple=True)
        
        outputs, states = rnn.static_rnn(lstm_cell, flat_unstack, init_state, dtype=tf.float32)
        

        
    # Compress to num_classes for prediction
    with tf.name_scope("Out_Layer"):
          
        outvars={'weights':tf.Variable(tf.random_normal([densenodes, num_classes],stddev=stvs)), 
                'biases':tf.Variable(tf.random_normal([num_classes],stddev=stvs))}
        
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, outvars['weights'])
        
        outlayer=[tf.matmul(output, outvars['weights']) + outvars['biases'] for output in outputs]
        
        #outlayer=tf.add(tf.matmul(dense,outvars['weights']),outvars['biases'])
        
    return tf.concat(outlayer,0) , states, cell_state, hidden_state
    

In [8]:
def build_graph(trainmode, inputdepth, num_classes, picsize, densenodes, trunc_prop):

    #Define placeholders
    with tf.name_scope("Input"):
        X = tf.placeholder(tf.float32, [None, picsize, picsize, inputdepth])
    with tf.name_scope("Ground_Truth"):
        Y = tf.placeholder(tf.float32, [None, num_classes])
    
    #Define flow graph
    logits, states, cell_state, hidden_state = densenet(X, inputdepth, picsize, trainmode, densenodes, trunc_prop) ###<---IS TRAINING

    #Prediction function for evaluating accuracy
    with tf.name_scope("Softmax"):
        prediction = tf.nn.softmax(logits)

    #Define Loss
    with tf.name_scope("Loss"):
        loss_op=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))

        regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)
        reg_variables = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        reg_term = tf.contrib.layers.apply_regularization(regularizer, reg_variables)
        loss_op += reg_term

        tf.summary.scalar("Loss",loss_op)

    #Define optimizer
    with tf.name_scope("Optimizer"):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)   

        train_op = optimizer.minimize(loss_op)
        
    # Evaluate model
    with tf.name_scope("Accuracy"):
        correct_pred = tf.equal(tf.argmax(prediction, axis=1), tf.argmax(Y, axis=1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        tf.summary.scalar("accuracy",accuracy)

    #Confusion matrix
    with tf.name_scope("Confusion_Matrix"):
        batch_confusion = tf.confusion_matrix(tf.reshape(tf.argmax(Y,axis=1),[-1]),tf.reshape(tf.argmax(prediction,axis=1),[-1]),num_classes=num_classes,name='batch_confusion')
        confusion_image = tf.reshape( tf.cast(batch_confusion, tf.float32),[1, num_classes, num_classes, 1])
        tf.summary.image('confusion',confusion_image)

    #Define writer for Tensorboard
    writer=tf.summary.FileWriter("/output/12")
    writer_val=tf.summary.FileWriter("/output/4")
    summ=tf.summary.merge_all()

    # Initialize the variables
    init = tf.global_variables_initializer()

    #Define saver for model saver
    saver = tf.train.Saver()
    
    
    return X, Y,  logits, states, cell_state, hidden_state, prediction, loss_op, train_op, accuracy, batch_confusion, confusion_image, writer, writer_val, summ, init, saver

In [9]:
# Training Parameters
learning_rate = 0.00001
batch_size = 1
display_step = batch_size*10
hm_epochs=30
BS=batch_size
picsize=256
inputdepth=1
num_classes=2
densenodes=1000
trunc_prop=10

In [10]:
train_mode=True
tf.reset_default_graph()
X, Y, logits, states, cell_state, hidden_state, prediction, loss_op, train_op, accuracy, batch_confusion, confusion_image, writer, writer_val, summ, init, saver = build_graph(train_mode, inputdepth, num_classes, picsize, densenodes, trunc_prop)

print('Graph built')

num_subj_train=17; num_subj_val=2; num_subj_test=2;

mode=''
if mode=='j':    
    imdpath="/path/to/dir/Knee_MRI2/Femur/Train/ImdTrain/"
    pxdpath="/path/to/label/csv/Femur/Train/fem_mri_classlist_Train.csv"

    imdpath_val="/path/to/dir/Knee_MRI2/Femur/Val/ImdVal/"
    pxdpath_val="/path/to/label/csv/Femur/Val/fem_mri_classlist_Val.csv"
    
    imdpath_test="/path/to/dir/Knee_MRI2/Femur/Test/ImdTest/"
    pxdpath_test="/path/to/label/csv/Femur/Test/fem_mri_classlist_Test.csv"
        
    root = '/path/to/dir/Knee_MRI2/Femur/'
    
else:
    imdpath="/mydata/Femur/Train/ImdTrain/"
    pxdpath="/mydata/Femur/Train/fem_mri_classlist_Train.csv"

    imdpath_val="/mydata/Femur/Val/ImdVal/"
    pxdpath_val="/mydata/Femur/Val/fem_mri_classlist_Val.csv"
    
    imdpath_test="/mydata/Femur/Test/ImdTest/"
    pxdpath_test="/mydata/Femur/Test/fem_mri_classlist_Test.csv"
    
    root = '/mydata/Femur/'


Denseblock1
96
Tranny1
48
Denseblock2
192
Tranny2
96
Denseblock3
384
Tranny3
192
Denseblock4
384
24576
Graph built


In [None]:
#Send it
ct=0
with tf.Session() as sess:
    sess.run(init)
    writer.add_graph(sess.graph)
    #saver.restore(sess, '/models/model_30112') #Option to restore a model to resume training
    
    for epoch in range(hm_epochs):
        print("EPOCH " + "{:d}".format(epoch))
        
        for subj in range(num_subj_train):
            prediction_vec=[]; logits_vec=[]; ground_vec=[];
            print("Subject " + str(subj))
        
            byy = return_list(root + 'Train/fem_mri_classlist_Train_subj' + str(subj) + '.csv')
            images = return_images((imdpath+ str(subj) + '/'),picsize,inputdepth)
            images=images-np.mean(images)
            
            x_stvs=15; y_stvs=15; 
            x_t=int(x_stvs*np.random.randn()); y_t=int(y_stvs*np.random.randn());
            bxx=np.empty((images.shape[0],picsize,picsize,1))
            
            for j in range(images.shape[0]):
                bxx[j,:,:,:]=random_translate(images[j,:,:,:],x_t,y_t)
                
            #Initialize cell and hidden for this subject
            current_cell = np.zeros((BS, densenodes))
            current_hidden = np.zeros((BS, densenodes))

            i=0
            # Because we are using a static_rnn, iterate through the subject "trunc_prop" slices at a time
            while i+trunc_prop<=images.shape[0]:
                
                bx=np.empty((trunc_prop,picsize,picsize,inputdepth))
                by=np.empty((trunc_prop,num_classes))
                
                bx=bxx[i:i+trunc_prop,:,:,:]
                by=byy[i:i+trunc_prop,:]
                
                sess.run(train_op,feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})
                loss, acc, summary, pred=sess.run([loss_op,accuracy,summ,prediction],feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})
                
                for kk in range(pred.shape[0]):
                    prediction_vec.append(np.argmax(pred[kk,:]))
                    ground_vec.append(by[kk,1])
                    logits_vec.append(pred[kk,1])
                
                #After training iteration, retrieve cell and hidden states for next iteration in current subject
                current_total=sess.run(states,{X: bx, Y: by, cell_state: current_cell , hidden_state: current_hidden})
                
                #Unpack the state
                current_cell, current_hidden = current_total
                
                writer.add_summary(summary, ct)
                
                i+=trunc_prop
                ct+=trunc_prop
            
            # For the remaining slices in the subject, we need it to have the same amount of slices as the other truncated sequence lengths
            # Pad on the final slice until we reach the correct sequence length
            if images.shape[0]%trunc_prop!=0:
            
                trunc_prop_temp=images.shape[0]-i

                bx=np.empty((trunc_prop,picsize,picsize,inputdepth))
                by=np.empty((trunc_prop,num_classes))

                bx[:trunc_prop_temp,:,:,:]=bxx[i:i+trunc_prop_temp,:,:,:]
                by[:trunc_prop_temp,:]=byy[i:i+trunc_prop_temp,:]
                i+=trunc_prop_temp

                temp_ct=0;
                while trunc_prop_temp+temp_ct<trunc_prop:
                    bx[trunc_prop_temp+temp_ct,:,:,:]=bx[trunc_prop_temp-1,:,:,:]
                    by[trunc_prop_temp+temp_ct,:]=by[trunc_prop_temp-1,:]
                    i+=1; temp_ct+=1


                sess.run(train_op,feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})
                loss, acc, summary, pred=sess.run([loss_op,accuracy,summ,prediction],feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})

                for kk in range(pred.shape[0]):
                    prediction_vec.append(np.argmax(pred[kk,:]))
                    logits_vec.append(pred[kk,1])
                    ground_vec.append(by[kk,1])
                    
                # After training iteration, retrieve cell and hidden states for next iteration in current subject
                current_total=sess.run(states,{X: bx, Y: by, cell_state: current_cell , hidden_state: current_hidden})

                # Unpack the state
                current_cell, current_hidden = current_total

                writer.add_summary(summary, ct)

                i+=trunc_prop_temp
                ct+=trunc_prop_temp
                
            plt.plot(prediction_vec,'r',logits_vec,'g',ground_vec,'k--')
            plt.ylabel('Subject Predictions')
            plt.show()

            
            
            if subj>0 and subj%10==0:# Save every 5 subjects
                
                s=''
                checkpointnamelist=('./model_',str(ct))
                checkpointname= s.join(checkpointnamelist)
                save_path = saver.save(sess,checkpointname)
                
                #print("Model saved in file: %s" % save_path)
                #print(" Step " + str(i) + ", Minibatch Loss= " + "{:.4f}".format(loss) + ", Training Accuracy = " + "{:.3f}".format(acc)) 
                #print(sess.run(batch_confusion,feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden}))
                
                #Show example image from batch with ground truth
                #temp=np.squeeze(bx[0,:,:,:])
                #plt.imshow(temp,cmap='gray')
                #plt.show()
                #print(by[0,:])
    
            
            #VALIDATION
            if subj>0 and subj%10==0:    
            
                acc_vec_val=[]; ct_val=0;
                seq_vec=[];
                for j in range(num_subj_val): # Import data for that mini-batch
                    print("Val Subject " + str(j))
                    prediction_vec=[]; logits_vec=[]; ground_vec=[];
                    
                    byy = return_list(root + 'Val/fem_mri_classlist_Val_subj' + str(j) + '.csv')
                    images = return_images((imdpath_val+ str(j) + '/'),picsize,inputdepth)
                    images=images-np.mean(images)
                    bxx=images

                    #Initialize cell and hidden for this subject
                    current_cell = np.zeros((BS, densenodes))
                    current_hidden = np.zeros((BS, densenodes))

                    ii=0
                    #Because we are using a static_rnn, iterate through the subject trunc_prop slices at a time
                    while ii+trunc_prop<=images.shape[0]:
                
                        bx=np.empty((trunc_prop,picsize,picsize,inputdepth))
                        by=np.empty((trunc_prop,num_classes))
                
                        bx=bxx[ii:ii+trunc_prop,:,:,:]
                        by=byy[ii:ii+trunc_prop,:]
                
                        loss, acc, summary, pred=sess.run([loss_op,accuracy,summ, prediction],feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})
                
                        for kk in range(pred.shape[0]):
                            prediction_vec.append(np.argmax(pred[kk,:]))
                            logits_vec.append(pred[kk,1])
                            ground_vec.append(by[kk,1])
                
                        #After training iteration, retrieve cell and hidden states for next iteration in current subject
                        current_total=sess.run(states,{X: bx, Y: by, cell_state: current_cell , hidden_state: current_hidden})
                
                        #Unpack the state
                        current_cell, current_hidden = current_total
                
                        ii+=trunc_prop
                        
                        acc_vec_val.append(acc)
                        seq_vec.append(trunc_prop)
                    
                    if images.shape[0]%trunc_prop!=0:
                        #Run remaining slices through an iteration
                        trunc_prop_temp=images.shape[0]-ii

                        bx=np.empty((trunc_prop,picsize,picsize,inputdepth))
                        by=np.empty((trunc_prop,num_classes))

                        bx[:trunc_prop_temp,:,:,:]=bxx[ii:ii+trunc_prop_temp,:,:,:]
                        by[:trunc_prop_temp,:]=byy[ii:ii+trunc_prop_temp,:]
                        ii+=trunc_prop_temp

                        temp_ct=0;
                        while trunc_prop_temp+temp_ct<trunc_prop:
                            bx[trunc_prop_temp+temp_ct,:,:,:]=bx[trunc_prop_temp-1,:,:,:]
                            by[trunc_prop_temp+temp_ct,:]=by[trunc_prop_temp-1,:]
                            ii+=1; temp_ct+=1

                        loss, acc, summary, pred=sess.run([loss_op,accuracy,summ, prediction],feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})

                        for kk in range(pred.shape[0]):
                            prediction_vec.append(np.argmax(pred[kk,:]))
                            logits_vec.append(pred[kk,1])
                            ground_vec.append(by[kk,1])
                        
                        #After training iteration, retrieve cell and hidden states for next iteration in current subject
                        current_total=sess.run(states,{X: bx, Y: by, cell_state: current_cell , hidden_state: current_hidden})

                        #Unpack the state
                        current_cell, current_hidden = current_total

                        acc_vec_val.append(acc)
                        seq_vec.append(trunc_prop_temp)

                        ii+=trunc_prop_temp
                        
                    plt.plot(prediction_vec,'r',logits_vec,'g',ground_vec,'k--')
                    plt.ylabel('Subject Predictions')
                    plt.show()
                    
                weighted_avg_vec=[];
                for k in range(len(acc_vec_val)):
                    weighted_avg_vec.append(acc_vec_val[k]*seq_vec[k]/sum(seq_vec))
                    
                print("Validation accuracy: " + str(sum(weighted_avg_vec)))
        
    print("We out here")

In [None]:
#Send it
ct=0
with tf.Session() as sess:
    sess.run(init)
    writer.add_graph(sess.graph)
    saver.restore(sess, '/models/model_38292')
    
    for j in range(num_subj_test): # Import data for that mini-batch
        acc_vec_val=[]; ct_val=0; seq_vec=[];
        print("Test Subject " + str(j))
        prediction_vec=[]; logits_vec=[]; ground_vec=[];
        
        byy = return_list(root + 'Test/fem_mri_classlist_Test_subj' + str(j) + '.csv')
        images = return_images((imdpath_val+ str(j) + '/'),picsize,inputdepth)
        images=images-np.mean(images)
        bxx=images

        #Initialize cell and hidden for this subject
        current_cell = np.zeros((BS, densenodes))
        current_hidden = np.zeros((BS, densenodes))

        ii=0
        #Because we are using a static_rnn, iterate through the subject trunc_prop slices at a time
        while ii+trunc_prop<=images.shape[0]:

            bx=np.empty((trunc_prop,picsize,picsize,inputdepth))
            by=np.empty((trunc_prop,num_classes))

            bx=bxx[ii:ii+trunc_prop,:,:,:]
            by=byy[ii:ii+trunc_prop,:]

            loss, acc, summary, pred=sess.run([loss_op,accuracy,summ, prediction],feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})

            for kk in range(pred.shape[0]):
                prediction_vec.append(np.argmax(pred[kk,:]))
                logits_vec.append(pred[kk,1])
                ground_vec.append(by[kk,1])

            #After training iteration, retrieve cell and hidden states for next iteration in current subject
            current_total=sess.run(states,{X: bx, Y: by, cell_state: current_cell , hidden_state: current_hidden})

            #Unpack the state
            current_cell, current_hidden = current_total

            ii+=trunc_prop

            acc_vec_val.append(acc)
            seq_vec.append(trunc_prop)

        if images.shape[0]%trunc_prop!=0:
            #Run remaining slices through an iteration
            trunc_prop_temp=images.shape[0]-ii

            bx=np.empty((trunc_prop,picsize,picsize,inputdepth))
            by=np.empty((trunc_prop,num_classes))

            bx[:trunc_prop_temp,:,:,:]=bxx[ii:ii+trunc_prop_temp,:,:,:]
            by[:trunc_prop_temp,:]=byy[ii:ii+trunc_prop_temp,:]
            ii+=trunc_prop_temp

            temp_ct=0;
            while trunc_prop_temp+temp_ct<trunc_prop:
                bx[trunc_prop_temp+temp_ct,:,:,:]=bx[trunc_prop_temp-1,:,:,:]
                by[trunc_prop_temp+temp_ct,:]=by[trunc_prop_temp-1,:]
                ii+=1; temp_ct+=1

            loss, acc, summary, pred=sess.run([loss_op,accuracy,summ, prediction],feed_dict={X: bx, Y: by, cell_state: current_cell ,hidden_state: current_hidden})

            for kk in range(pred.shape[0]):
                prediction_vec.append(np.argmax(pred[kk,:]))
                logits_vec.append(pred[kk,1])
                ground_vec.append(by[kk,1])

            #After training iteration, retrieve cell and hidden states for next iteration in current subject
            current_total=sess.run(states,{X: bx, Y: by, cell_state: current_cell , hidden_state: current_hidden})

            #Unpack the state
            current_cell, current_hidden = current_total

            acc_vec_val.append(acc)
            seq_vec.append(trunc_prop_temp)

            ii+=trunc_prop_temp

        plt.plot(prediction_vec,'r',logits_vec,'g',ground_vec,'k--')
        plt.ylabel('Subject Predictions')
        plt.show()

        weighted_avg_vec=[];
        for k in range(len(acc_vec_val)):
            weighted_avg_vec.append(acc_vec_val[k]*seq_vec[k]/sum(seq_vec))

        print("Test accuracy: " + str(sum(weighted_avg_vec)))

