# Merge nuclear and membrane segmentation (P51)

#### This notebook merges the results of nuclear and membrane segmentation in each FOV from P51, generating 1 ID for each cell.

### Import Libraries

In [1]:
from os import listdir
from os.path import isfile, join
import cv2
import tifffile
import numpy as np
from skimage.segmentation import expand_labels

In [2]:
def get_iou(inference, gt):
    """
    Retruns a 2D float for intersection of union ratio between ground truth and inference labels.
    
    Arguements:
        inference: 2D numpy array (uint16)
            The inference lables
        
        gt: 2D numpy array (uint16)
            The ground truth labels
            
    Returns: 2D numpy (float)
        The intersection of union ratio between all pairs of inference and ground truth labels.
    """
    true_objects = np.unique(gt)
    pred_objects = np.unique(inference)
    print("ground truth nuclei:", len(true_objects)-1)
    print("Inference nuclei:", len(pred_objects)-1)
    true_bins = np.append(true_objects, true_objects[-1] + 1)
    pred_bins = np.append(pred_objects, pred_objects[-1] + 1)
    intersection, xedges, yedges = np.histogram2d(gt.flatten(), inference.flatten(), bins=(true_bins, pred_bins))
    area_true = np.histogram(gt, bins = true_bins)[0]
    area_pred = np.histogram(inference, bins = pred_bins)[0]
    area_true = np.expand_dims(area_true, -1)
    area_pred = np.expand_dims(area_pred, 0)
    union = area_true + area_pred - intersection
    intersection = intersection[1:,1:]
    union = union[1:,1:]
    union[union == 0] = 1e-9
    iou = intersection / union
    return iou

### Load membrane masks

In [7]:
mem_mask_dir = '/data/Zhaolab/1_AMLCosMx/Final_scripts/2_Segmentation/1_MembraneSegmentation/P51_v7_output_membrane/prediction_model_1_11_23/'
files_mem = [f for f in listdir(mem_mask_dir) if isfile(join(mem_mask_dir, f))]
files_mem.sort()

### Load nuclear masks

In [8]:
nuc_mask_dir = '/data/Zhaolab/1_AMLCosMx/Final_scripts/2_Segmentation/0_NuclearSegmentation/P51_v7_output/labels_predicted/'
files_nuc = [f for f in listdir(nuc_mask_dir) if isfile(join(nuc_mask_dir, f))]
files_nuc.sort()

### Define output location

In [9]:
hybrid_dir = '/data/Zhaolab/1_AMLCosMx/Final_scripts/2_Segmentation/3_NucMemMerging/P51_hybrid/labels_predicted_2_15_24/'

### Merge mask files

In [10]:
for mem_label in files_mem:
    
    # read in membrane file
    mem = cv2.imread(mem_mask_dir + mem_label, 2)
    mem = mem.astype('uint16')
    
    # read in corresponding nuclear file
    for nuc_label in files_nuc:
        if mem_label == nuc_label:
            nuc = cv2.imread(nuc_mask_dir + nuc_label, 2)
            nuc = nuc.astype('uint16')
            break
    
    # make nucelar ids unique from membrane ids
    nuc[nuc>0] = nuc[nuc>0] + mem.max()
    
    # find iou of each membrane/nucleus combination
    iou = get_iou(mem, nuc)
    
    nuc_overlay = np.copy(nuc)
    mem_overlay = np.copy(mem)
    
    # create vector of nuclear ids
    nuc_ids = np.unique(nuc)
    nuc_ids = np.delete(nuc_ids, 0)
    
    # create vector of membrane ids
    mem_ids = np.unique(mem)
    mem_ids = np.delete(mem_ids, 0)

    # for each nucleus overlapping more than 10% with a membrane, assign the nucleus the id of that nucleus
    for i in range(iou.shape[0]):
        for j in range(iou.shape[1]):
            if iou[i,j] > 0.1:
                nuc_overlay[nuc_overlay == nuc_ids[i]] = mem_ids[j]
    
    # create array of nuclear masks that don't overlap with membrane masks, perform nuclear expansion
    nuc_only = np.copy(nuc_overlay)
    nuc_only[nuc_only <= mem.max()] = 0
    nuc_only[mem_overlay > 0] = 0
    
    nuc_only = expand_labels(nuc_only, distance=5) # expand labels
    
    # create array of only nuclear masks that DO overlap with membrane masks
    nuc_overlap = np.copy(nuc_overlay)
    nuc_overlap[nuc_only > 0] = 0
    nuc_overlap[mem_overlay > 0] = 0
    
    # create array of membrane masks, subtracting areas of intersection with overlapping nuclei and areas of lone nuclei    
    mem_only = np.copy(mem_overlay)
    mem_only[nuc_overlap > 0] = 0
    mem_only[nuc_only > 0] = 0
    
    # sum above three arrays to create final output
    output = np.sum([nuc_only, mem_only, nuc_overlap], axis=0)
    
    tifffile.imwrite(hybrid_dir + mem_label, output.astype('uint16'), photometric='minisblack')

ground truth nuclei: 2272
Inference nuclei: 1271
ground truth nuclei: 2778
Inference nuclei: 1510
ground truth nuclei: 1428
Inference nuclei: 814
ground truth nuclei: 1959
Inference nuclei: 1096
ground truth nuclei: 2742
Inference nuclei: 1723
ground truth nuclei: 2881
Inference nuclei: 1240
ground truth nuclei: 1267
Inference nuclei: 1115
ground truth nuclei: 3615
Inference nuclei: 2335
ground truth nuclei: 2528
Inference nuclei: 2209
ground truth nuclei: 3971
Inference nuclei: 2968
ground truth nuclei: 4313
Inference nuclei: 2436
ground truth nuclei: 2240
Inference nuclei: 2023
ground truth nuclei: 3705
Inference nuclei: 2624
ground truth nuclei: 3050
Inference nuclei: 2219
ground truth nuclei: 3110
Inference nuclei: 2493
ground truth nuclei: 4097
Inference nuclei: 3577
ground truth nuclei: 3893
Inference nuclei: 3200
ground truth nuclei: 4061
Inference nuclei: 2517
ground truth nuclei: 4171
Inference nuclei: 3302
ground truth nuclei: 3835
Inference nuclei: 3066
ground truth nuclei: 