# EVALUATION SCRIPT FOR SMVs and CNNs
The way this works is that you place all your prediction files in a folder/list. Be it for synthetic holograms or real. 
GT data is available in a .pkl file for synthetic and .txt files for real holograms. The script takes care of these. 
Make sure you put your SVM/CNN predictions in the same folder. If they're not, combine them in a single list. 



In [None]:
import numpy as np
from tqdm import trange
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
from utils import f1, recall, precision, dice_score
import mat73
import glob
from patchify import patchify, unpatchify
import pickle

In [13]:
def gkern(l=5, sig=1.):
    """\
    creates gaussian kernel with side length `l` and a sigma of `sig`
    """
    ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l)
    gauss = np.exp(-0.5 * np.square(ax) / np.square(sig))
    kernel = np.outer(gauss, gauss)
    return kernel

def peak_local_max(input, threshold_abs=1, min_distance=1):
    '''
    Returns a binary map where maxima positions are true.

        Parameters:
            input (pytorch tensor): image-like pytorch tensor of dimension [batch_size, channels, width, height], where each image will be processed individually
            threshold_abs (float): local maxima below will be dropped
            min_distance (int): min distance (in pixels) between two maxima detections.
        Returns
            pytorch tensor of same shape as input
    '''
    max_filtered=nn.functional.max_pool2d(input, kernel_size=2*min_distance+1, stride=1, padding=min_distance)
    maxima = torch.eq(max_filtered, input)
    return maxima * (input >= threshold_abs)

def mask_maker(X, Y, Z, R, N_x = 512, N_y = 512, kernel_size = 5):
    X = np.round(np.array(X)).astype(int)
    Y = np.round(np.array(Y)).astype(int)
    
    Z = (np.array(Z))
    R = (np.array(R))
    
    
    gk = gkern(l = kernel_size, sig = 1)
    mask  = np.zeros((3, N_y + kernel_size + 2, N_x + kernel_size + 2))
    
    for (x,y,z,r) in zip(X,Y,Z,R):

        mask[0,(kernel_size//2+y-kernel_size//2):(kernel_size//2+y+kernel_size//2+1),(kernel_size//2+x-kernel_size//2):(kernel_size//2+x+kernel_size//2+1)] = gk # Use this for synthetic 
        mask[1,(kernel_size//2+y-kernel_size//2):(kernel_size//2+y+kernel_size//2+1),(kernel_size//2+x-kernel_size//2):(kernel_size//2+x+kernel_size//2+1)] = z
        mask[2,(kernel_size//2+y-kernel_size//2):(kernel_size//2+y+kernel_size//2+1),(kernel_size//2+x-kernel_size//2):(kernel_size//2+x+kernel_size//2+1)] = r


    return mask[:,1+kernel_size//2:N_y+1+kernel_size//2,1+kernel_size//2:N_x+1+kernel_size//2]



def make_real_gt_mask(pth_to_real_gt, Z_locs, units, file_index, hologram_size, edge_crop_dist, ds_factor, hitbox):

    gt_files = glob.glob(pth_to_real_gt+"*.txt")
    dx, dz, dr = units

    num_holograms = 1 if file_index is not None else len(gt_files)

    holo_size = (hologram_size-2*edge_crop_dist)
    xy_slices = (slice(3), slice(edge_crop_dist//ds_factor, (hologram_size-edge_crop_dist)//ds_factor),
              slice(edge_crop_dist//ds_factor, (hologram_size-edge_crop_dist)//ds_factor))
    
    masks = np.zeros((num_holograms, 3 ,holo_size, holo_size))

    for j, gt_file in enumerate(gt_files):
        data = np.genfromtxt(gt_file)
        x = data[:,0]
        y = data[:,1]
        d = data[:,2]*dr
        z = np.ones(x.shape[0])*Z_locs[j]

        xmin = 0
        ymin = 0
        xmax = hologram_size
        ymax = hologram_size
        

        xx = x[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        yy = y[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        zz = z[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        dd = d[(x>xmin)*(y>ymin)*(x<xmax)*(y<ymax)]
        
        tgt = mask_maker(xx/ds_factor, yy/ds_factor, zz, dd, xmax//ds_factor, ymax//ds_factor, kernel_size = hitbox)
        # print(tgt.shape)

        true_mask_xy = np.flipud(tgt[0]) # Mask for standard method on 5120x5120
        true_mask_z = np.flipud(tgt[1])
        true_mask_r = np.flipud(tgt[2])
        true_mask = np.concatenate((true_mask_xy[np.newaxis,:,:], true_mask_z[np.newaxis,:,:], true_mask_r[np.newaxis,:,:]), axis = 0)

        masks[j] = true_mask[xy_slices]

    return masks

def make_synthetic_gt_mask(pkl_file, num_holos, units, holo_size, crop_size, step_size, ds_factor, kernel_size):
    dx, dz, dr = units
    dx = dx*ds_factor
    holo_size = holo_size//ds_factor
    patch_size = crop_size//ds_factor
    step = step_size//ds_factor
    num_patches_per_holo = (holo_size//patch_size)**2
    store_masks = np.zeros((num_holos*num_patches_per_holo, 3, patch_size, patch_size), dtype = np.float16)
    
    msk_count = 0
    for i in trange(num_holos):
        mask = mask_maker(pkl_file['x'][i]/dx+holo_size//2, pkl_file['y'][i]/dx+holo_size//2,
                          pkl_file['z'][i]/dz,pkl_file['r'][i]/dr, holo_size, holo_size, kernel_size)
        patched_mask_xy = patchify(mask[0], patch_size=(patch_size), step = step)
        patched_mask_xy = np.reshape(patched_mask_xy, (patched_mask_xy.shape[0]*patched_mask_xy.shape[1], patched_mask_xy.shape[2], patched_mask_xy.shape[3]))


        patched_mask_z = patchify(mask[1], patch_size=(patch_size), step = step)
        patched_mask_z = np.reshape(patched_mask_z, (patched_mask_z.shape[0]*patched_mask_z.shape[1], patched_mask_z.shape[2], patched_mask_z.shape[3]))


        patched_mask_r = patchify(mask[2], patch_size=(patch_size), step = step)
        patched_mask_r = np.reshape(patched_mask_r, (patched_mask_r.shape[0]*patched_mask_r.shape[1], patched_mask_r.shape[2], patched_mask_r.shape[3]))
        
        store_masks[msk_count:msk_count+patched_mask_xy.shape[0],0] = patched_mask_xy
        store_masks[msk_count:msk_count+patched_mask_z.shape[0],1] = patched_mask_z
        store_masks[msk_count:msk_count+patched_mask_r.shape[0],2] = 2*patched_mask_r # radius in the ground truth

        msk_count += patched_mask_xy.shape[0]
    
    return store_masks

def make_cnn_prediction_masks(pth_to_preds, cutoff, hologram_size, units, ds_factor, kernel_size,):
    dx, dz, dr = units
    dx = dx*ds_factor
    holo_size = hologram_size//ds_factor


    predictions_list = sorted(glob.glob(pth_to_preds+"*.txt"))

    masks = np.zeros((len(predictions_list, 3, holo_size, holo_size)), dtype=np.float16)

    for i,file in enumerate(predictions_list):
        data = np.genfromtxt(file)
        data = data[data[:,-1]>=cutoff]
        x = data[:,0]/(dx) + holo_size//2
        y = data[:,1]/(dx) + holo_size//2
        z = data[:,2]/dz 
        d = data[:,3]/dr 
        
        masks[i] = mask_maker(x, y, z, d, holo_size, holo_size, kernel_size)
    
    return masks 

def make_svm_prediction_masks(pth_to_preds, cutoff, hologram_size, units, ds_factor, kernel_size,):
    dx, dz, dr = units
    dx = dx*ds_factor
    holo_size = hologram_size//ds_factor

    predictions_list = sorted(glob.glob(pth_to_preds+"*.mat"))

    masks = np.zeros((len(predictions_list, 3, holo_size, holo_size)), dtype=np.float16)

    for i,file in enumerate(predictions_list):
        data = mat73.loadmat(file)
        x = data['metrics'][:,100]/dx + holo_size//2
        y = data['metrics'][:,101]/dx + holo_size//2
        z = data['metrics'][:,105]/dz
        d = np.sqrt(4/(np.pi)*data['metrics'][:,1])/dr

        masks[i] = mask_maker(x, y, z, d, holo_size, holo_size, kernel_size)
    
    return masks 

def get_bad_particles_counts(pred_data, true_data, hits, constraints):
    z_values_detected_wrt_gt = true_data[:,1].unsqueeze(1)[hits]
    s_values_detected_wrt_gt = true_data[:,2].unsqueeze(1)[hits]
    z_values_detected_wrt_pred = pred_data[:,1].unsqueeze(1)[hits]
    s_values_detected_wrt_pred = pred_data[:,2].unsqueeze(1)[hits]
    ez = np.abs(z_values_detected_wrt_pred-z_values_detected_wrt_gt)
    er = np.abs(s_values_detected_wrt_pred-s_values_detected_wrt_gt)
    
    outliers = 0
    out_of_scope = 0 # This variable is for removing out of scope particles (<min_r with respect to gt) from the intersection and predictions.
    
    min_r = 6.0 # put diameter here
    max_r = 75.0
    min_z ,max_z, min_r, max_r, allowed_ez = constraints
    allowed_ez = 10
    for zz, zz_pred, ss, ss_pred, ezz, err in zip(z_values_detected_wrt_gt, z_values_detected_wrt_pred, s_values_detected_wrt_gt, s_values_detected_wrt_pred, ez, er):
        if ss < min_r or ss > max_r:  # take care of the inequalities here. 
            out_of_scope += 1
            continue

        if ezz >allowed_ez: 
            outliers += 1
            continue
    return out_of_scope, outliers

def get_z_and_size(pred_masks, true_masks, best_cutoff, min_distance, hit_box_size_param, constraints):
    z_det, z_pred, d_det, d_pred ,ez_det, ed_det = [], [], [], [], [], []
    min_z ,max_z, min_r, max_r, allowed_ez = constraints
    for pred_data, true_data in zip(pred_masks, true_masks):
        predicted_particles = peak_local_max(pred_data[:,0].unsqueeze(1), threshold_abs=best_cutoff, min_distance=min_distance).float()
        hits = peak_local_max(predicted_particles*((true_data[0,0]>=hit_box_size_param).float()).unsqueeze(0), threshold_abs=1, min_distance=min_distance)

        z_values_detected_wrt_gt = true_data[:,1].unsqueeze(1)[hits]
        s_values_detected_wrt_gt = true_data[:,2].unsqueeze(1)[hits]
        z_values_detected_wrt_pred = pred_data[:,1].unsqueeze(1)[hits]
        s_values_detected_wrt_pred = pred_data[:,2].unsqueeze(1)[hits]
        ez = np.abs(z_values_detected_wrt_pred-z_values_detected_wrt_gt)
        er = np.abs(s_values_detected_wrt_pred-s_values_detected_wrt_gt)
        
        outlier = 0
        out_of_scope = 0 # This variable is for removing out of scope particles (<min_r with respect to gt) from the intersection and predictions.
    
        for zz, zz_pred, ss, ss_pred, ezz, err in zip(z_values_detected_wrt_gt, z_values_detected_wrt_pred, s_values_detected_wrt_gt, s_values_detected_wrt_pred, ez, er):
            if ss < min_r or ss > max_r:  # take care of the inequalities here. 
                out_of_scope += 1
                continue

            if ezz >allowed_ez: 
                outlier += 1
                continue

            z_det.append(float(zz))
            z_pred.append(float(zz_pred))
            d_det.append(float(ss))
            d_pred.append(float(ss_pred))
        
    return z_det, z_pred, d_det, d_pred

def get_hists(true_masks, gt_thresh, min_distance, constraints):
    hist_true_r = []
    hist_true_z = []
    min_z ,max_z, min_r, max_r, ez_allowed = constraints
    for true_data in true_masks:
            peaks_in_true_data = peak_local_max(true_data[0,0].unsqueeze(0), gt_thresh, min_distance).squeeze(0)
            depths_in_true_data = true_data[0,1][peaks_in_true_data]
            sizes_in_true_data = true_data[0,2][peaks_in_true_data]
            for depth, size in zip(depths_in_true_data,sizes_in_true_data):
                if size < min_r and size > max_r:
                    continue
                hist_true_r.append(size)
                hist_true_z.append(depth)


In [14]:
def get_precision_recall_f1(pred_masks, true_masks, constraints, hitbox, cutoff_vals):

    min_z ,max_z, min_r, max_r, ez_allowed = constraints
    hit_box_size_param = gkern(hitbox, 1).min() 
    gt_thresh = 0.8
    min_distance = hitbox//2
    num_samples = pred_masks.shape[0]

    for cutoff in cutoff_vals:
        P = []
        R = []
        F1 = []

        pr = 0
        rc = 0
        f1sc = 0
        for pred_data, true_data in zip(pred_masks, true_masks):

            predicted_particles = peak_local_max(pred_data[:,0].unsqueeze(1), threshold_abs=gt_thresh, min_distance=min_distance).float()
            hits = peak_local_max(predicted_particles*((true_data[0,0]>=hit_box_size_param).float()).unsqueeze(0), threshold_abs=1, min_distance=min_distance)

            out_of_scope, outliers = get_bad_particles_counts(pred_data, true_data, hits, constraints)
            

            hits_ = (hits.sum()-out_of_scope)-outliers

            fp = (predicted_particles.sum()-out_of_scope) - (hits_)
            
            out_of_scope_for_gt = ((((true_data[:,2].unsqueeze(1))[peak_local_max(true_data[:,0].unsqueeze(1), gt_thresh, 1)]< min_r)
                                    +((true_data[:,2].unsqueeze(1))[peak_local_max(true_data[:,0].unsqueeze(1), gt_thresh, 1)]> max_r)).float()).sum() # This variable is for removing all the particles <min_r from gt. > 0.9 takes care when sum is 2 (True + True), it converts it to 1.
            fn = (peak_local_max(true_data[:,0].unsqueeze(1), 0.8, 1).sum()-out_of_scope_for_gt) - (hits_) # Have to add outlier to hits because calculation happens in xy space for peak_local_max(true_data) and the outlier is added to false positive

            if hits_ > fn+hits_:
                extra_hits -= (fn - hits_)
                hits_ -= extra_hits
                fp += extra_hits


            pr += precision(hits_,fp)
            rc += recall(hits_,fn)
            f1sc += f1(pr,rc)

        P = float(pr/num_samples)
        R = float(rc/num_samples)
        F1 = float(f1sc/num_samples)
    
    return P, R, F1
    


In [None]:
# load the grount truth (.pkl file)
pth_to_syn_gt = ''
num_synthetic_test_holos = 1000
dr = 1 # in the gt only xy are in m, size and z are in µm and mm.
dz = 1
dx = 3e-6
hologram_size = 4096
crop_size = 1536
ds_factor = 4
hitbox = 21//ds_factor
tgt_pkl = np.load(pth_to_syn_gt, allow_pickle=True)
tgt_masks = make_synthetic_gt_mask(tgt_pkl, num_synthetic_test_holos, (dx, dz, dr) ,holo_size = hologram_size, 
                          crop_size = crop_size, step_size = crop_size, ds_factor = ds_factor, kernel_size=hitbox)



pth_to_real_gt = ''
dr = 3 #size is given in pixels for some reason in CloudTarget
dz = 1
dx = 1 # (x,y) also in pixels 
units = (dx, dz, dr)
hologram_size = 5120
ds_factor = 4
hitbox = 21//ds_factor
edge_crop_dist = 512 #512, #256
Zs = [50,99,167,192,75]
tgt_masks = make_real_gt_mask(pth_to_real_gt, Zs, units, None, hologram_size, edge_crop_dist, ds_factor, hitbox)
tgt = torch.from_numpy(tgt_masks)

# load the predictions from CNNs (.txt files )
# cnn preds are in m
pth_to_cnn_preds = ''
dr = 1e-6
dz = 1e-3
dx = 3e-6
hologram_size = 1536 # 5120
ds_factor = 4
pred_kernel_size = 3


cutoff = 0.69 # random value for svm
cutoff_vals = [cutoff]
cutoff_vals = np.arange(100)/100 # for cnn

# Constraints
min_z = 5
max_z = 200
min_r = 6.0 # this is the real parameter setting precision recall with respect to r
max_r = 100
ez_allowed = 10
constraints = (min_z, max_z, min_r, max_r)
hit_box_size_param = gkern(hitbox, 1).min()
gt_thresh = 0.8
min_distance = hitbox//2


xy_slices = (slice(None), slice(None), slice(None), slice(None)) # for synthetic 
xy_slices = (slice(None), slice(3), slice(edge_crop_dist//ds_factor, (hologram_size-edge_crop_dist)//ds_factor),
            slice(edge_crop_dist//ds_factor, (hologram_size-edge_crop_dist)//ds_factor)) # for real 





In [None]:
# calculate stats
prec, rec, f1_score = [], [], []
for cutoff in cutoff_vals:
    pred_masks = make_cnn_prediction_masks(pth_to_cnn_preds, cutoff, hologram_size, (dx, dz, dr), ds_factor, pred_kernel_size)
    # pred_masks = make_svm_prediction_masks(pth_to_cnn_preds, cutoff, hologram_size, (dx, dz, dr), ds_factor, pred_kernel_size)

    # convert to torch
    preds = torch.from_numpy(pred_masks)[xy_slices]

    P, R, F1 = get_precision_recall_f1(tgt, preds, constraints, hitbox)
    prec.append(P)
    rec.append(R)
    f1_score.append(F1)


best_cutoff_idx = np.argmax(F1)
best_cutoff = cutoff_vals[best_cutoff_idx]
best_f1 = F1[best_cutoff_idx]
rec_at_best_f1 = R[best_cutoff_idx]
prec_at_best_f1 = P[best_cutoff_idx]

z_det, z_pred, d_det, d_pred = get_z_and_size(preds, tgt_masks, best_cutoff, min_distance, hit_box_size_param, constraints)
    
hist_true_z, hist_true_r = get_hists(tgt_masks, gt_thresh, min_distance, constraints)

prediction_dict = {
        "precision": prec,
        "recall": rec,
        "F1": f1_score,
        "best_F1":  float(best_f1),
        "precision_at_best_f1": float(prec_at_best_f1),
        "recall_at_best_f1": float(rec_at_best_f1),
        "best_cutoff": float(best_cutoff),
        "best_cutoff_idx": best_cutoff_idx,
        "z_detected": np.array(z_det),
        "z_predicted": np.array(z_pred),
        "d_detected": np.array(d_det),
        "d_predicted": np.array(d_pred),
        "z_all_in_gt": np.array(hist_true_z),
        "d_all_in_gt": np.array(hist_true_r)
    }

# qualitatively check for overlap
plt.imshow(tgt_masks[0,0])
plt.imshow(pred_masks[0,0], alpha = 0.5)



In [None]:
# save predictions 
with open(""+"CNN_preds.pkl", mode = 'wb+') as f:
    pickle.dump(prediction_dict, f)