In [1]:
import json
import pathlib
from typing import List, NamedTuple, Tuple, Union
from matplotlib import pyplot as plt
from PIL import Image

import numpy as np

from msi_zarr_analysis.ml.dataset.translate_annotation import (
    TemplateTransform,
    colorize_data,
    load_ms_template,
    load_tif_file,
    match_template_multiscale_scipy,
    rgb_to_grayscale,
    scale_image,
    threshold_ms,
    threshold_overlay,
)
from msi_zarr_analysis.utils.check import open_group_ro
from msi_zarr_analysis.utils.cytomine_utils import (
    get_lipid_dataframe,
    get_page_bin_indices,
)

In [2]:
class DSConfig(NamedTuple):
    image_id_overlay: int  # Cytomine ID for the overlay image
    local_overlay_path: str  # local path of the (downloaded) overlay
    lipid_tm: str  # name of the lipid to base the template matching on

    project_id: int  # project id
    annotated_image_id: int  # image with the annotations

    zarr_path: str  # path to the non-binned zarr image

    term_list: List[str]  # force order on the classes

    save_image: Union[bool, str] = False

    transform_rot90: int = 0
    transform_flip_ud: bool = False
    transform_flip_lr: bool = False

    annotation_users_id: Tuple[int] = ()  # select these users only
    annotation_terms_id: Tuple[int] = ()  # select these terms only

    zarr_template_path: str = None  # use another group for the template matching

In [3]:
def run(
    config_path: str,
    bin_csv_path: str,
    *ds_config: DSConfig,
):
    if not ds_config:
        raise ValueError("a list of dataset configuration is required")

    from cytomine import Cytomine

    with open(config_path) as config_file:
        config_data = json.loads(config_file.read())
        host_url = config_data["HOST_URL"]
        pub_key = config_data["PUB_KEY"]
        priv_key = config_data["PRIV_KEY"]

    lipid_df = get_lipid_dataframe(bin_csv_path)

    # build ds
    with Cytomine(host_url, pub_key, priv_key):

        for ds_config_itm in ds_config:

            # build all datasets and merge them if there are more than one
            page_idx, bin_idx, *_ = get_page_bin_indices(
                ds_config_itm.image_id_overlay, ds_config_itm.lipid_tm, bin_csv_path
            )

            transform_template = TemplateTransform(
                ds_config_itm.transform_rot90,
                ds_config_itm.transform_flip_ud,
                ds_config_itm.transform_flip_lr,
            )
            
            stem = pathlib.Path(ds_config_itm.zarr_path).stem

            ms_group = open_group_ro(ds_config_itm.zarr_path)

            overlay = load_tif_file(
                page_idx=page_idx, disk_path=ds_config_itm.local_overlay_path
            )

            crop_idx, ms_template = load_ms_template(ms_group, bin_idx=bin_idx)

            ms_template = transform_template.transform_template(ms_template)
            ms_template_color = colorize_data(ms_template)
    
            image_th = threshold_overlay(overlay)
            image_th_gs = rgb_to_grayscale(image_th)
            
            template_th = threshold_ms(ms_template_color)
            template_th_gs = rgb_to_grayscale(template_th)

            max_scale = min(
                o_s / t_s for o_s, t_s in zip(image_th_gs.shape, template_th_gs.shape)
            )

            (score, scale, (height, width), (tl_x, tl_y)) = match_template_multiscale_scipy(
                image_th_gs, template_th_gs, max_scale
            )
            
            # pretty show
            
            # scale up
            scaled_th_template = scale_image(template_th, scale)
            assert scaled_th_template.shape[:2] == (height, width)
            
            # translate
            mask = np.zeros_like(overlay)
            mask[tl_y:tl_y+height, tl_x:tl_x+width] = scaled_th_template
            
            # get matching pixels
            mask = (mask > 0).any(axis=2)
            
            positive = overlay.copy()
            negative = overlay.copy()
            
            positive[~mask, :] = 0
            negative[mask, :] = 0
            
            #fig, axes = plt.subplots(1, 2, figsize=[12.8, 9.6])
            #axes[0].imshow(positive)
            #axes[0].set_title("positive")
            #axes[1].imshow(negative)
            #axes[1].set_title("negative")
            
            #fig.savefig(stem + "_tm_mask.png")
            #fig.tight_layout()

            Image.fromarray(overlay, 'RGB').save(stem + "_overlay.png")
            Image.fromarray(image_th, 'RGB').save(stem + "_overlay_th.png")
            Image.fromarray(ms_template_color, 'RGB').save(stem + "_template.png")
            Image.fromarray(template_th, 'RGB').save(stem + "_template_th.png")
            
            """
            # overlay & overlay filtered (image_th)
            fig, axes = plt.subplots(1, 2, figsize=(12.9, 9.6))
            axes[0].imshow(overlay)
            axes[0].set_title('Overlay')
            axes[1].imshow(image_th)
            axes[1].set_title('Filtered Overlay')
            fig.tight_layout()
            fig.savefig(stem + "_overlay_pretty.png")
            plt.close(fig)
            
            # template & template filtered (template_th)
            fig, axes = plt.subplots(1, 2, figsize=(12.9, 9.6))
            axes[0].imshow(ms_template_color)
            axes[0].set_title('Template')
            axes[1].imshow(template_th)
            axes[1].set_title('Filtered Template')
            fig.tight_layout()
            fig.savefig(stem + "_template_pretty.png")
            plt.close(fig)
            """

In [4]:
data_sources = [
    {
        "name": "region_13",
        "args": {
            "image_id_overlay": 545025763,
            "local_overlay_path": "../datasets/Adjusted_Cytomine_MSI_3103_Region013-Viridis-stacked.ome.tif",
            "lipid_tm": "LysoPPC",
            "project_id": 542576374,
            "annotated_image_id": 545025783,
            "transform_rot90": 1,
            "transform_flip_ud": True,
            "transform_flip_lr": False,
            "annotation_users_id": (),
            "zarr_template_path": "../datasets/comulis13_binned.zarr",
        },
        "base": "../datasets/comulis13",
    },
    {
        "name": "region_14",
        "args": {
            "image_id_overlay": 548365416,
            "local_overlay_path": "../datasets/Region014-Viridis-stacked.ome.tif",
            "lipid_tm": "LysoPPC",
            "project_id": 542576374,
            "annotated_image_id": 548365416,
            "transform_rot90": 1,
            "transform_flip_ud": True,
            "transform_flip_lr": False,
            "annotation_users_id": (),
            "zarr_template_path": "../datasets/comulis14_binned.zarr",
        },
        "base": "../datasets/comulis14",
    },
    {
        "name": "region_15",
        "args": {
            "image_id_overlay": 548365463,
            "local_overlay_path": "../datasets/Region015-Viridis-stacked.ome.tif",
            "lipid_tm": "LysoPPC",
            "project_id": 542576374,
            "annotated_image_id": 548365463,
            "transform_rot90": 1,
            "transform_flip_ud": True,
            "transform_flip_lr": False,
            "annotation_users_id": (),
            "zarr_template_path": "../datasets/comulis15_binned.zarr",
        },
        "base": "../datasets/comulis15",
    },
]

normalizations = [
    "",
    # others removed
]

classification_problems = {
    "SC_n_SC_p": {
        "term_list": ["SC negative AREA", "SC positive AREA"],
        "annotation_terms_id": (544926052, 544924846),
    },
    # others removed
}


for normalization in normalizations:

    for name, class_problem in classification_problems.items():
        base = name + (normalization or "_no_norm")

        ds_lst = []
        for source in data_sources:
            zarr_path = source["base"] + normalization + "_binned.zarr"
            ds_lst.append(
                DSConfig(
                    **source["args"],
                    **class_problem,
                    save_image=False,
                    zarr_path=zarr_path,
                )
            )
        run(
            "../config_cytomine.json",
            "../mz value + lipid name.csv",
            *ds_lst,
        )

[2022-08-17 15:34:30,298][INFO] [GET] [currentuser] CURRENT USER - 534530561 : mamodei | 200 OK
[2022-08-17 15:34:30,737][INFO] [GET] [sliceinstance collection] 12 objects | 200 OK
[2022-08-17 15:34:38,485][INFO] [GET] [sliceinstance collection] 14 objects | 200 OK
[2022-08-17 15:34:46,023][INFO] [GET] [sliceinstance collection] 15 objects | 200 OK
