# Test Model

In [1]:
import os
import glob
import scipy.io
import sys
import h5py
import yaml
import sklearn
import warnings
import tqdm
import csv
import cv2
import argparse
from skimage import io
import seaborn as sns
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import torch.utils.data as data
import sys
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.autograd import Variable 
torch.manual_seed(0)

os.environ["CUDA_VISIBLE_DEVICES"]="1"



In [2]:
with warnings.catch_warnings():
    warnings.filterwarnings("ignore",category=FutureWarning)
    import h5py
    
def dice(x_predict,y_predict):
    
    """calculate dice coefficient
        Args:
            x_predict (np.bool): ground truth
            y_predict (np.bool): predicted
            
        Returns:
            dice coefficient

    """

    if x_predict.shape != y_predict.shape:
        raise ValueError('Shape mismatch')

    intersection=np.logical_and(x_predict,y_predict)
    
    return 2.*intersection.sum()/(x_predict.sum()+y_predict.sum())

def evaluate(x_predict, y_predict):
    """obtain the dice coefficient and return it
        Args:
            x_predict (np.bool): ground truth
            y_predict (np.bool): predicted
            
        Returns:
            metrices (numpy array): dice coefficient

    """

    metrices=dice(x_predict, y_predict)
    return metrices

def statistics(metric_layer_dict,results_folder, results_name):
    
    """plot box plot of dice coefficient
        Args:
            metric_layer_dict (dict): dictionary of dice coefficient
            results_folder (str): location to save plot
            results_name (str): name of plot
            
        Returns:
            None

    """
    # to exclude background from plots
    metric_layer = [l for l in metric_layer_dict.values()]
    layer_list = [l for l in metric_layer_dict.keys()]
    plt.clf()
    fig=plt.figure(1,figsize=(40,15))
    ax=fig.add_subplot(111)
#     bp=ax.boxplot(metric_layer[:-1]);
#     ax.set_xticklabels(layer_list[:-1])
    bp=ax.boxplot(metric_layer[1:2]);
    ax.set_xticklabels(layer_list[1:2])
    ax.set_ylim([0,1])
    plt.setp(ax.get_xticklabels(), rotation=15, horizontalalignment='right')
    plt.savefig(os.path.join(results_folder,results_name+'.png'))
    
    
class TestData(data.Dataset):
    def __init__(self, X):
        self.X = X

    def __getitem__(self, index):
        img = self.X[index]
        img = torch.from_numpy(img)
        return img

    def __len__(self):
        return len(self.X)
    
def get_predictions(BScans, dimensions, relaynet_model):
    
    """get predictions with pretrained model
        Args:
            BScans (numpy array): image array
            dimensions (dict): contains number of bscans, layers, height and width
            relaynet_model : pytorch model
            
        Returns:
            stitched_stauck (numpy array): predicted segmentations

    """

    stitched_stack = np.zeros((dimensions['bscans'], 
                                    dimensions['layers'], dimensions['height'], dimensions['width']), dtype=np.float32)
    test_dataset = TestData(BScans)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
    for idx, (img) in enumerate(test_loader):
        with torch.no_grad():
            out = relaynet_model(Variable(img.float().cuda()))
        out = F.softmax(out,dim=1)
        
        stitched_stack[idx] = np.squeeze(out.data.cpu().numpy())
        
    return stitched_stack

def build_mask(annotations,HEIGHT,WIDTH): 

    """ Build mask for 1 scan
    """
    layers=np.zeros((annotations.shape[0],HEIGHT,WIDTH), dtype=np.float32)
    background=np.ones((HEIGHT,WIDTH))
    for layer in range(annotations.shape[0]-1): # loop through the layers
        for i in range(annotations.shape[1]): # loop through the width of the image
            if int(np.round(annotations[layer, i]))-1>0 and int(np.round(annotations[layer+1, i]))-1>0:
            # build channels of mask
                layers[layer,np.round(annotations[layer, i]).astype(int)-1:np.round(annotations[layer+1, i]).astype(int)-1,i]=1 # need to -1 to convert matlab index to python index 
    # get background
    layers[-1]=background-np.sum(layers[:-1,:,:],0)
    return layers

def convert_label_png_to_mask(annotation, HEIGHT, WIDTH):
    ''' Converts 2 dimentional png label input into one-hot encoded target mask '''
    layers = np.zeros((3,HEIGHT,WIDTH)) #hardcoded for this instance, can change to max of annotation
    for h in range(HEIGHT):
        for w in range(WIDTH):
            layers[annotation[h,w], h,w] = 1
            
    return layers

def prepare_dataset(image_path, label_path=None):

    image_stack = io.imread(image_path)
    if image_stack.dtype == 'uint8':
        image_stack = image_stack/255
    elif image_stack.dtype == 'uint16':
        image_stack = image_stack*(1.0 / 65535.0)
#         image_stack = (image_stack-np.min(image_stack, axis=(1,2))[:,None,None])/(np.max(image_stack, axis=(1,2))[:,None,None]-np.min(image_stack, axis=(1,2))[:,None,None])
        
    # load corresponding label
    images_array = image_stack     
    
    if label_path is not None:
        mat = scipy.io.loadmat(label_path)
        annotations=mat['surface_matrix'] 
        undefine_surface =mat['undefine_surface']
        annotations = np.multiply(annotations, undefine_surface)

        lmap_array = []

        for scan in range(images_array.shape[0]):
            image = images_array[scan]
            lmap = build_mask(annotations[:,:,scan],image.shape[0],image.shape[1])
            lmap_array.append(lmap)

        lmap_array = np.array(lmap_array)
        return images_array, lmap_array
    else:
        return images_array
    
def get_coordinates_for_crop(mip, height, y_top):
    ret, thresh1 = cv2.threshold(np.uint8(mip), 10, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    kernel = np.ones((10,10),np.uint8)
    thresh_opening = cv2.morphologyEx(thresh1, cv2.MORPH_OPEN, kernel)
    # stat matrix contains information about bounding box
    (num_labels ,labelled_matrix, stat_matrix, centroid_matrix) = cv2.connectedComponentsWithStats(np.uint8(thresh_opening)) 
    
    
    max_area = stat_matrix
    all_idx = []
    for stat in range(stat_matrix.shape[0]):
        left_coordinate = stat_matrix[stat,0]
        top_coordinate = stat_matrix[stat,1]
        bottom_coordinate = stat_matrix[stat,1] - stat_matrix[stat,3]
        right_coordinate = stat_matrix[stat,0] + stat_matrix[stat,2]
        area = stat_matrix[stat,4]
        if left_coordinate == 0 or right_coordinate == 0 or top_coordinate == 0 or bottom_coordinate == 0:
            labelled_matrix[labelled_matrix == stat] = 0

    y_top_intervals = range(y_top, labelled_matrix.shape[0], 20)
    
    all_area = []
    for y_interval in y_top_intervals:
        area = np.count_nonzero(labelled_matrix[y_interval:y_interval+height])/(height*labelled_matrix.shape[1])
        all_area.append(area)
    idx = np.argmax(np.asarray(all_area))
    y_top_idx = y_top_intervals[idx]
    return y_top_idx, ret

def filter_image_quality(IM, save_imquality, dataset, threshold = 0.01, height_ratio = 0.25):
    y_top = 0
    height = round(IM.shape[2]*height_ratio)
    IM_MAX = np.max(IM, axis=0)

    y_top_idx, ret = get_coordinates_for_crop(IM_MAX[0]*255, height, y_top)
    ratio = np.squeeze(np.count_nonzero(IM[:,:,y_top_idx:y_top_idx+height]>ret/255, axis=(2,3))/(height*IM.shape[3]))
    scans_idx = list(range(0, IM.shape[0]))

    filtered_scans = [[scan, r] for scan, r in zip(scans_idx, ratio) if r > threshold]
    df = pd.DataFrame(filtered_scans, columns= ['scan', 'ratio'])
    df.to_csv(os.path.join(save_imquality, dataset +'.csv'), index=False)
    return df


In [None]:
with open( "./test_preclinical_evaluate.yaml") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    filepaths = config['filepaths']
    save_filepaths = config['save_filepaths']
    layer_mapping = config['general']['layers_mapping']
    test_dataset_list = filepaths['test_dataset']
    get_image_quality = config['general']['get_image_quality']
    relaynet_model =  torch.load(filepaths['model_path'])
    relaynet_model.eval()

    if test_dataset_list is not None:
        all_datasets = []
        with open(test_dataset_list,'r') as reader:
            for idx, line in enumerate(reader.readlines()):
                all_datasets.append(line.strip('\n'))
    else:
        all_datasets = [os.path.basename(f) for f in glob.glob(os.path.join(filepaths['image_path'], '*')) if os.path.isdir(f)]


    all_predictions = []
    metric_layer_dict={value:[] for value in layer_mapping.values()}
    counter=0
    for dataset in tqdm.tqdm(all_datasets, ascii=True):  
        save_pred = os.path.join(save_filepaths['predictions_path'], 'labels', dataset)
        save_eval = os.path.join(save_filepaths['evaluations_path'], dataset)
        save_pred_viz = os.path.join(save_filepaths['predictions_path'], 'viz_labels', dataset)
        save_imquality = os.path.join(save_filepaths['image_quality_path'])
        save_npy = os.path.join(save_filepaths['predictions_path'], os.path.dirname(dataset))

        Path(save_pred).mkdir(parents=True, exist_ok = True)
        Path(save_eval).mkdir(parents=True, exist_ok = True)
        Path(save_npy).mkdir(parents=True, exist_ok = True)
        Path(save_pred_viz).mkdir(parents=True, exist_ok = True)

        if get_image_quality: Path(save_imquality).mkdir(parents=True, exist_ok=True)

        image_paths = glob.glob(os.path.join(filepaths['image_path'], dataset, '*'+filepaths['image_ext']))
        image_paths = sorted(image_paths, key = lambda x:int(os.path.splitext(os.path.basename(x))[0].split('_')[-1]))

        images_array = []
        lmap_array = []
        
        for i in range(len(image_paths)):
            image_path = image_paths[i]
            img = io.imread(image_path)
            if img.dtype == 'uint8':
                img = img/255
            elif img.dtype == 'uint16':
                img = img*(1.0 / 65535.0)
            
            #Find corresponding label 
            image_path_stem = Path(image_paths[i]).stem
            label_path = label_fullpath = os.path.join(filepaths['label_path'],dataset,image_path_stem+'.png')
            label = io.imread(label_path)
            lmap = convert_label_png_to_mask(label,label.shape[0], label.shape[1])
            
            #Appending image and lmap to respective arrays
            images_array.append(img)
            lmap_array.append(lmap)
            
        images_array = np.array(images_array, dtype=np.float32)
        lmap_array = np.array(lmap_array)
        images_array2 = np.expand_dims(images_array, axis = 1)

        if get_image_quality: df = filter_image_quality(images_array2, save_imquality, dataset)

        dimensions = {'bscans': images_array.shape[0],
                      'layers': len(layer_mapping.keys()),
                      'height': images_array.shape[1],
                      'width' : images_array.shape[2]
        }
        predicted_stack = get_predictions(images_array2, dimensions, relaynet_model)
        predicted_stack_argmax = np.argmax(predicted_stack, axis=1)
        if 'label_path' in list(filepaths.keys()):   
            for scan in range(predicted_stack.shape[0]):
                for layer in range(predicted_stack.shape[1]):
                    mask = predicted_stack_argmax[scan]==layer
                    mask=mask.astype(np.bool)
                    gt = lmap_array[scan, layer]
                    gt=gt.astype(np.bool)
                    metric=evaluate(gt,mask)
                    metric_layer_dict[layer_mapping[layer]].append(metric)

                    if counter==0: # create csv file if not created before
#                         os.makedirs(os.path.join(save_filepaths['evaluations_path'], save_filepaths['results_name']))
                        with open(os.path.join(save_eval, save_filepaths['results_name']+".csv"),"w") as csvfile:
                            filewriter = csv.writer(csvfile, delimiter=',',quotechar='|', quoting=csv.QUOTE_MINIMAL,lineterminator='\n')
                            filewriter.writerow([dataset,scan,layer_mapping[layer], str(metric)])
                            counter=counter+1
                    else: #otherwise, append to the csv file
                        with open(os.path.join(save_eval, save_filepaths['results_name']+".csv"),"a") as csvfile:
                            filewriter = csv.writer(csvfile, delimiter=',',quotechar='|', quoting=csv.QUOTE_MINIMAL, lineterminator='\n')
                            filewriter.writerow([dataset,scan,layer_mapping[layer], str(metric)])
            

                cv2.imwrite(os.path.join(save_pred, 'Slice_1_'+str(scan+1)+'_argmax.png'), predicted_stack_argmax[scan])
                plt.figure(figsize=(30,15))
                plt.imshow(predicted_stack_argmax[scan], vmin=0, vmax=dimensions['layers']-1)
                plt.savefig(os.path.join(save_pred_viz, 'Slice_1_'+str(scan+1)+'_argmax.png'))
                plt.close()

                np.save(os.path.join(save_filepaths['predictions_path'], dataset+'.npy'), predicted_stack_argmax)

        else:
            for scan in range(predicted_stack_argmax.shape[0]):
                cv2.imwrite(os.path.join(save_pred, 'Slice_1_'+str(scan+1)+'_argmax.png'), predicted_stack_argmax[scan])
                plt.figure(figsize=(30,15))
                plt.imshow(predicted_stack_argmax[scan], vmin=0, vmax=dimensions['layers']-1)
                plt.savefig(os.path.join(save_pred_viz, 'Slice_1_'+str(scan+1)+'_argmax.png'))
                plt.close()
            np.save(os.path.join(save_filepaths['predictions_path'], dataset+'.npy'), predicted_stack_argmax)
            
    if 'label_path' in list(filepaths.keys()): 
        statistics(metric_layer_dict,save_filepaths['evaluations_path'], save_filepaths['results_name'])
        total_metric_df = pd.DataFrame(metric_layer_dict)
        total_metric_df.to_csv(os.path.join(save_filepaths['evaluations_path'], save_filepaths['results_name']+'.csv'), index=False)


        

  3%|#5                                          | 3/87 [00:45<23:04, 16.49s/it]