In [1]:
from pathlib import Path
import os
import hpacellseg.cellsegmentator as cellsegmentator
from hpacellseg.utils import label_cell, label_nuclei
from skimage import measure, segmentation, morphology
import cv2
from glob import glob
import numpy as np

## User Interface

### You need to run this with the absolute path

In [2]:
!find /home/ishang/HPA-embedding/dev-dataset -type d > data-folder.txt
!sed -i '1d' data-folder.txt
!cat data-folder.txt

/home/ishang/HPA-embedding/dev-dataset/11_TileScan 6--Stage00
/home/ishang/HPA-embedding/dev-dataset/10_R1--Stage01
/home/ishang/HPA-embedding/dev-dataset/field--X01--Y00_image--L0000--S00--U08--V04--J17--E02--O02--X01--Y00--T0000--Z00--C0


In [3]:
DATA_DIR = Path("/home/ishang/HPA-embedding/dev-dataset") # NEEDS TO BE ABSOLUTE PATH
CHANNEL_NAMES = ["cyclinb1", "microtubule", "nuclei"]
DAPI = 2
TUBL = 1
ANLN = None

## Mask Creation

### DO NOT RERUN THIS CELL BELOW!!! YOU WILL HAVE TO KILL THE PROCESS AND RESTART THE KERNEL

In [4]:
pwd = Path(os.getcwd())
NUC_MODEL = pwd / "HPA-Cell-Segmentation" / "nuclei-model.pth"
CELL_MODEL = pwd / "HPA-Cell-Segmentation" / "cell-model.pth"

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
segmentator = cellsegmentator.CellSegmentator(
    str(NUC_MODEL), str(CELL_MODEL), device="cuda", padding=True, multi_channel_model=(ANLN is not None)
)

please compile abn




### DO NOT RERUN THIS ^^^^^ YOU WILL HAVE TO KILL THE PROCESS AND RESTART THE KERNEL

In [5]:
image_paths = list(open("data-folder.txt", "r"))
image_paths = [Path(x.strip()) for x in image_paths]
print(image_paths)

[PosixPath('/home/ishang/HPA-embedding/dev-dataset/11_TileScan 6--Stage00'), PosixPath('/home/ishang/HPA-embedding/dev-dataset/10_R1--Stage01'), PosixPath('/home/ishang/HPA-embedding/dev-dataset/field--X01--Y00_image--L0000--S00--U08--V04--J17--E02--O02--X01--Y00--T0000--Z00--C0')]


In [None]:
def get_masks(segmentator, image_paths):
    for image_path in image_paths:
        if (image_paths / "cellmask.png").exists():
            continue
        channel_images = []
        glob_channel_images = lambda image_path, c: list(glob(f"{str(image_path)}/**/*{CHANNEL_NAMES[c]}.png", recursive=True))
        dapi_paths = sorted(glob_channel_images(image_path, DAPI))
        tubl_paths = sorted(glob_channel_images(image_path, TUBL))
        anln_paths = sorted(glob_channel_images(image_path, ANLN)) if ANLN is not None else None
        
        for dapi, tubl in zip(dapi_paths, tubl_paths):
            assert str(dapi).split(CHANNEL_NAMES[DAPI])[0] == str(tubl).split(CHANNEL_NAMES[TUBL])[0], f"File mismatch for {dapi} and {tubl}"
        if ANLN is not None and anln_paths is not None:
            for dapi, anln in zip(dapi_paths, anln_paths):
                assert str(dapi).split(CHANNEL_NAMES[DAPI])[0] == str(anln).split(CHANNEL_NAMES[ANLN])[0], f"File mismatch for {dapi} and {anln}"

        load_image = lambda path_list: [cv2.imread(str(x), cv2.IMREAD_UNCHANGED) for x in path_list]
        dapi_images = load_image(dapi_paths)
        tubl_images = load_image(tubl_paths)
        anln_images = load_image(anln_paths) if anln_paths is not None else None

        images = [tubl_images, anln_images, dapi_images]
        nuc_segmentation = segmentator.pred_nuclei(images[2])
        cell_segmentation = segmentator.pred_cells(images)

        # post-processing
        nuclei_mask = label_nuclei(nuc_segmentation[0])
        nuclei_mask, cell_mask = label_cell(
            nuc_segmentation[0], cell_segmentation[0]
        )

        # apply preprocessing mask if the user want to merge nuclei
        # in preprocess_masks get_single_cell_mask

        assert set(np.unique(nuclei_mask)) == set(np.unique(cell_mask)), f"Mask mismatch for {image_path}, nuclei: {np.unique(nuclei_mask)}, cell: {np.unique(cell_mask)}"
        assert np.max(nuclei_mask) > 0 and np.max(cell_mask) > 0, f"No nuclei or cell mask found for {image_path}"

: 

In [None]:
def get_single_cell_mask(
    cell_mask,
    nuclei_mask,
    # final_size=None, # resize the mask to final_size
    rm_border=True, # removes cells touching the border
    remove_size=1000,
    dialation_radius=20,
):
    if rm_border:
        nuclei_mask = segmentation.clear_border(nuclei_mask)
        keep_value = np.unique(nuclei_mask)
        borderedcellmask = np.array([[x_ in keep_value for x_ in x] for x in cell_mask]).astype("uint8")
        cell_mask = cell_mask * borderedcellmask
        num_removed = len(keep_value) - len(np.unique(cell_mask))
    else:
        num_removed = 0

    assert set(np.unique(nuclei_mask)) == set(np.unique(cell_mask))

    # needs to be after clear border otherwise you get a boundary of nuclei that are still touching the edge
    # maybe that is due to the interpolation method
    # ideally this happens outside
    # if final_size is not None:
    #     nuclei_mask = cv2.resize(nuclei_mask, final_size, interpolation=cv2.INTER_NEAREST)
    #     cell_mask = cv2.resize(cell_mask, final_size, interpolation=cv2.INTER_NEAREST)

    ### see if nuclei are touching and merge them
    bin_nuc_mask = (nuclei_mask > 0).astype(np.int8)
    cls_nuc = morphology.closing(bin_nuc_mask, morphology.disk(dialation_radius))
    # get the labels of touching nuclei
    new_label_map = morphology.label(cls_nuc)
    new_label_idx = np.unique(new_label_map)[1:]

    new_cell_mask = np.zeros_like(cell_mask)
    new_nuc_mask = np.zeros_like(nuclei_mask)
    for new_label in new_label_idx:
        # get the label of the touching nuclei
        old_labels = np.unique(nuclei_mask[new_label_map == new_label])
        old_labels = old_labels[old_labels != 0]

        new_nuc_mask[np.isin(nuclei_mask, old_labels)] = new_label
        new_cell_mask[np.isin(cell_mask, old_labels)] = new_label

        # for old_label in old_labels:
        #     new_cell_mask[cell_mask == old_label] = new_label
        #     new_nuc_mask[nuclei_mask == old_label] = new_label

    assert set(np.unique(new_nuc_mask)) == set(np.unique(new_cell_mask))

    region_props = measure.regionprops(new_cell_mask, (new_cell_mask > 0).astype(np.uint8))
    if len(region_props) == 0:
        return new_cell_mask, new_nuc_mask, None, None
    else:
        bbox_array = np.array([x.bbox for x in region_props if x.area > remove_size])
        ## convert x1,y1,x2,y2 to x,y,w,h
        bbox_array[:, 2] = bbox_array[:, 2] - bbox_array[:, 0]
        bbox_array[:, 3] = bbox_array[:, 3] - bbox_array[:, 1]

        com_array = np.array([x.weighted_centroid for x in region_props if x.area > remove_size])
        return new_cell_mask, new_nuc_mask, bbox_array, com_array, num_removed
    # TODO somehow report the number removed