In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import sys
import random
import scipy.io as sio
%matplotlib notebook
# Required modules

### Parameters

In [None]:
INPUT_SIZE = 64 # Input feature width/height
INPUT_DEPTH = 64 # Input depth 
INPUT_CHANNEL = 1
OUTPUT_SIZE = 64 # Output feature width/height 
OUTPUT_DEPTH = 64 # Output depth
OUTPUT_CHANNEL = 1
OUTPUT_CLASSES = 4 # Number of output classes in dataset

### Pre-processing

In [None]:
OFF_IMAGE_FILL = 0 # What to fill an image with if padding is required to make Tensor
OFF_LABEL_FILL = 0 # What to fill a label with if padding is required to make Tensor

# Get 'natural' OUTPUT_DEPTH according to scipy method
io_zoom = OUTPUT_SIZE/INPUT_SIZE
zero_chk = np.zeros((INPUT_SIZE, INPUT_SIZE, INPUT_DEPTH))

def get_scaled_input(data, min_i = INPUT_SIZE, min_o = OUTPUT_SIZE, depth = INPUT_DEPTH, 
                    depth_out = OUTPUT_DEPTH, image_fill = OFF_IMAGE_FILL, 
                    label_fill = OFF_LABEL_FILL, n_classes = OUTPUT_CLASSES, norm_max = 500):
    
    # Takes raw data (x, y) and scales to match desired input and output sizes to feed into Tensorflow
    # Pads and normalises input and also moves axes around to orientation expected by tensorflow
    
    input_scale_factor = min_i/data[0].shape[0]
    output_scale_factor = min_o/data[0].shape[0]

    vox_zoom = None
    lbl_zoom = None

    if not input_scale_factor == 1:
        vox_zoom = scipy.ndimage.interpolation.zoom(data[0], input_scale_factor, order = 1) 
        # Order 1 is bilinear - fast and good enough
    else:
        vox_zoom = data[0]

    if not output_scale_factor == 1:
        lbl_zoom = scipy.ndimage.interpolation.zoom(data[1], output_scale_factor, order = 0) 
        # Order 0 is nearest neighbours: VERY IMPORTANT as it ensures labels are scaled properly (and stay discrete)
    else:
        lbl_zoom = data[1]   

    lbl_pad = label_fill*np.ones((min_o, min_o, depth_out - lbl_zoom.shape[-1]))
    lbl_zoom = np.concatenate((lbl_zoom, lbl_pad), 2)
    lbl_zoom = lbl_zoom[np.newaxis, :, :, :]
    
    vox_pad = image_fill*np.ones((min_i, min_i, depth - vox_zoom.shape[-1]))
    vox_zoom = np.concatenate((vox_zoom, vox_pad), 2)
    
    max_val = np.max(vox_zoom)
    if not np.max(vox_zoom) == 0:
        vox_zoom = vox_zoom * norm_max/np.max(vox_zoom)
        
    vox_zoom = vox_zoom[np.newaxis, :, :, :]

    vox_zoom = np.swapaxes(vox_zoom, 0, -1)
    lbl_zoom = np.swapaxes(lbl_zoom, 0, -1)
    # Swap axes
        
    return vox_zoom, lbl_zoom

def upscale_segmentation(lbl, shape_desired):
    # Returns scaled up label for a given input label and desired shape. Required for Mean IOU calculation
    
    scale_factor = shape_desired[0]/lbl.shape[0]
    lbl_upscale = scipy.ndimage.interpolation.zoom(lbl, scale_factor, order = 0)
    # Order 0 EVEN more important here
    lbl_upscale = lbl_upscale[:, :, :shape_desired[-1]]
    if lbl_upscale.shape[-1] < shape_desired[-1]:
        pad_zero = OFF_LABEL_FILL*np.zeros((shape_desired[0], shape_desired[1], shape_desired[2] - lbl_upscale.shape[-1]))
        lbl_upscale = np.concatenate((lbl_upscale, pad_zero), axis = -1)
    return lbl_upscale

def get_label_accuracy(pred, lbl_original):
    # Get pixel-wise labelling accuracy (DEMO metric)
    
    # Swap axes back
    pred = swap_axes(pred)
    pred_upscale = upscale_segmentation(pred, np.shape(lbl_original))
    return 100*np.sum(np.equal(pred_upscale, lbl_original))/np.prod(lbl_original.shape)

def get_mean_iou(pred, lbl_original, num_classes = OUTPUT_CLASSES, ret_full = False, reswap = False):
    # Get mean IOU between input predictions and target labels. Note, method implicitly resizes as needed
    # Ret_full - returns the full iou across all classes
    # Reswap - if lbl_original is in tensorflow format, swap it back into the format expected by plotting tools (+ format of raw data)
    
    # Swap axes back 
    pred = swap_axes(pred)
    if reswap:
        lbl_original = swap_axes(lbl_original)
    pred_upscale = upscale_segmentation(pred, np.shape(lbl_original))
    iou = [1]*num_classes
    for i in range(num_classes): 
        test_shape = np.zeros(np.shape(lbl_original))
        test_shape[pred_upscale == i] = 1
        test_shape[lbl_original == i] = 1
        full_sum = int(np.sum(test_shape))
        test_shape = -1*np.ones(np.shape(lbl_original))
        test_shape[lbl_original == i] = pred_upscale[lbl_original == i]
        t_p = int(np.sum(test_shape == i))
        if not full_sum == 0:
            iou[i] = t_p/full_sum
    if ret_full:
        return iou
    else: 
        return np.mean(iou)
    
def swap_axes(pred):
    # Swap those axes
    pred = np.swapaxes(pred, -1, 0)
    pred = np.squeeze(pred)
    return pred

### Load dataset

In [None]:
loc = './dataset'

In [None]:
data = sio.loadmat(os.path.join(loc,'dataset.mat')) 
keys = sorted(data.keys())
data = data[keys[3]] #shape: batch x 2 x h x w x d x ch (2 because 1 for data, another for the correspondint ground truth)

print('data shape: ', data.shape)

train = data[0:2004] 
test = data[2004:] 


print('train shape: ', train.shape)
print('test shape: ', test.shape)

del data

'''
data shape:  (2338, 2, 64, 64, 64, 1)
train shape:  (2004, 2, 64, 64, 64, 1)
test shape:  (334, 2, 64, 64, 64, 1)
'''

In [None]:
# Extract train raw and label
train_raw = train[:,0] #dtype: float32
train_label = train[:,1] #dtype: int16
train_label = train_label.astype('int8') #dtype: int8

# Extract test raw and label
test_raw = test[:,0] #dtype: float32
test_label = test[:,1] #dtype: int16
test_label = test_label.astype('int8') #dtype: int8

del train, test

print('Train shape: ', train_raw.shape)
print('Test shape: ', test_raw.shape)

'''
Train shape:  (2004, 64, 64, 64, 1)
Test shape:  (334, 64, 64, 64, 1)
'''

### Normalization

In [None]:
def normalize(data_set, min_int, max_int):  
    data_set = data_set.astype('float32')
    data_set_norm = (data_set - min_int)/(max_int - min_int)
    
    return data_set_norm


min_int, max_int = -1000, 3095 # min_int and max_int are the min and max value in the dataset

train_raw = normalize(train_raw, min_int, max_int)
test_raw = normalize(test_raw, min_int, max_int)

## Hyperparameters

In [None]:
LEARNING_RATE = 0.001 # Model learning rate
NUM_STEPS = 100000 # Number of train steps per model train
BATCH_SIZE = 3 
SAVE_PATH = './tf' #"./tf/" 
LOGS_PATH = './tf_logs' #"./tf_logs/"
LOAD_MODEL = True
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)
if not os.path.exists(LOGS_PATH):
    os.makedirs(LOGS_PATH)
MODEL_NAME = 'model' # Model name to LOAD FROM (LOOKS IN SAVE_PATH DIRECTORY)

### Attention UNET Model

In [None]:
from tensorflow.compat.v1.keras import backend as K
from tensorflow.keras.layers import UpSampling3D
from tensorflow.keras import models, layers, regularizers

class UNetwork():
    
    def conv_batch_relu(self, tensor, filters, kernel = [3,3,3], stride = [1,1,1], is_training = True):
        padding = 'valid'
        if self.should_pad: padding = 'same'
    
        conv = tf.layers.conv3d(tensor, filters, kernel_size = kernel, strides = stride, padding = padding,
                                kernel_initializer = self.base_init, kernel_regularizer = self.reg_init)
        conv = tf.layers.batch_normalization(conv, training = is_training)
        conv = tf.nn.relu(conv) 
        return conv

    def upconvolve(self, tensor, filters, kernel = 2, stride = 2, scale = 4, activation = None):
        padding = 'valid'
        if self.should_pad: padding = 'same'

        conv = tf.layers.conv3d_transpose(tensor, filters, kernel_size = kernel, strides = stride, padding = padding, use_bias=False, 
                                          kernel_initializer = self.base_init,  kernel_regularizer = self.reg_init)
        return conv
    
    # ************************************** START: Attention mechanism ************************************** #
    def repeat_elem(self, tensor, rep):

         return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=4),
                              arguments={'repnum': rep})(tensor) # pay attention to axis
        
    def gating_signal(self, input, out_size, batch_norm=False):

        x = tf.layers.conv3d(inputs=input,filters=out_size, kernel_size=(1, 1, 1), strides=(1, 1, 1), padding='same')
        if batch_norm:
            x = tf.layers.batch_normalization(x)
        x = tf.nn.relu(x)
        return x

    def attention_block(self, x, gating, inter_shape):
        shape_x = K.int_shape(x)
        shape_g = K.int_shape(gating)

        # Getting the x signal to the same shape as the gating signal
        theta_x = tf.layers.conv3d(x, inter_shape, (2, 2, 2), strides=(2, 2, 2), padding='same')  # 16
        shape_theta_x = K.int_shape(theta_x)

        # Getting the gating signal to the same number of filters as the inter_shape
        phi_g = tf.layers.conv3d(gating, inter_shape, (1, 1, 1), padding='same')

        upsample_g = tf.layers.conv3d_transpose(phi_g, inter_shape, (3, 3, 3),
                                     strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2], shape_theta_x[3] // shape_g[3]),
                                     padding='same')  # 16

        concat_xg = tf.keras.layers.add([upsample_g, theta_x]) 
        act_xg = tf.nn.relu(concat_xg)
        psi = tf.layers.conv3d(act_xg, 1, (1, 1, 1), padding='same')
        sigmoid_xg = tf.nn.sigmoid(psi)
        shape_sigmoid = K.int_shape(sigmoid_xg)
        upsample_psi = UpSampling3D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2], shape_x[3] // shape_sigmoid[3]))(sigmoid_xg)  # 32
                                    
        upsample_psi = self.repeat_elem(upsample_psi, shape_x[-1])  

        y = tf.keras.layers.multiply([upsample_psi, x])

        result = tf.layers.conv3d(y, shape_x[3], (1, 1, 1), padding='same')
        result_bn = tf.layers.batch_normalization(result)
        return result_bn
    
    # ************************************** END: Attention mechanism ************************************** #


    def centre_crop_and_concat(self, prev_conv, up_conv):
        # If concatenating two different sized Tensors, centre crop the first Tensor to the right size and concat
        # Needed if you don't have padding
        p_c_s = prev_conv.get_shape()
        u_c_s = up_conv.get_shape()
        offsets =  np.array([0, (p_c_s[1] - u_c_s[1]) // 2, (p_c_s[2] - u_c_s[2]) // 2, 
                             (p_c_s[3] - u_c_s[3]) // 2, 0], dtype = np.int32)
        size = np.array([-1, u_c_s[1], u_c_s[2], u_c_s[3], p_c_s[4]], np.int32)
        prev_conv_crop = tf.slice(prev_conv, offsets, size)
        up_concat = tf.concat((prev_conv_crop, up_conv), 4)
        return up_concat
        
    def __init__(self, base_filt = 8, in_depth = INPUT_DEPTH, out_depth = OUTPUT_DEPTH,
                 in_size = INPUT_SIZE, out_size = OUTPUT_SIZE, num_classes = OUTPUT_CLASSES,
                 learning_rate = LEARNING_RATE, print_shapes = True, drop = 0.2, should_pad = False):
       
        
        self.base_init = tf.truncated_normal_initializer(stddev=0.1) # Initialise weights
        self.reg_init = tf.keras.regularizers.l2(0.1) # Initialise regularisation (was useful)
        
        self.should_pad = should_pad # To pad or not to pad, that is the question
        self.drop = drop # Set dropout rate
        
        with tf.variable_scope('3DuNet'):
            self.training = tf.placeholder(tf.bool)
            self.do_print = print_shapes
            self.model_input = tf.placeholder(tf.float32, shape = (None, in_depth, in_size, in_size, 1))  
            # Define placeholders for feed_dict
            self.model_labels = tf.placeholder(tf.int32, shape = (None, in_depth, in_size, in_size, 1))
            labels_one_hot = tf.squeeze(tf.one_hot(self.model_labels, num_classes, axis = -1), axis = -2)
            
            if self.do_print: 
                print('Input features shape', self.model_input.get_shape())
                print('Labels shape', labels_one_hot.get_shape())
                
            # Level zero
            conv_0_1 = self.conv_batch_relu(self.model_input, base_filt, is_training = self.training)
            conv_0_2 = self.conv_batch_relu(conv_0_1, base_filt*2, is_training = self.training)
            # Level one
            max_1_1 = tf.layers.max_pooling3d(conv_0_2, [2,2,2], [2,2,2]) 
            conv_1_1 = self.conv_batch_relu(max_1_1, base_filt*2, is_training = self.training)
            conv_1_2 = self.conv_batch_relu(conv_1_1, base_filt*4, is_training = self.training)
            conv_1_2 = tf.layers.dropout(conv_1_2, rate = self.drop, training = self.training)
            # Level two
            max_2_1 = tf.layers.max_pooling3d(conv_1_2, [2,2,2], [2,2,2]) 
            conv_2_1 = self.conv_batch_relu(max_2_1, base_filt*4, is_training = self.training)
            conv_2_2 = self.conv_batch_relu(conv_2_1, base_filt*8, is_training = self.training)
            conv_2_2 = tf.layers.dropout(conv_2_2, rate = self.drop, training = self.training) 
            # Level three
            max_3_1 = tf.layers.max_pooling3d(conv_2_2, [2,2,2], [2,2,2]) 
            conv_3_1 = self.conv_batch_relu(max_3_1, base_filt*8, is_training = self.training)
            conv_3_2 = self.conv_batch_relu(conv_3_1, base_filt*16, is_training = self.training)
            conv_3_2 = tf.layers.dropout(conv_3_2, rate = self.drop, training = self.training)
            
            # Level two
            gating_1 = self.gating_signal(conv_3_2, base_filt*8, batch_norm=True)          
            att_1 = self.attention_block(conv_2_2, gating_1, base_filt*8)           
            up_conv_3_2 = self.upconvolve(conv_3_2, base_filt*16, kernel = 2, stride = [2,2,2]) 
            concat_2_1 = self.centre_crop_and_concat(att_1, up_conv_3_2)          
            conv_2_3 = self.conv_batch_relu(concat_2_1, base_filt*8, is_training = self.training)
            conv_2_4 = self.conv_batch_relu(conv_2_3, base_filt*8, is_training = self.training)
            conv_2_4 = tf.layers.dropout(conv_2_4, rate = self.drop, training = self.training)
            
            # Level one
            gating_2 = self.gating_signal(conv_2_4, base_filt*4, batch_norm=True) 
            att_2 = self.attention_block(conv_1_2, gating_2, base_filt*4)        
            up_conv_2_1 = self.upconvolve(conv_2_4, base_filt*8, kernel = 2, stride = [2,2,2]) 
            concat_1_1 = self.centre_crop_and_concat(att_2, up_conv_2_1)
            conv_1_3 = self.conv_batch_relu(concat_1_1, base_filt*4, is_training = self.training)
            conv_1_4 = self.conv_batch_relu(conv_1_3, base_filt*4, is_training = self.training)
            conv_1_4 = tf.layers.dropout(conv_1_4, rate = self.drop, training = self.training)
            
            # Level zero
            gating_3 = self.gating_signal(conv_1_4, base_filt*2, batch_norm=True)
            att_3 = self.attention_block(conv_0_2, gating_3, base_filt*2)           
            up_conv_1_0 = self.upconvolve(conv_1_4, base_filt*4, kernel = 2, stride = [2,2,2])  
            concat_0_1 = self.centre_crop_and_concat(att_3, up_conv_1_0)
            conv_0_3 = self.conv_batch_relu(concat_0_1, base_filt*2, is_training = self.training)
            conv_0_4 = self.conv_batch_relu(conv_0_3, base_filt*2, is_training = self.training)
            conv_0_4 = tf.layers.dropout(conv_0_4, rate = self.drop, training = self.training)
            conv_out = tf.layers.conv3d(conv_0_4, OUTPUT_CLASSES, [1,1,1], [1,1,1], padding = 'same')
            self.predictions = tf.expand_dims(tf.argmax(conv_out, axis = -1), -1)
            
            if self.do_print: 
                print('Model Convolution output shape', conv_out.get_shape())
                print('Model Argmax output shape', self.predictions.get_shape())
            
            do_weight = True
            loss_weights = [1, 150, 100, 1.0]
            # Weighted cross entropy: approach adapts following code: https://stackoverflow.com/questions/44560549/unbalanced-data-and-weighted-cross-entropy
            ce_loss = tf.nn.softmax_cross_entropy_with_logits(logits=conv_out, labels=labels_one_hot)
            if do_weight:
                weighted_loss = tf.reshape(tf.constant(loss_weights), [1, 1, 1, 1, num_classes]) # Format to the right size
                weighted_one_hot = tf.reduce_sum(weighted_loss*labels_one_hot, axis = -1)
                ce_loss = ce_loss * weighted_one_hot
            
            self.loss = tf.reduce_mean(ce_loss) # Get loss
            
            self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            
            self.extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Ensure correct ordering for batch-norm to work
            with tf.control_dependencies(self.extra_update_ops):
                self.train_op = self.trainer.minimize(self.loss)
            

### On-the-fly Augmentation

In [None]:
from aug import rotation, flip, rot90, rot180

In [None]:
def get_data_raw_sample_aug(raw, label, size):
    # Convert arrays into list
    rList = list(raw)
    lbList = list(label)

    idx = random.sample(range(len(rList)), k=size) #randomly peak some values 

    # Get raws and labels using idx
    selected_raw = [rList[i] for i in idx] 
    selected_label = [lbList[i] for i in idx]
    selected_raw = np.array(selected_raw)
    selected_label = np.array(selected_label)

    # Normalize data (*****Uncomment if you want to normalize*****)
    # min_int = -1000
    # max_int = 3095
    # selected_raw = (selected_raw - min_int)/(max_int - min_int)
    # selected_label = selected_label/255 #label is either 0 or 255

    # selected_raw = selected_raw.astype('float32')
    # selected_label = selected_label.astype('int8')
  
    # Do augmentation
#    aug_val -> 0,1,2,3,4
    aug_val = np.random.choice([0, 1],1) # np.random.choice([0, 1],1)
    # aug_val -> 0    means original data
    # aug_val -> 1    means rotation within a specific range
    # aug_val -> 2    means rotation by 90 about a specific axis
    # aug_val -> 3    means rotation by 180 about a specific axis
    # aug_val -> 4    means flip by numpy

    if (aug_val==0):
      #print('No augmentation')
      return list(selected_raw), list(selected_label)

    elif (aug_val==1):
      # Create random angles between -10 to +10      
      for i in range(size):        
        angle = np.random.randint(low=-10, high=10, size=(1,3))
        #print('Rotated by {}, {}, {}'.format(angle[0][0],angle[0][1],angle[0][2]))
        selected_raw[i], selected_label[i] = rotation(selected_raw[i], selected_label[i], 
                                                      (angle[0][0],angle[0][1],angle[0][2]))
      return list(selected_raw), list(selected_label)

    elif (aug_val==2):
      for i in range(size):        
        axis = np.random.choice([0,1,2], 1) # randomly peak an axis
        selected_raw[i], selected_label[i] = rot90(selected_raw[i], selected_label[i], axis)
      return list(selected_raw), list(selected_label)

    elif (aug_val==3):
      for i in range(size):        
        axis = np.random.choice([0,1,2], 1) # randomly peak an axis
        selected_raw[i], selected_label[i] = rot180(selected_raw[i], selected_label[i], axis)
      return list(selected_raw), list(selected_label)   
      
    elif (aug_val==4):
      for i in range(size):
        flip_val = random.choice([0,2])                         
        #print('Flipped along axis {}'.format(flip_val))                         
        selected_raw[i], selected_label[i] = flip(selected_raw[i], selected_label[i], flip_val)
          # flip_val -> 0    means flip along first axis
          # flip_val -> 2    means flip along third axis
      return list(selected_raw), list(selected_label)
  
    # print('random index: ', idx)
    # print('max', np.max(selected_raw), np.max(selected_label))
    # print(selected_raw.shape, selected_label.shape) # I added

#-------------------------------------------------------------------------
    # x_y_data = random.sample(list(data), size)
    # return [x[0] for x in x_y_data], [y[1] for y in x_y_data]

# x, y = get_data_raw_sample_aug(train_raw, train_label, BATCH_SIZE) # Draw samples from batch



### Train model

In [None]:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

In [None]:
tf.reset_default_graph()
unet = UNetwork(drop = 0.15, base_filt = 32, should_pad = True) # MODEL DEFINITION
init = tf.global_variables_initializer()
# saver = tf.train.Saver(tf.global_variables()) #this is original, I have commneted out
saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)
config = tf.ConfigProto()
st_loss = []
with tf.Session(config=config) as sess:
    writer = tf.summary.FileWriter(LOGS_PATH, graph=tf.get_default_graph())
    if LOAD_MODEL:
        print('Trying to load saved model...')
        try:
            print('Loading from: ', SAVE_PATH +'/'+ MODEL_NAME+ '.meta')
            restorer = tf.train.import_meta_graph(SAVE_PATH +'/'+ MODEL_NAME+ '.meta')
            restorer.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))
            print("Model sucessfully restored")
        except IOError:
            sess.run(init)
            print("No previous model found, running default init") 
    t_loss = []
    for i in range(NUM_STEPS):
        print('Current iter: ', i, end='\r')
#         x, y, orig_y = get_dataset_sample(train, BATCH_SIZE, no_perturb = True) (USED IF DATA-AUG AT RUNTIME)
        # x, y = get_data_raw_sample(train_run, BATCH_SIZE) # Draw samples from batch
        x, y = get_data_raw_sample_aug(train_raw, train_label, BATCH_SIZE) #runtime augmentation function called
        train_dict = {
            unet.training: True,
            unet.model_input: x,
            unet.model_labels: y
        }
        _, loss = sess.run([unet.train_op, unet.loss], feed_dict = train_dict) # Get loss
        t_loss.append(loss) # Loss store
        # Store loss
        if i % 500 == 0 and i > 0:
            st_loss.append([i, loss])
        if i>30000 and i % 500 == 0 and i > 0:
            print('Saving model at iter: ', i) # Save periodically
            saver.save(sess, SAVE_PATH + MODEL_NAME, global_step = i)
        if i>30000 and i % 500 == 0 and i > 0:
            print('Iteration', i, 'Loss: ', np.mean(t_loss)) # Get periodic progress reports
            t_loss = []
            iou_size = 5
#             x, y, orig_y = get_dataset_sample(train, iou_size) (USED IF DATA-AUG AT RUNTIME)
            # x, y = get_data_raw_sample(train_run, BATCH_SIZE)
            x, y = get_data_raw_sample_aug(train_raw, train_label, BATCH_SIZE)
            train_dict = {
                unet.training: False,
                unet.model_input: x,
                unet.model_labels: y
            }
            preds = np.squeeze(sess.run([unet.predictions], feed_dict = train_dict))
            iou = get_pred_iou(preds, y, ret_full = True, reswap = True)
            print('Train IOU (on SCALED anns): ', iou, 'Mean: ', np.mean(iou[:OUTPUT_CLASSES-1]))
            
            # Get test mean IOU over batch
            # x, y, orig_y = get_dataset_sample(test, iou_size, no_perturb = True)            
            # train_dict = {
            #     unet.training: False,
            #     unet.model_input: x,
            #     unet.model_labels: y
            # }
            # preds = np.squeeze(sess.run([unet.predictions], feed_dict = train_dict))
            # iou = get_pred_iou(preds, orig_y, ret_full = True)
            # print('Test IOU (on ORIGINAL anns): ', iou, 'Mean: ', np.mean(iou[:OUTPUT_CLASSES-1]))
            print('######################')
  
    saver.save(sess,SAVE_PATH + MODEL_NAME, global_step = NUM_STEPS) # Final save

In [None]:
# Store loss
sio.savemat('loss.mat', {'ls':st_loss}, do_compression=True)

### Test model

In [None]:
# For testing
MODEL_PATH = './CBCT_Dental_NCI'
MODEL_NAME = 'tfmodel-99500'

# Activate tensorflow v1
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

# Load model
TEST_MODEL_NAME = MODEL_NAME
tf.reset_default_graph()
unet = UNetwork(drop = 0.15, base_filt = 32, should_pad = True) # MODEL DEFINITION

config = tf.ConfigProto()
test_predictions = []
restorer = tf.train.Saver(tf.global_variables())
with tf.Session(config=config) as sess:
    # tf.initialize_all_variables().run()
    print('Loading saved model ...')
#   try
    graph = tf.get_default_graph()
    restorer.restore(sess, SAVE_PATH +'/'+ TEST_MODEL_NAME)
#     restorer = tf.train.import_meta_graph(MODEL_PATH +'/'+ TEST_MODEL_NAME + '.meta')    
#     restorer.restore(sess, tf.train.latest_checkpoint(MODEL_PATH))
    print("Model sucessfully restored")
    pred_out = []
    y_orig = []
    x_orig = []
    x_in = []
    y_in = []
    i = 0
    iou_out = []

    while i < len(test_raw):
        x_batch = []
        y_batch = []
        for j in range(i, min(len(test_raw), i + BATCH_SIZE)):
            y_orig.append(np.copy(test_label[j]))
            x_orig.append(np.copy(test_raw[j]))
            x_cur, y_cur = test_raw[j], test_label[j]

            x_batch.append(x_cur)
            y_batch.append(y_cur)
        if len(x_batch) == 0: break
        print('Processing ', i)
        x_in = x_in + x_batch
        y_in = y_in + y_batch
        test_dict = {
            unet.training: True, # Whether to perform batch-norm at inference (Paper says this would be useful)
            unet.model_input: x_batch,
            unet.model_labels: y_batch
        }
        test_predictions = np.squeeze(sess.run([unet.predictions], feed_dict = test_dict))
        if len(x_batch) == 1:
            pred_out.append(test_predictions)
        else:
            pred_out.extend([np.squeeze(test_predictions[z, :, :, :]) for z in list(range(len(x_batch)))])
        i += BATCH_SIZE

    for i in range(len(y_orig)):
        iou = get_mean_iou(pred_out[i], np.squeeze(y_orig[i]), ret_full = True)
        print('Test IOU: ', iou, 'Mean: ', np.mean(iou[:OUTPUT_CLASSES-1]))
        iou_out.append(np.mean(iou[:OUTPUT_CLASSES-1]))

    print('Mean test IOU', np.mean(iou_out), 'Var IOU', np.var(iou_out))


#     except Exception as e:
#         print('Something went wrong!', e)

pred_out = np.array(pred_out).astype('int8')
sio.savemat(os.path.join(MODEL_PATH, 'pred_' + MODEL_NAME + '.mat'), {'p':pred_out}, do_compression=True)

