In [1]:
from argparse import ArgumentParser
import yaml
import numpy as np
from skimage.io import imread
from skimage.morphology import binary_dilation, disk
from skimage.measure import regionprops, label
import os
import copy
from glob import glob
import pickle
from tqdm import tqdm

In [2]:
def getid(samp_path):
    return samp_path.split("/")[-1].split(".")[0]

In [3]:
def create_pmask(segpath, r):
    seg = imread(segpath)
    mask = np.zeros_like(seg)
    rps = regionprops(seg)
    centroids = np.array(
        list(map(lambda x : np.array(x.centroid).astype(int), rps))
    )
    mask[centroids[:,0], centroids[:,1]] = 1
    mask = binary_dilation(mask, disk(r))
    return mask

In [4]:
def compute_masks(sid, radii, methods, datapath):
    masks = {}
    for r in radii:
        rmasks = {}
        for m in methods:
            segpath = glob(os.path.join(datapath, m, sid + "*"))[0]
            pmask = create_pmask(segpath, r)
            rmasks[m] = pmask
        stack = np.stack(list(rmasks.values()))
        avg_mask = stack.mean(0)
        rmasks["mean"] = avg_mask
        masks[r] = rmasks
    return masks

In [5]:
def write_pmasks(sample_ids, radii, methods, datapath, save_dir):
    print(f"Writing probability masks for {len(sample_ids)} samples")
    for sid in tqdm(sample_ids):
        save_path = os.path.join(save_dir, f"{sid}.pkl")
        if os.path.exists(save_path):
            continue

        sid_masks = compute_masks(sid, radii, methods, datapath)
        with open(save_path, "wb") as handle:
            pickle.dump(sid_masks, handle)
    print(f"Proability masks saved to {save_dir}")



In [6]:
def filter_mask(mask, avg_labs):
    filtered = copy.deepcopy(mask)
    mask_lab = label(mask)
    rps = regionprops(mask_lab)
    for rp in rps:
        coords = rp.coords
        vals = avg_labs[coords[:,0], coords[:,1]]
        uniq, counts = np.unique(vals, return_counts=True)
        if uniq[0] == 0:
            uniq = uniq[1:]
            counts = counts[1:]
        n_unique = len(uniq)
        if n_unique > 1:
            amax = np.argmax(counts)
            top_val = uniq[amax]
            idxs = np.where(vals != top_val)
            to_zero = coords[idxs,:][0]
            filtered[to_zero[:,0], to_zero[:,1]] = False
    return filtered

In [7]:
def filter_pmasks(sample_ids, pmask_save_dir, filtered_save_dir, min_num_agree, methods):
    print(f"Filtering probability masks for {len(sample_ids)} samples")
    for sid in tqdm(sample_ids):
        with open(os.path.join(pmask_save_dir, f"{sid}.pkl"), "rb") as handle:
            data = pickle.load(handle)

        filtered_masks = {}
        for r, masks in data.items():
            avg = masks["mean"]
            avg_threshd = (avg >= (min_num_agree / len(methods)))
            avg_labs = label(avg_threshd)

            r_filtered_masks = {}
            for m in methods:
                r_filtered_masks[m] = filter_mask(masks[m], avg_labs)

            new_stack = np.stack(list(r_filtered_masks.values()))        
            new_avg = new_stack.mean(0)
            r_filtered_masks["mean"] = new_avg
            filtered_masks[r] = r_filtered_masks
            
        with open(os.path.join(filtered_save_dir, f"{sid}.pkl"), "wb") as handle:
            pickle.dump(filtered_masks, handle)
            
    print(f"Filtered probability masks saved to {filtered_save_dir}")


In [8]:
def eval_mask(gt, m):
    rps = regionprops(m)
    coords = list(map(lambda x : x.coords, rps))
    correct = 0
    
    for c in coords:
        correct += (gt[c[:,0], c[:,1]]).max()
    precision = correct / len(rps)
    
    gt_labs = label(gt)
    #print(gt_labs)
    gt_rps = regionprops(gt_labs)
    #print(gt_rps)
    coords = list(map(lambda x : x.coords, gt_rps))
    correct = 0
    
    for c in coords:
        correct += (m[c[:,0], c[:,1]]).max() > 0
    recall = correct / len(gt_rps)
    
    assert precision <= 1
    assert precision >= 0
    assert recall <= 1
    assert recall >= 0
    
    return precision, recall

In [9]:
def evaluate_masks(sample_ids, filtered_pmask_save_dir, radii, min_num_agree, num_methods):
    precision = {}
    recall = {}
    print(f"Computing precision and recall for {len(sample_ids)} samples")
    for sid in tqdm(sample_ids):
        data_load_path = os.path.join(filtered_pmask_save_dir, f"{sid}.pkl")
        with open(data_load_path, "rb") as handle:
            data = pickle.load(handle)

        sid_precisions = dict((r, {}) for r in radii)
        sid_recalls = dict((r, {}) for r in radii)

        for r, masks in data.items():

            avg = masks["mean"]
            avg_thresh = (avg >= (min_num_agree / num_methods))
            
            for name, mask in masks.items():
                if name == "mean":
                    continue
                labd_mask = label(mask)
                prec, rec = eval_mask(avg_thresh, labd_mask)
                sid_precisions[r][name] = prec
                sid_recalls[r][name] = rec

        precision[sid] = sid_precisions
        recall[sid] = sid_recalls

    return precision, recall

In [None]:
def main():
    #parser = ArgumentParser()
    #parser.add_argument("--config", type=str, default="./config.yml", help="Path to config file")
    #parser.add_argument("--compute-pmasks", action="store_true", help="Compute probability masks")
    #parser.add_argument("--filter-pmasks", action="store_true", help="Filter probability masks")
    #parser.add_argument("--compute-scores", action="store_true", help="Compute scores")
    #args = parser.parse_args()
    
    # m - mesmer, s- startdist, c - cellpose, u-unet, r - MaskRCNN

    #with open(args.config, "r") as handle:
    #    config = yaml.load(handle, Loader=yaml.FullLoader)
    methods =  ["mesmer", "stardist", "maskrcnn", "unet"] #config["methods"]
    radii = [2, 4, 6, 8, 10, 12, 14, 16] #config["radii"]
    num_agree = 3 #config["num_agree"]
    datapath = "/home/groups/ChangLab/dataset/HMS-TMA-TNP/DATA-03292022" #config["datapath"]
    results_dir = "/home/groups/ChangLab/dharani/HMS-TMA-TNP_results_msur" #config["resultsdir"]
    
    compute_pmasks = True
    filter_pmask = True
    compute_scores = True

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    pmask_save_dir = os.path.join(results_dir, "pmasks")
    if not os.path.exists(pmask_save_dir):
        os.makedirs(pmask_save_dir)

    sample_ids = os.listdir(os.path.join(datapath, methods[0]))
    sample_ids = [s.split(".")[0] for s in sample_ids]
    #sample_ids = sample_ids[0:20]

    if compute_pmasks:
        write_pmasks(sample_ids, radii, methods, datapath, pmask_save_dir)

    filtered_pmask_save_dir = os.path.join(results_dir, "filtered_pmasks")
    if not os.path.exists(filtered_pmask_save_dir):
        os.makedirs(filtered_pmask_save_dir)

    if filter_pmask:
        filter_pmasks(sample_ids, pmask_save_dir, filtered_pmask_save_dir, num_agree, methods)

    precision_scores_path = os.path.join(results_dir, "precision_scores.pkl")
    recall_scores_path = os.path.join(results_dir, "recall_scores.pkl")

    if compute_scores:
        precision, recall = evaluate_masks(sample_ids, filtered_pmask_save_dir, radii, num_agree, len(methods))
        with open(precision_scores_path, "wb") as handle:
            pickle.dump(precision, handle)
            print(f"Saved precision scores to {precision_scores_path}")
        with open(recall_scores_path, "wb") as handle:
            pickle.dump(recall, handle)
            print(f"Saved recall scores to {recall_scores_path}")


if __name__ == '__main__':
    main()

Writing probability masks for 88 samples


100%|██████████| 88/88 [2:32:45<00:00, 104.15s/it]  


Proability masks saved to /home/groups/ChangLab/dharani/HMS-TMA-TNP_results_msur/pmasks
Filtering probability masks for 88 samples


100%|██████████| 88/88 [46:29<00:00, 31.70s/it]


Filtered probability masks saved to /home/groups/ChangLab/dharani/HMS-TMA-TNP_results_msur/filtered_pmasks
Computing precision and recall for 88 samples


 39%|███▊      | 34/88 [09:09<15:00, 16.67s/it]

### Code understanding

In [None]:
import pandas as pd

object1 = pd.read_pickle("/home/groups/ChangLab/dharani/HMS-TMA-TNP_results_all5/filtered_pmasks/OHSU_TMA1_004-A1.pkl")
#object2 = pd.read_pickle("/home/groups/ChangLab/dataset/HMS-TMA-TNP/DATA-03292022/results/filtered_pmasks/OHSU_TMA1_004-A1.pkl")
#object3 = pd.read_pickle("/home/groups/ChangLab/dataset/HMS-TMA-TNP/DATA-03292022/results/precision_scores.pkl")

In [None]:
print(object1)

In [None]:
for r, masks in object1.items():
    #print(masks)
    for i, j in masks.items():
        print(i)

In [None]:
with open("/home/groups/ChangLab/dataset/HMS-TMA-TNP/DATA-03292022/results/pmasks/OHSU_TMA1_004-A1.pkl", "rb") as handle:
    data = pickle.load(handle)

#print(data)

for r, masks in data.items():
    filtered = copy.deepcopy(masks['cellpose'])
    avg = masks["mean"]
    avg_threshd = (avg >= (3 / 5))
    avg_labs = label(avg_threshd)
    print(np.shape(avg_labs)) 
    #print(masks['cellpose'])
    #print(label(masks['cellpose']))
    mask_lab = label(masks['cellpose'])
    rps = regionprops(mask_lab)
    #print('rps', rps)
    for rp in rps:
        #print('rp', rp)
        coords = rp.coords
        #print('coords', coords)
        vals = avg_labs[coords[:,0], coords[:,1]]
        #print('Vals', vals)
        uniq, counts = np.unique(vals, return_counts=True)
        
        amax = np.argmax(counts)
        #print(amax)
        top_val = uniq[amax]
        #print(top_val)

        idxs = np.where(vals != top_val)
        #print(idxs)
    
        to_zero = coords[idxs,:][0]
        #print(to_zero)

        filtered[to_zero[:,0], to_zero[:,1]] = False
        #print(filtered)
        

        
    

'''
def filter_mask(mask, avg_labs):
    filtered = copy.deepcopy(mask)
    mask_lab = label(mask)
    rps = regionprops(mask_lab)
    for rp in rps:
        coords = rp.coords
        vals = avg_labs[coords[:,0], coords[:,1]]
        uniq, counts = np.unique(vals, return_counts=True)
        if uniq[0] == 0:
            uniq = uniq[1:]
            counts = counts[1:]
        n_unique = len(uniq)
        if n_unique > 1:
            amax = np.argmax(counts)
            top_val = uniq[amax]
            idxs = np.where(vals != top_val)
            to_zero = coords[idxs,:][0]
            filtered[to_zero[:,0], to_zero[:,1]] = False
    return filtered
    
    
filtered_masks = {}
for r, masks in data.items():
    avg = masks["mean"]
    avg_threshd = (avg >= (min_num_agree / len(methods)))
    avg_labs = label(avg_threshd)

    r_filtered_masks = {}
    for m in methods:
        r_filtered_masks[m] = filter_mask(masks[m], avg_labs)

    new_stack = np.stack(list(r_filtered_masks.values()))        
    new_avg = new_stack.mean(0)
    r_filtered_masks["mean"] = new_avg
    filtered_masks[r] = r_filtered_masks
            
#with open(os.path.join(filtered_save_dir, f"{sid}.pkl"), "wb") as handle:
#    pickle.dump(filtered_masks, handle)
'''

In [None]:
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(a)

In [None]:
label(a >= 5/3)

In [None]:
a >= 5/3

In [None]:
vals = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

uniq, counts = np.unique(vals, return_counts=True)

In [None]:
print(uniq, counts)

In [None]:
np.argmax(counts)