In [None]:
# default_exp clone_counters

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# export
from functools import partial
from glob import glob
from typing import Callable

import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
import xarray as xr
from skimage import measure

from py_clone_detective.utils import (
    add_scale_regionprops_table_area_measurements,
    calculate_corresponding_labels,
    calculate_overlap,
    check_channels_input_suitable_and_return_channels,
    determine_labels_across_other_images_using_centroids,
    extend_region_properties_list,
    get_all_labeled_clones_unmerged_and_merged,
    img_path_to_xarr,
    last2dims,
    lazy_props,
    reorder_df_to_put_ch_info_first,
    update_1st_coord_and_dim_of_xarr,
)

# CloneCounter Classes

## Parent Class

In [None]:
# export
class CloneCounter:
    def __init__(
        self,
        exp_name: str,
        img_name_regex: str,
        pixel_size: float,
        tot_seg_ch: str = "C0",
    ):
        self.exp_name = exp_name
        self.img_name_regex = img_name_regex
        self.pixel_size = pixel_size
        self.tot_seg_ch = tot_seg_ch

    def add_images(self, **channel_path_globs):
        return img_path_to_xarr(
            self.img_name_regex,
            self.pixel_size,
            ch_name_for_first_dim="img_channels",
            **channel_path_globs,
        )

    def add_segmentations(
        self,
        additional_func_to_map: Callable = None,
        ad_func_kwargs: dict = None,
        **channel_path_globs,
    ):
        segmentations = img_path_to_xarr(
            self.img_name_regex,
            self.pixel_size,
            ch_name_for_first_dim="seg_channels",
            **channel_path_globs,
        )

        if additional_func_to_map is not None:
            segmentations.data = segmentations.data.map_blocks(
                additional_func_to_map, **ad_func_kwargs, dtype=np.uint16
            )

        segmentations.data = segmentations.data.map_blocks(
            last2dims(partial(measure.label)), dtype=np.uint16
        )
        return segmentations

    def combine_C0_overlaps_and_measurements(self):
        ov_df = (
            self.results_overlaps.pivot(
                index=["img_name", "C0_labels"],
                columns=["colocalisation_ch"],
                values="is_in_label",
            )
            .query("C0_labels != 0")
            .copy()
        )
        sk_df = self.results_measurements.query("seg_ch== 'C0'").set_index(
            ["seg_img", "label"]
        )
        sk_df.index.rename(["img_name", "C0_labels"], inplace=True)
        return pd.merge(ov_df, sk_df, left_index=True, right_index=True)

    def determine_seg_img_channel_pairs(
        self, seg_channels: list = None, img_channels: list = None
    ):
        seg_channels = check_channels_input_suitable_and_return_channels(
            channels=seg_channels,
            available_channels=self.image_data.seg_channels.values.tolist(),
        )

        img_channels = check_channels_input_suitable_and_return_channels(
            channels=img_channels,
            available_channels=self.image_data.img_channels.values.tolist(),
        )

        seg_img_channel_pairs = pd.DataFrame()
        seg_img_channel_pairs["image_channel"] = pd.Series(img_channels)
        seg_img_channel_pairs["segmentation_channel"] = pd.Series(seg_channels)
        self.seg_img_channel_pairs = seg_img_channel_pairs.fillna(method="ffill")[
            ["segmentation_channel", "image_channel"]
        ]

    def make_measurements(
        self,
        seg_channels: list = None,
        img_channels: list = None,
        extra_properties: list = None,
        **kwargs,
    ):

        self.determine_seg_img_channel_pairs(seg_channels, img_channels)

        properties = extend_region_properties_list(extra_properties)

        results = list()
        for _, seg_ch, img_ch in self.seg_img_channel_pairs.itertuples():
            for seg, img in zip(
                self.image_data["segmentations"].loc[seg_ch],
                self.image_data["images"].loc[img_ch],
            ):
                results.append(
                    lazy_props(
                        seg.data,
                        img.data,
                        seg.seg_channels.item(),
                        img.img_channels.item(),
                        seg.img_name.item(),
                        img.img_name.item(),
                        properties,
                        **kwargs,
                    )
                )

        df = dd.from_delayed(results).compute()
        df = add_scale_regionprops_table_area_measurements(df, self.pixel_size)
        self.results_measurements = reorder_df_to_put_ch_info_first(df)
        self._determine_max_seg_label_levels()

    def _determine_max_seg_label_levels(self):
        self.tot_seg_ch_max_labels = (
            self.image_data["segmentations"]
            .loc[self.tot_seg_ch]
            .data.map_blocks(
                lambda x: np.unique(x).shape[0], drop_axis=(1, 2), dtype=np.uint16,
            )
            .compute()
            .max()
        )

    def _create_df_from_arr(self, arr):
        return (
            xr.DataArray(
                np.moveaxis(arr, 1, 0),
                coords=(
                    self.image_data["segmentations"].coords["seg_channels"][1:],
                    self.image_data["segmentations"].coords["img_name"],
                    np.arange(self.tot_seg_ch_max_labels),
                ),
                dims=("colocalisation_ch", "img_name", "C0_labels",),
            )
            .to_dataframe("is_in_label")
            .reset_index()
            .dropna()
        )

    def measure_overlap(self):
        self._determine_max_seg_label_levels()
        arr = (
            self.image_data["segmentations"]
            .data.map_blocks(
                calculate_overlap,
                drop_axis=[0],
                dtype=np.float64,
                num_of_segs=self.image_data["segmentations"].shape[0],
                preallocate_value=self.tot_seg_ch_max_labels,
            )
            .compute()
        )

        df = self._create_df_from_arr(arr)
        df["is_in_label"] = df["is_in_label"].astype(np.uint16)
        self.results_overlaps = df[
            ["img_name", "C0_labels", "colocalisation_ch", "is_in_label"]
        ]

    def clones_to_keep_as_dict(self, query_for_pd: str):
        return (
            self.results_measurements.query(query_for_pd)
            .groupby("int_img")
            .agg({"label": lambda x: list(x)})["label"]
            .to_dict()
        )

    def get_centroids_list(self):
        df = self.results_measurements.query("int_img_ch == @self.tot_seg_ch")
        centroids_list = list()
        for img_name in df["int_img"].unique():
            centroids_list.append(
                (
                    df.query("int_img == @img_name")
                    .loc[:, ["centroid-0", "centroid-1"]]
                    .values.astype(int)
                )
            )
        return centroids_list

    def add_clones_and_neighbouring_labels(
        self,
        query_for_pd: str = 'int_img_ch == "C1" & mean_intensity > 1000',
        name_for_query: str = "filt_C1_intensity",
        calc_clones: str = True,
    ):
        new_coord = [
                    "extended_tot_seg_labels",
                    "total_neighbour_counts",
                    "inside_clone_neighbour_counts",
                    "outside_clone_neighbour_counts",
                ]
        
        if calc_clones:
            new_coord.append("clone")
            
        
        clone_coords, clone_dims = update_1st_coord_and_dim_of_xarr(
            self.image_data["images"],
            new_coord=new_coord,
            new_dim=f"{name_for_query}_neighbours",
        )

        clones_to_keep = self.clones_to_keep_as_dict(query_for_pd)

        new_label_imgs = get_all_labeled_clones_unmerged_and_merged(
            self.image_data["segmentations"].loc[self.tot_seg_ch],
            clones_to_keep,
            calc_clones,
        )

        return xr.DataArray(
            data=new_label_imgs,
            coords=clone_coords,
            dims=clone_dims,
            attrs={f"{self.tot_seg_ch}_labels_kept_query": query_for_pd},
        )

    def colabels_to_df(self, colabels, name_for_query):
        return (
            xr.DataArray(
                colabels,
                coords=(
                    self.image_data[name_for_query].coords[
                        "extended_labels_neighbour_counts"
                    ],
                    foo.image_data[name_for_query].coords["img_name"],
                    range(colabels.shape[2]),
                ),
                dims=("extended_labels_neighbour_counts", "img_name", "labels"),
            )
            .to_dataframe("colabel")
            .reset_index()
            .dropna()
            .pivot(
                index=["img_name", "labels"],
                columns=["extended_labels_neighbour_counts"],
                values="colabel",
            )
            .astype(np.uint16)
        )

    def measure_clones_and_neighbouring_labels(self, name_for_query):
        self.get_centroids_list()
        colabels = calculate_corresponding_labels(
            self.image_data[name_for_query].data,
            self.get_centroids_list(),
            self.image_data[name_for_query].shape[0],
            foo.tot_seg_ch_max_labels,
        )

        self.results_clones_and_neighbour_counts = self.colabels_to_df(
            colabels, name_for_query
        )

## CloneCounter subclasses

In [None]:
# export
class LazyCloneCounter(CloneCounter):
    def __init__(self, exp_name: str, img_name_regex: str, pixel_size: float):
        super().__init__(exp_name, img_name_regex, pixel_size)

    def add_images(self, **channel_path_globs):
        self.image_data = xr.Dataset(
            {"images": super().add_images(**channel_path_globs)}
        )

    def add_segmentations(
        self,
        additional_func_to_map: Callable = None,
        ad_func_kwargs: dict = None,
        **channel_path_globs
    ):
        self.image_data["segmentations"] = super().add_segmentations(
            additional_func_to_map, ad_func_kwargs, **channel_path_globs
        )

    def add_clones_and_neighbouring_labels(
        self,
        query_for_pd: str = 'int_img_ch == "C1" & mean_intensity > 1000',
        name_for_query: str = "filt_C1_intensity",
        calc_clones: str = True,
    ):
        self.image_data[name_for_query] = super().add_clones_and_neighbouring_labels(
            query_for_pd, name_for_query, calc_clones
        )

In [None]:
# export
class PersistentCloneCounter(CloneCounter):
    def __init__(self, exp_name: str, img_name_regex: str, pixel_size: float):
        super().__init__(exp_name, img_name_regex, pixel_size)

    def add_images(self, **channel_path_globs):
        self.image_data = xr.Dataset(
            {"images": super().add_images(**channel_path_globs)}
        ).persist()

    def add_segmentations(
        self,
        additional_func_to_map: Callable = None,
        ad_func_kwargs: dict = None,
        **channel_path_globs,
    ):
        self.image_data["segmentations"] = (
            super()
            .add_segmentations(
                additional_func_to_map, ad_func_kwargs, **channel_path_globs
            )
            .persist()
        )

    def add_clones_and_neighbouring_labels(
        self,
        query_for_pd: str = 'int_img_ch == "C1" & mean_intensity > 1000',
        name_for_query: str = "filt_C1_intensity",
        calc_clones: str = True,
    ):
        self.image_data[name_for_query] = (
            super()
            .add_clones_and_neighbouring_labels(
                query_for_pd, name_for_query, calc_clones
            )
            .persist()
        )

In [None]:
# hide
from dask.distributed import Client

c = Client()
c

0,1
Connection method: Cluster object,Cluster type: LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Status: running,Using processes: True
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads:  8,Total memory:  8.00 GiB

0,1
Comm: tcp://127.0.0.1:58655,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads:  8
Started:  Just now,Total memory:  8.00 GiB

0,1
Comm: tcp://127.0.0.1:58673,Total threads: 2
Dashboard: http://127.0.0.1:58675/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:58660,
Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-0bf2qze1,Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-0bf2qze1

0,1
Comm: tcp://127.0.0.1:58668,Total threads: 2
Dashboard: http://127.0.0.1:58669/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:58658,
Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-iinl4z12,Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-iinl4z12

0,1
Comm: tcp://127.0.0.1:58665,Total threads: 2
Dashboard: http://127.0.0.1:58666/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:58657,
Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-s60auxkg,Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-s60auxkg

0,1
Comm: tcp://127.0.0.1:58671,Total threads: 2
Dashboard: http://127.0.0.1:58672/status,Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:58659,
Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-s4tfgf97,Local directory: /Users/ottomorris/Documents/py_clone_detective/dask-worker-space/worker-s4tfgf97


## Example using LazyCloneCounter with measure_overlap

In [None]:
foo._determine_max_seg_label_levels()

In [None]:
foo.tot_seg_ch_max_labels

In [None]:
foo = LazyCloneCounter("Marcm2a_E7F1", r"a\dg\d\dp\d", 0.275)

foo.add_images(
    C0="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C0/C0_imgs/*.tif*",
    C1="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C1/C1_imgs/*.tif*",
    C2="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C2/C2_imgs/*.tif*",
    C3="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C3/C3_imgs/*.tif*",
)

foo.add_segmentations(
    C0="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C0/C0_label_imgs_combined_C3/*.tif*",
    C1="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C1/C1_binaries/*.tif*",
    C2="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C2/C2_label_imgs_v2/*.tif*",
    C3="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C3/C3_label_imgs/*.tif*",
)
foo.make_measurements(extra_properties=["convex_area"],)
foo.measure_overlap()
foo.combine_C0_overlaps_and_measurements()

Unnamed: 0_level_0,Unnamed: 1_level_0,C1,C2,C3,seg_ch,int_img_ch,int_img,area,mean_intensity,centroid-0,centroid-1,convex_area,area_um2,convex_area_um2
img_name,C0_labels,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
a1g01p1,1,0,0,0,C0,C0,a1g01p1,353,1143.484419,5.025496,97.269122,366,26.695625,27.678750
a1g01p1,2,0,0,0,C0,C0,a1g01p1,613,2521.355628,8.014682,222.654160,640,46.358125,48.400000
a1g01p1,3,0,0,0,C0,C0,a1g01p1,275,2570.625455,5.941818,383.338182,288,20.796875,21.780000
a1g01p1,4,0,3,0,C0,C0,a1g01p1,700,2873.265714,14.550000,519.455714,724,52.937500,54.752500
a1g01p1,5,0,0,0,C0,C0,a1g01p1,33,2838.606061,0.848485,579.636364,35,2.495625,2.646875
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
a2g13p3,249,0,150,0,C0,C0,a2g13p3,712,1801.213483,792.443820,551.751404,739,53.845000,55.886875
a2g13p3,250,0,146,0,C0,C0,a2g13p3,308,1980.500000,793.714286,459.652597,318,23.292500,24.048750
a2g13p3,251,0,148,0,C0,C0,a2g13p3,678,2612.631268,799.069322,500.057522,696,51.273750,52.635000
a2g13p3,252,0,0,0,C0,C0,a2g13p3,672,1393.748512,798.055060,786.084821,689,50.820000,52.105625


## Example using LazyCloneCounter with add_clones_and_neighbouring_labels

In [None]:
from skimage import morphology

In [None]:
foo = LazyCloneCounter("Marcm2a_E7F1", r"a\dg\d\dp\d", 0.275)

foo.add_images(
    C0="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C0/C0_imgs/*.tif*",
    C1="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C1/C1_imgs/*.tif*",
    C2="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C2/C2_imgs/*.tif*",
    C3="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C3/C3_imgs/*.tif*",
)

In [None]:
foo.add_segmentations(
    morphology.remove_small_objects,
    ad_func_kwargs={"min_size": 49},
    C0="../current_imaging_analysis/MARCM2A_E7F1_refactoring/C0/C0_label_imgs_combined_C3/*.tif*",
)

In [None]:
foo.make_measurements(extra_properties=["convex_area"],)

In [None]:
foo.add_clones_and_neighbouring_labels(
    query_for_pd='int_img_ch == "C1" & mean_intensity > 1000',
    name_for_query="filt_C1_int",
    calc_clones=True,
)

In [None]:
foo.add_clones_and_neighbouring_labels(
    query_for_pd='int_img_ch == "C2" & mean_intensity > 500',
    name_for_query="filt_C2_int",
    calc_clones=False,
)

In [None]:
#hide
import napari
view = napari.Viewer()
view.add_image(foo.image_data['images'].data, channel_axis = 0)
view.add_labels(foo.image_data['filt_C1_int'].data)
view.add_labels(foo.image_data['filt_C2_int'].data)

In [None]:
foo.measure_clones_and_neighbouring_labels(name_for_query="filt_C1_int")

In [None]:
df = foo.results_clones_and_neighbour_counts.reset_index().copy()

In [None]:
df = df.query("extended_tot_seg_labels != 0")[
    [
        "extended_tot_seg_labels",
        "img_name",
        "clone",
        "total_neighbour_counts",
        "inside_clone_neighbour_counts",
        "outside_clone_neighbour_counts",
    ]
]

In [None]:
bar = foo.results_measurements.copy()

In [None]:
pd.merge(
    bar,
    df,
    how="inner",
    left_on=["int_img", "label"],
    right_on=["img_name", "extended_tot_seg_labels"],
).drop(columns=["img_name", "extended_tot_seg_labels"])