In [None]:
import fibresegt as fs
import numpy as np
import matplotlib.pyplot as plt


import skimage
import cv2

from scipy.ndimage import label, center_of_mass
from scipy.spatial.distance import cdist

In [None]:
undersample_factor = [8, 12]
noise_level = [4, 8, 12, 16, 20, 24]

# Datasets
dataset_folder = '/scratch/pooja/fiber_detection/data/final_data'
Mock_UD = '/ScanRef_Glass_Mock_UD/'
noisy_Mock = 'im_Mock500noisy'
mask_Mock = 'mask_Mock500.png'

def divided_img_blocks(img, n_blocks=(2,2)):
   horizontal = np.array_split(img, n_blocks[0])
   splitted_img = [np.array_split(block, n_blocks[1], axis=1) for block in horizontal]
   return splitted_img

def get_coords(mask):
    mask_np = mask
    labeled_mask, num_labels = label(mask_np)
    centroids = center_of_mass(mask_np, labels=labeled_mask, index=np.arange(1, num_labels+1))
    centroids = np.array(centroids)
    return centroids

def gauss_filter(s):
    '''Returns a 1D Gaussian filter of standard deviation s.'''
    x = np.arange(-np.ceil(5*s), np.ceil(5*s) + 1)[:,None]
    g = np.exp(-x**2/(2*s**2))
    g /= np.sum(g)
    dg = -x/s**2 * g
    ddg = -g/s**2 -x/s**2 * dg
    return g, dg, ddg

def get_pred_coords_gauss(im, s=2.5, min_distance=3, threshold_abs=0.4):
    '''Returns the coordinates of the predicted blobs in the image im.'''
    g = gauss_filter(s)[0]
    im_g = cv2.filter2D(cv2.filter2D(im, -1, g), -1, g.T)
    pred_coords = skimage.feature.peak_local_max(im_g, min_distance=min_distance, threshold_abs=threshold_abs)
    return pred_coords

def get_match_id(true_centroids, pred_centroids, dist_thres=3):
    '''Returns the ids of true positive, true negative, and false positive coordinates.'''
    dmat = cdist(true_centroids, pred_centroids, metric='euclidean')
    dmat = np.array(dmat)
    tid = np.argmin(dmat, axis=1) # for each true coord, id of the closest pred coord
    t_dist = np.min(dmat, axis=1) # for each true coord, distance to the closest pred coord
    p_dist = np.min(dmat, axis=0) # for each pred coord, distance to the closest true coord
    true_match = (t_dist == p_dist[tid]) & (t_dist < dist_thres) # true if the closest pred coord is the same as the 
                                                                 # true coord and the distance is less than dist_thres
    true_matched_id = np.where(true_match) # ids of true that are correctly matched
    true_not_matched_id = np.where(~true_match) # ids of true that are not matched

    pred_matched_id = tid[true_match] # ids of pred that are correctly matched

    neg = np.ones(dmat.shape[1], dtype=bool) # array of length of pred_coords, True if the pred coord is not matched
    neg[pred_matched_id] = False
    pred_not_matched_id = np.where(neg)
    return true_matched_id, true_not_matched_id, pred_matched_id, pred_not_matched_id

def detect_and_match(true_centroids, pred_centroids, dist_thres=3):
    '''Detects and matches the blobs in the image im and im_true. Returns the coordinates of the true and predicted blobs'''
    true_matched_id, true_not_matched_id, pred_matched_id, pred_not_matched_id = get_match_id(true_centroids, pred_centroids, dist_thres=dist_thres)
    true_matched = true_centroids[true_matched_id]
    pred_matched = pred_centroids[pred_matched_id]
    true_not_matched = true_centroids[true_not_matched_id]
    pred_not_matched = pred_centroids[pred_not_matched_id]
    return true_matched, true_not_matched, pred_matched, pred_not_matched


In [None]:
TP = np.zeros((len(undersample_factor), len(noise_level)))
TP_blob = np.zeros((len(undersample_factor), len(noise_level)))
precision = np.zeros((len(undersample_factor), len(noise_level)))
recall = np.zeros((len(undersample_factor), len(noise_level)))
f1 = np.zeros((len(undersample_factor), len(noise_level)))
precision_blob = np.zeros((len(undersample_factor), len(noise_level)))
recall_blob = np.zeros((len(undersample_factor), len(noise_level)))
f1_blob = np.zeros((len(undersample_factor), len(noise_level)))

dataset_folder1 = dataset_folder + Mock_UD
label_file = fs.join(dataset_folder1, mask_Mock)
labelInnerFibre = np.array(fs.imread(label_file))
if labelInnerFibre.ndim > 2:
    labelInnerFibre = labelInnerFibre[:,:,0]
divided_label = divided_img_blocks(labelInnerFibre, n_blocks=(2,2))
empty_block = np.zeros_like(np.array(divided_label[1][0]))
test_label = np.block([[divided_label[0][0], divided_label[0][1]], [empty_block, divided_label[1][1]]])
true_centroids = get_coords(test_label)

for i in range(len(undersample_factor)):
    for j in range(len(noise_level)):
        output_dir = './final_output_detect/Mock_noisy_uf' + str(undersample_factor[i]) + '_n' + str(noise_level[j]) + '/'
        trainedNet_dir = output_dir + 'checkpoint/'
        im_UD = noisy_Mock + '_uf' + str(undersample_factor[i]) + '_n' + str(noise_level[j]) + '.tiff'
        dataset_file = fs.join(dataset_folder1, im_UD)
        origData = np.array(fs.imread(dataset_file))
        origData = fs.normalize_8_bit(origData)
        data_info = dict(dataset_file=dataset_file, 
                 label_file=label_file)
        divided_img = divided_img_blocks(origData, n_blocks=(2,2))
        empty_block = np.zeros_like(np.array(divided_img[1][0]))
        test_data = np.block([[divided_img[0][0], divided_img[0][1]], [empty_block, divided_img[1][1]]])
        dataset            = test_data
        net_var            = 'UnetID'
        output_dir         = output_dir
        trainedNet_dir     = trainedNet_dir
        dataset_name       = im_UD # Default is segm_results_2D
        checkpoint_id      = 'last_id' # or checkpoint_id=100
        crop_input_shape   = (64,64,1)
        save_orig_results  = True 
        hardware_acltr     = 'CPU' # Default is 'GPU'
        postproc_param     =  {"method": "open", "kernel": {"kernel_shape":"disk", "kernel_radius":4}, 
                           "iteration": 1, "save_postproc_results":True} # Remove small artifacts or set it as None if don't want to do any prost processing
        segm_img = fs.apis.segm_2d_data(dataset=dataset, 
                                net_var=net_var, 
                                output_dir=output_dir,
                                trainedNet_dir=trainedNet_dir,
                                dataset_name=dataset_name,
                                checkpoint_id=checkpoint_id, 
                                crop_input_shape=crop_input_shape,
                                save_orig_results=save_orig_results,
                                **postproc_param)
        pred_centroids = get_coords(segm_img[:,:,0])
        true_matched, true_not_matched, pred_matched, pred_not_matched = detect_and_match(true_centroids, pred_centroids, dist_thres=6)
        tp = len(true_matched)
        fn = len(true_not_matched)
        fp = len(pred_not_matched)

        TP[i,j] = tp
        precision[i,j] = tp/(tp+fp)
        recall[i,j] = tp/(tp+fn)
        f1[i,j] = 2*precision[i,j]*recall[i,j]/(precision[i,j]+recall[i,j])

        if test_data.dtype == np.uint8:
            test_data = test_data.astype(float)/2**8
        elif test_data.dtype == np.uint16:
            test_data = test_data.astype(float)/2**16
        
        pred_blobcentroids = get_pred_coords_gauss(test_data, s=2.5, min_distance=3, threshold_abs=0.5)
        true_blobmatched, true_not_blobmatched, pred_blobmatched, pred_not_blobmatched = detect_and_match(true_centroids, pred_blobcentroids, dist_thres=6)
        tp_blob = len(true_blobmatched)
        fn_blob = len(true_not_blobmatched)
        fp_blob = len(pred_not_blobmatched)

        TP_blob[i,j] = tp_blob
        precision_blob[i,j] = tp_blob/(tp_blob+fp_blob)
        recall_blob[i,j] = tp_blob/(tp_blob+fn_blob)
        f1_blob[i,j] = 2*precision_blob[i,j]*recall_blob[i,j]/(precision_blob[i,j]+recall_blob[i,j])

        fig, ax = plt.subplots(figsize=(12, 12))
        fig.suptitle('Blob Fiber Detection (Undersampling Factor: {} | Noise Level: {}) - TP: {}, FP: {}, FN: {}'.format(undersample_factor[i], noise_level[j], tp_blob, fp_blob, fn_blob), fontsize=20)
        plt.subplots_adjust(bottom=0.7)  # Add bottom space
        ax.imshow(test_data, cmap='gray')
        ax.plot(true_blobmatched[:,1], true_blobmatched[:,0], 'gx', alpha=0.8, label='Matched True labels')
        ax.plot(true_not_blobmatched[:,1], true_not_blobmatched[:,0], 'rx', alpha=0.8, label='Not Matched True labels') 
        ax.plot(pred_not_blobmatched[:,1], pred_not_blobmatched[:,0], 'r.', alpha=0.8, label='Not Matched Predicted labels')
        ax.plot(pred_blobmatched[:,1], pred_blobmatched[:,0], 'g.', alpha=0.8, label='Matched Predicted labels')

        plt.tight_layout()
        plt.legend()
        plt.show()
        
        fig1, ax1 = plt.subplots(figsize=(12, 12))
        fig1.suptitle('UNET Fiber Detection (Undersampling Factor: {} | Noise Level: {}) - TP: {}, FP: {}, FN: {}'.format(undersample_factor[i], noise_level[j], tp, fp, fn), fontsize=20)
        plt.subplots_adjust(bottom=0.7)  # Add bottom space
        ax1.imshow(test_data, cmap='gray')
        ax1.plot(true_matched[:,1], true_matched[:,0], 'gx', alpha=0.8, label='Matched True labels')
        ax1.plot(true_not_matched[:,1], true_not_matched[:,0], 'rx', alpha=0.8, label='Not Matched True labels') 
        ax1.plot(pred_not_matched[:,1], pred_not_matched[:,0], 'r.', alpha=0.8, label='Not Matched Predicted labels')
        ax1.plot(pred_matched[:,1], pred_matched[:,0], 'g.', alpha=0.8, label='Matched Predicted labels')
        plt.tight_layout()
        plt.legend()
        plt.show()
        

In [None]:
plt.figure(figsize=(8, 8))
for i in range(len(undersample_factor)):
    plt.plot(precision_blob[i,:], precision[i,:], 'o-', label='UF=' + str(undersample_factor[i]))
plt.xlim(0.5, 1.0)
plt.ylim(0.5, 1.0)
plt.plot([0.5, 1.0], [0.5, 1.0], 'k--')
plt.xlabel('Blob Precision')
plt.ylabel('UNET Precision')
plt.legend(loc='lower right')
plt.show()

print('precision_blob:', precision_blob)
print('precision_UNET:', precision)

In [None]:
plt.figure(figsize=(8, 8))
for i in range(len(undersample_factor)):
    plt.plot(recall_blob[i,:], recall[i,:], 'o-', label='UF=' + str(undersample_factor[i]))
plt.xlim(0.5, 1.0)
plt.ylim(0.5, 1.0)
plt.plot([0.5, 1.0], [0.5, 1.0], 'k--')
plt.xlabel('Blob Recall')
plt.ylabel('UNET Recall')
plt.legend(loc='lower right')
plt.show()

print('recall_blob:', recall_blob)
print('recall_UNET:', recall)

In [None]:
plt.figure(figsize=(8, 8))
for i in range(len(undersample_factor)):
    plt.plot(f1_blob[i,:], f1[i,:], 'o-', label='UF=' + str(undersample_factor[i]))
plt.xlim(0.5, 1.0)
plt.ylim(0.5, 1.0)
plt.plot([0.5, 1.0], [0.5, 1.0], 'k--')
plt.xlabel('Blob F1 Score')
plt.ylabel('UNET F1 Score')
plt.legend(loc='lower right')
plt.show()

print('f1_blob:', f1_blob)
print('f1_UNET:', f1)