In [2]:
import micropip
await micropip.install(["kaibu_utils"])

# Connect to the BioEngine

In [3]:
from imjoy_rpc.hypha import connect_to_server
server = await connect_to_server(
    {"name": "test client", "server_url": "https://ai.imjoy.io/", "method_timeout": 3000}
)
model_runner = await server.get_service("triton-client")

In [34]:
from imjoy import api

import io
from PIL import Image
import numpy as np
import base64
import pyodide
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
from skimage import transform, util
from skimage import filters, measure, segmentation
from skimage.morphology import (binary_erosion, closing, disk,
                                remove_small_holes, remove_small_objects)

import numpy as np
from kaibu_utils import fetch_image, mask_to_features

import numpy as np
from skimage.transform import resize

nucleus_segmentation_model = "10.5281/zenodo.6200999"
cell_segmentation_model = "10.5281/zenodo.6200635"
classification_model = "10.5281/zenodo.5911832"

COLORS =  ["red", "green", "blue", "yellow"] # microtubule, protein, nuclei, ER
NORMALIZE = {"mean": [124 / 255, 117 / 255, 104 / 255], "std": [1 / (0.0167 * 255)] * 3}
HIGH_THRESHOLD = 0.4
LOW_THRESHOLD = HIGH_THRESHOLD - 0.25


LABELS = {
    0: 'Nucleoplasm',
    1: 'Nuclear membrane',
    2: 'Nucleoli',
    3: 'Nucleoli fibrillar center',
    4: 'Nuclear speckles',
    5: 'Nuclear bodies',
    6: 'Endoplasmic reticulum',
    7: 'Golgi apparatus',
    8: 'Intermediate filaments',
    9: 'Actin filaments',
    10: 'Microtubules',
    11: 'Mitotic spindle',
    12: 'Centrosome',
    13: 'Plasma membrane',
    14: 'Mitochondria',
    15: 'Aggresome',
    16: 'Cytosol',
    17: 'Vesicles and punctate cytosolic patterns',
    18: 'Negative',
}

COLORS =  ["red", "green", "blue", "yellow"]


async def fetch_hpa_image(image_id):
    crops = []
    for color in COLORS:
        image = await fetch_image(f'https://images.proteinatlas.org/{image_id}_{color}.jpg', grayscale=True)
        crops.append(image)
    image = np.stack(crops, axis=0)
    # assert image.shape == (4, 128, 128)
    return image

def __fill_holes(image):
    """Fill_holes for labelled image, with a unique number."""
    boundaries = segmentation.find_boundaries(image)
    image = np.multiply(image, np.invert(boundaries))
    image = ndi.binary_fill_holes(image > 0)
    image = ndi.label(image)[0]
    return image


def label_nuclei(nuclei_pred):
    """Return the labeled nuclei mask data array.
    This function works best for Human Protein Atlas cell images with
    predictions from the CellSegmentator class.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    Returns:
    nuclei-label -- An array with unique numbers for each found nuclei
                    in the nuclei_pred. A value of 0 in the array is
                    considered background, and the values 1-n is the
                    areas of the cells 1-n.
    """
    img_copy = np.copy(nuclei_pred[..., 2])
    borders = (nuclei_pred[..., 1] > 0.05).astype(np.uint8)
    m = img_copy * (1 - borders)

    img_copy[m <= LOW_THRESHOLD] = 0
    img_copy[m > LOW_THRESHOLD] = 1
    img_copy = img_copy.astype(np.bool_)
    img_copy = binary_erosion(img_copy)
    # TODO: Add parameter for remove small object size for
    #       differently scaled images.
    # img_copy = remove_small_objects(img_copy, 500)
    img_copy = img_copy.astype(np.uint8)
    markers = measure.label(img_copy).astype(np.uint32)

    mask_img = np.copy(nuclei_pred[..., 2])
    mask_img[mask_img <= HIGH_THRESHOLD] = 0
    mask_img[mask_img > HIGH_THRESHOLD] = 1
    mask_img = mask_img.astype(np.bool_)
    mask_img = remove_small_holes(mask_img, 1000)
    # TODO: Figure out good value for remove small objects.
    # mask_img = remove_small_objects(mask_img, 8)
    mask_img = mask_img.astype(np.uint8)
    nuclei_label = segmentation.watershed(
        mask_img, markers, mask=mask_img, watershed_line=True
    )
    nuclei_label = remove_small_objects(nuclei_label, 2500)
    nuclei_label = measure.label(nuclei_label)
    return nuclei_label


def label_cell(nuclei_pred, cell_pred):
    """Label the cells and the nuclei.
    Keyword arguments:
    nuclei_pred -- a 3D numpy array of a prediction from a nuclei image.
    cell_pred -- a 3D numpy array of a prediction from a cell image.
    Returns:
    A tuple containing:
    nuclei-label -- A nuclei mask data array.
    cell-label  -- A cell mask data array.
    0's in the data arrays indicate background while a continous
    strech of a specific number indicates the area for a specific
    cell.
    The same value in cell mask and nuclei mask refers to the identical cell.
    NOTE: The nuclei labeling from this function will be sligthly
    different from the values in :func:`label_nuclei` as this version
    will use information from the cell-predictions to make better
    estimates.
    """
    def __wsh(
        mask_img,
        threshold,
        border_img,
        seeds,
        threshold_adjustment=0.35,
        small_object_size_cutoff=10,
    ):
        img_copy = np.copy(mask_img)
        m = seeds * border_img  # * dt
        img_copy[m <= threshold + threshold_adjustment] = 0
        img_copy[m > threshold + threshold_adjustment] = 1
        img_copy = img_copy.astype(np.bool_)
        img_copy = remove_small_objects(img_copy, small_object_size_cutoff).astype(
            np.uint8
        )

        mask_img[mask_img <= threshold] = 0
        mask_img[mask_img > threshold] = 1
        mask_img = mask_img.astype(np.bool_)
        mask_img = remove_small_holes(mask_img, 1000)
        mask_img = remove_small_objects(mask_img, 8).astype(np.uint8)
        markers = ndi.label(img_copy, output=np.uint32)[0]
        labeled_array = segmentation.watershed(
            mask_img, markers, mask=mask_img, watershed_line=True
        )
        return labeled_array

    nuclei_label = __wsh(
        nuclei_pred[..., 2] / 255.0,
        0.4,
        1 - (nuclei_pred[..., 1] + cell_pred[..., 1]) / 255.0 > 0.05,
        nuclei_pred[..., 2] / 255,
        threshold_adjustment=-0.25,
        small_object_size_cutoff=500,
    )

    # for hpa_image, to remove the small pseduo nuclei
    nuclei_label = remove_small_objects(nuclei_label, 2500)
    nuclei_label = measure.label(nuclei_label)
    # this is to remove the cell borders' signal from cell mask.
    # could use np.logical_and with some revision, to replace this func.
    # Tuned for segmentation hpa images
    threshold_value = max(0.22, filters.threshold_otsu(cell_pred[..., 2] / 255) * 0.5)
    # exclude the green area first
    cell_region = np.multiply(
        cell_pred[..., 2] / 255 > threshold_value,
        np.invert(np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8)),
    )
    sk = np.asarray(cell_region, dtype=np.int8)
    distance = np.clip(cell_pred[..., 2], 255 * threshold_value, cell_pred[..., 2])
    cell_label = segmentation.watershed(-distance, nuclei_label, mask=sk)
    cell_label = remove_small_objects(cell_label, 5500).astype(np.uint8)
    selem = disk(6)
    cell_label = closing(cell_label, selem)
    cell_label = __fill_holes(cell_label)
    # this part is to use green channel, and extend cell label to green channel
    # benefit is to exclude cells clear on border but without nucleus
    sk = np.asarray(
        np.add(
            np.asarray(cell_label > 0, dtype=np.int8),
            np.asarray(cell_pred[..., 1] / 255 > 0.05, dtype=np.int8),
        )
        > 0,
        dtype=np.int8,
    )
    cell_label = segmentation.watershed(-distance, cell_label, mask=sk)
    cell_label = __fill_holes(cell_label)
    cell_label = np.asarray(cell_label > 0, dtype=np.uint8)
    cell_label = measure.label(cell_label)
    cell_label = remove_small_objects(cell_label, 5500)
    cell_label = measure.label(cell_label)
    cell_label = np.asarray(cell_label, dtype=np.uint16)
    nuclei_label = np.multiply(cell_label > 0, nuclei_label) > 0
    nuclei_label = measure.label(nuclei_label)
    nuclei_label = remove_small_objects(nuclei_label, 2500)
    nuclei_label = np.multiply(cell_label, nuclei_label > 0)

    return nuclei_label, cell_label



def preprocess(image, scale_factor=0.25):
    assert len(image.shape) == 3
    image = transform.rescale(image, scale_factor, channel_axis=2)
    # nuc_image = np.dstack((image[..., 2], image[..., 2], image[..., 2]))
    image = image.transpose([2, 0, 1])
    return image

def run_model_torch(model, imgs, scale_factor=0.25):
    import torch
    imgs = torch.tensor(imgs).float()
    mean = torch.as_tensor(NORMALIZE["mean"])
    std = torch.as_tensor(NORMALIZE["std"])
    imgs = imgs.sub_(mean[:, None, None]).div_(std[:, None, None])
    # np.save("test_input.npy", imgs)
    imgs = model(imgs)
    imgs = imgs.to("cpu").detach().numpy()
    # np.save("test_output.npy", imgs)
    imgs = transform.rescale(imgs[0].transpose([1, 2, 0]), 1.0/scale_factor, channel_axis=2)
    imgs = util.img_as_ubyte(imgs)
    return imgs

async def segment(model, imgs, scale_factor=0.25):
    imgs = preprocess(imgs, scale_factor=0.25)
    image = imgs[None, :, :, :].astype('float32')
    kwargs = {"inputs": [image], "model_id": model}
    ret = await model_runner.execute(
        inputs=[kwargs],
        model_name="bioengine-model-runner",
        serialization="imjoy",
    )
    result = ret["result"]
    # result = await model_runner.execute_model([image], model)
    if result['success']:
        imgs = result['outputs'][0]
        # np.save("test_output.npy", imgs)
        imgs = transform.rescale(imgs[0].transpose([1, 2, 0]), 1.0/scale_factor, channel_axis=2)
        imgs = util.img_as_ubyte(imgs)
        return imgs
    else:
        raise Exception(f"Failed to run model: {model}, error: {result['error']}")

async def classify(model, image, threshold=0.3, expected_shape=[4, 128, 128]):
    image = image.transpose(2, 0, 1)  # HxWxC to CxHxW
    image = resize(image, expected_shape)
    # plt.imshow(image[0:3, :, :].transpose(1,2,0))
    # after resizing, the pixel range will become 0~1
    image *= 255
    image = image[None, :, :, :].astype('float32')
    kwargs = {"inputs": [image], "model_id": model}
    ret = await model_runner.execute(
        inputs=[kwargs],
        model_name="bioengine-model-runner",
        serialization="imjoy",
    )
    result = ret["result"]
    # assert result["success"] == True, result["error"]
    # assert result["outputs"][0].shape == (1, 3, 128, 128), str(
    #     result["outputs"][0].shape
    # )
    # result = await model_runner.execute_model([image], model)
    if result['success']:
        features, classes = result['outputs']
        preds = [(LABELS[i], prob) for i, prob in enumerate(classes[0].tolist()) if prob>threshold]
        return preds
    else:
        raise Exception(f"Failed to run model: {model}, error: {result['error']}")


def cell_crop_augment(image, mask, paddings=(20, 20, 20, 20)):
    top, bottom, left, right = paddings
    label_image = measure.label(mask)
    max_area = 0
    for region in measure.regionprops(label_image):
        if region.area > max_area:
            max_area = region.area
            min_row, min_col, max_row, max_col = region.bbox

    min_row, min_col = max(min_row - top, 0), max(min_col - left, 0)
    max_row, max_col = min(max_row + bottom, mask.shape[0]), min(max_col + right, mask.shape[1])

    image = image[min_row:max_row, min_col:max_col]
    mask = mask[min_row:max_row, min_col:max_col]
    return image, mask

def generate_cell_indices(cell_mask):
    cell_indices = np.sort(list(set(np.unique(cell_mask).tolist()) - {0, }))
    return cell_indices.tolist()

def crop_image(image_raw, mask_raw):
    cell_indices = generate_cell_indices(mask_raw)
    crop_images = []
    for maskid in cell_indices:
        image = image_raw.copy()
        mask = mask_raw.copy()
        image[mask != maskid] = 0
        image, _ = cell_crop_augment(image, (mask == maskid).astype('uint8'))
        crop_images.append(image)
    return zip(cell_indices, crop_images)

def read_image(bytes, name=None, grayscale=False, size=None):
    buffer = io.BytesIO(bytes)
    buffer.name = name or url.split('?')[0].split('/')[1]
    image = Image.open(buffer).convert('L')
    if grayscale:
        image = image.convert('L')
    if size:
        image = image.resize(size=size)
    image = np.array(image)
    return image


def encode_image(image):
    image = Image.fromarray(image)
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode('ascii')
    return img_str


def watershed(image, seed=None):
    assert seed is not None
    nuclei_mask = label_nuclei(seed)
    nuclei_mask, cell_mask = label_cell(seed, image)
    return cell_mask

## Load image

In [5]:
image_id = "115/672_E2_1"
image_raw = await fetch_hpa_image(image_id)
nuc_image = image_raw[2, :, :]
nuc_image = np.stack((nuc_image, nuc_image, nuc_image), axis=-1)
cell_image = np.stack((image_raw[0, :, :], image_raw[3, :, :], image_raw[2, :, :]), axis=-1)

## Segment Nuclei

In [10]:
nucleus_segmentation_model = "conscientious-seashell"
nuclei_prediction = await segment(nucleus_segmentation_model, nuc_image)

## Segment Cells

In [16]:
cell_segmentation_model = "loyal-parrot"
cell_prediction = await segment(cell_segmentation_model, cell_image)
cell_mask = watershed(cell_prediction, seed=nuclei_prediction)

## Classify Protein Pattern in each Cell

In [None]:
classification_model = "straightforward-crocodile"
protein_predictions = {}
cell_images = crop_image(image_raw.transpose(1,2,0), cell_mask)
for mask_id, image_crop in cell_images:
    preds = await classify(classification_model, image_crop)
    protein_predictions[mask_id] = preds
    print(f"Prediction for cell {mask_id}: {preds}")

## Visualize Results in Kaibu

In [45]:
class ImJoyPlugin():
    async def setup(self):
        viewer = await api.createWindow(
            src="https://kaibu.org"
        )
        viewer.view_image(f"https://images.proteinatlas.org/{image_id}_blue_red_green.jpg")
        shapes = mask_to_features(np.flipud(cell_mask))
        names = []
        for i in range(len(shapes)):
            pred = protein_predictions[i+1]
            label = [f'{p[0]}({p[1]:.2})' for p in pred]
            names.append(",".join(label))
        await viewer.add_shapes(shapes, label=names, name="Prediction", edge_color='#4FED10', text_placement='point')

api.export(ImJoyPlugin())

<IPython.core.display.Javascript object>

<Future finished result=[]>

## Constantin's napari code

In [None]:
import bioimageio.core
import napari

def segment_cells(image):
    nucleus_model = bioimageio.core.load_resource_description(
        "conscientious-seashell"
    )
    with bioimageio.create_prediction_pipeline(nucleus_model) as pp:
        nucleus_predictions = pp(image)
    nuclei = label(nucleus_predictions > 0.5)

    membrane_model = bioimageio.core.load_resource_description(
        "loyal-parrot"
    )
    with bioimageio.create_prediction_pipeline(membranes_model) as pp:
        membranes = pp(image)

    cells = watershed(membranes, seeds=nuclei)
    return cells

def classify_cells(image, cells):
    classification_model = bioimageio.core.load_resource_description(
        "straightforward-crocodile"
    )

    classes = []
    with bioimageio.create_prediction_pipeline(classification_model) as pp:
        for cell in cells:
            this_class = pp(image[cell.bounding_box])
            classes.append(this_class)
    return classes

def main(image):
    cells = segment_cells(image)
    classes = classify_cells(image, cells)
    visualize_in_napari(image, cells, classes)