#### ⚠️ The sample data is not yet applicable to this notebook. If you wish to use the sample data, please begin on notebooks [1.1a](./1.1a_infer_masks_from-composite_single_cell.ipynb) or [1.1b](./1.1b_infer_masks_from-composite_multiple-cells.ipynb). ⚠️

# Infer ***nucleus***, ***cellmask***, and ***cytoplasm*** from a composite image of cytoplasmic organelles
### Default workflow: ***"1"*** (for images with cytoplasmic organelles, a **nuclei** marker, no **cell membrane** makers, and more than one cell per field of view)
--------------
## **OBJECTIVES:**



### <input type="checkbox"/> Infer sub-cellular component #1: ***nuclei***
Segment the ***nuclei*** from a single channel (nuclei marker). This will be necessary to determine the other subcellular compartment - like the ***cytoplasm***. Nuclei will also be used to seed the instance segmentation of the ***cell*** area (***cellmask***).

> ###### **Convention: *"nuclei"* for the segmentation of ALL nuclei in the image. *"nucleus"* for the single nucleus associated to the single cell being analyzed after the cell with the most signal is determine.**

### <input type="checkbox"/> Infer sub-cellular component #2: ***cellmask***
Segment the cell area (the ***cellmask***) from a composite image of multiple organelle markers combined. The **cellmask** will be necessary for determining which organelles are in which cell

### <input type="checkbox"/> Infer sub-cellular component #3: ***cytoplasm***
Segment the ***cytoplasm*** from the cellmask and nuclei outputs. We will first select the single nucleus that is within our cellmask. Then, a logical and will be applied to produce the cytoplasmic area.

> ###### **📝 this workflow is optimized for images with multiple fluorescent cells in the field of view**

---------
## **masks workflow** ***"1"***
### summary of steps


➡️ **EXTRACTION**
- **`STEP 1`** - Segment nuclei

    - select single channel containing the nuclei marker (channel number = user input)
    - rescale intensity of composite image (min=0, max=1)
    - median filter (median size = user input)
    - gaussian filter (sigma = user input)
    - log transform image
    - calculate Li's minimum cross entropy threshold value
    - apply threshold to image (thresholding options = user input)
    - fill holes (hole size = user input)
    - remove small objects (object size = user input)

- **`STEP 2`** - Create composite image

    - determine weight to apply to each channel of the intensity image (w# = user input)
    - rescale summed image intensities (rescale = user input)

**PRE-PROCESSING**
- **`STEP 3`** - Rescale and smooth image

    - rescale intensity of composite image (min=0, max=1)
    - median filter (median size = user input)
    - gaussian filter (sigma = user input)

- **`STEP 4`** Log transform + Scharr edge detection

    - log transform image
    - apply scharr edge detection filter 
    - combine log image + scharr edge filtered intensity

**CORE PROCESSING**
- **`STEP 5`** Global + local thresholding (AICSSeg – MO)

    - apply MO thresholding method from the Allen Cell [aicssegmentation](https://github.com/AllenCell/aics-segmentation) package (threshold options = user input)

**POST-PROCESSING**
- **`STEP 6`** Remove small holes and objects

    - fill holes (hole size = user input)
    - remove small objects (object size = user input)
    - filter method (method = user input)

**POST-POST-PROCESSING**
- **`STEP 7`** Select one cellmask/nuclei based on signal

    - label unique cell objects based on watershed seeded from the nuclei objects
    - select the single cell with the highest combined fluorescence

- **`STEP 8`** Segment cytoplasm

    - mask nuclei with ***cellmask*** to select single ***nucleus***
    - erode ***nucleus*** (shrink; *optional*)
    - Segment cytoplasm from logical **XOR** of ***nucleus*** and ***cellmask***

**EXPORT** ➡️
- **`STEP 9`** - Stack masks

    - stack masks in order of nucleus, cellmask and cytoplasm mask

    > ###### ***Note:* this pipeline will eventually include a selection step to identify the cellmask that are properly labeled with all fluorescent markers. This could be one single cell per image, or more if applicable data is available.**

---------
## **IMPORTS**

#### &#x1F3C3; **Run code; no user input required**

&#x1F453; **FYI:** This code block loads all of the necessary python packages and functions you will need for this notebook. 

In [None]:
from pathlib import Path
import os
from typing import Union

import numpy as np
import napari
from napari.utils.notebook_display import nbscreenshot

from skimage.morphology import binary_erosion

from infer_subc.core.file_io import (read_czi_image,
                                     export_inferred_organelle,
                                     list_image_files)
from infer_subc.core.img import *
from infer_subc.organelles import (choose_max_label_cellmask_union_nucleus,
                                   non_linear_cellmask_transform)



%load_ext autoreload
%autoreload 2

## **LOAD AND READ IN IMAGE FOR PROCESSING**

#### &#x1F6D1; &#x270D; **User Input Required:**

In [None]:
# Specify the file type of your raw data that will be analyzed. Ex) ".czi" or ".tiff"
im_type = ".czi"

## Define the path to the directory that contains the input image folder.
data_root_path = Path(os.path.expanduser("~")) / "Documents/Python_Scripts/Infer-subc"

## Specify which subfolder that contains the input data
in_data_path = data_root_path / "raw"

## Specify the output folder
out_data_path = data_root_path / "out"

# Specify which file you'd like to segment from the img_file_list
test_img_n = 5

#### &#x1F3C3; **Run code; no user input required**

In [None]:
if not Path.exists(out_data_path):
    Path.mkdir(out_data_path)
    print(f"making {out_data_path}")

img_file_list = list_image_files(in_data_path,im_type)
# pd.set_option('display.max_colwidth', None)
# pd.DataFrame({"Image Name":img_file_list})

In [None]:
test_img_name = img_file_list[test_img_n]

img_data,meta_dict = read_czi_image(test_img_name)

channel_names = meta_dict['name']
img = meta_dict['metadata']['aicsimage']
scale = meta_dict['scale']
channel_axis = meta_dict['channel_axis']

# ***EXTRACTION prototype - masks***


## **`STEP 1` - Segment nuclei**

- select single channel containing the nuclei marker (channel number = user input)

In [None]:
###################
# INPUT
###################
NUC_CH = 0
raw_nuclei = select_channel_from_raw(img_data, NUC_CH)

- rescale intensity of composite image (min=0, max=1)
- median filter (median size = user input)
- gaussian filter (sigma = user input)

In [None]:
med_filter_size = 4   
gaussian_smoothing_sigma = 1.34

nuclei =  scale_and_smooth(raw_nuclei,
                           median_size = med_filter_size, 
                           gauss_sigma = gaussian_smoothing_sigma)

- log transform image
- calculate Li's minimum cross entropy threshold value
- apply threshold to image (thresholding options = user input)


> #### ASIDE: Thresholding
> ###### [Thresholding](https://en.wikipedia.org/wiki/Thresholding_%28image_processing%29) is used to create binary images. A threshold value determines the intensity value separating foreground pixels from background pixels. Foregound pixels are pixels brighter than the threshold value, background pixels are darker. In many cases, images can be adequately segmented by thresholding followed by labelling of *connected components*, which is a fancy way of saying "groups of pixels that touch each other".
> 
> ###### Different thresholding algorithms produce different results. [Otsu's method](https://en.wikipedia.org/wiki/Otsu%27s_method) and [Li's minimum cross entropy threshold](https://scikit-image.org/docs/dev/auto_examples/developers/plot_threshold_li.html) are two common algorithms. Below, we use Li. You can use `skimage.filters.threshold_<TAB>` to find different thresholding methods.


In [None]:
# log transform the image, calculate the threshold value using Li minimum cross entropy method, inverse log transform the value
# apply the threshold value taking into account the user determined min, max, and adjustment values
threshold_factor = 0.9
thresh_min = .1
thresh_max = 1.

li_thresholded = apply_log_li_threshold(nuclei, 
                                        thresh_factor=threshold_factor, 
                                        thresh_min=thresh_min, 
                                        thresh_max=thresh_max)

- fill holes (hole size = user input)
- remove small objects (object size = user input)

> ###### 📝 **the size parameters are by convention defined as one dimensional "width", so the inputs to the functions are _squared_ i.e. raised to the power of 2: `**2` for 2D analysis. For volumetric (3D) analysis this would be _cubed_: `**3`.**

In [None]:
# fill small holes then exclude small objects
hole_min_width = 0
hole_max_width = 25  

small_object_width = 15

# combine the above functions into one for downstream use in plugin
cleaned_img = fill_and_filter_linear_size(li_thresholded, 
                                           hole_min=hole_min_width, 
                                           hole_max=hole_max_width, 
                                           min_size= small_object_width,
                                           method='3D')

> ###### 📝 **Create labels for the nuclei seeds that will be used during the watershedding algorithm**

In [None]:
###################
# LABELING
###################
# create instance segmentation based on connectivity
nuclei_labels = label_uint16(cleaned_img)

In [None]:
nuclei_labels.dtype

Define `_infer_nuclei_fromlabel` function

> ###### 📝 **these functions mainly serve for downstream prototyping in the notebooks. Each step above has an independent function that is implemented in the plugin for ease of use.**

In [None]:
##########################
#  _infer_nuclei
##########################
def _infer_nuclei_fromlabel(in_img: np.ndarray, 
                            nuc_ch: Union[int,None],
                            median_sz: int, 
                            gauss_sig: float,
                            thresh_factor: float,
                            thresh_min: float,
                            thresh_max: float,
                            min_hole_w: int,
                            max_hole_w: int,
                            small_obj_w: int,
                            fill_filter_method: str
                            ) -> np.ndarray:
    """
    Procedure to infer nuclei from linear unmixed input.

    Parameters
    ------------
    in_img: np.ndarray
        a 3d image containing all the channels
    median_sz: int
        width of median filter for signal
    gauss_sig: float
        sigma for gaussian smoothing of  signal
    thresh_factor: float
        adjustment factor for log Li threholding
    thresh_min: float
        abs min threhold for log Li threholding
    thresh_max: float
        abs max threhold for log Li threholding
    max_hole_w: int
        hole filling cutoff for nuclei post-processing
    small_obj_w: int
        minimu object size cutoff for nuclei post-processing

    Returns
    -------------
    nuclei_object
        mask defined extent of NU
    
    """
    ###################
    # EXTRACT
    ###################                
    nuclei = select_channel_from_raw(in_img, nuc_ch)

    ###################
    # PRE_PROCESSING
    ###################                
    nuclei =  scale_and_smooth(nuclei,
                        median_size = median_sz, 
                        gauss_sigma = gauss_sig)

    ###################
    # CORE_PROCESSING
    ###################
    nuclei_object = apply_log_li_threshold(nuclei, 
                                           thresh_factor=thresh_factor, 
                                           thresh_min=thresh_min, 
                                           thresh_max=thresh_max)

    ###################
    # POST_PROCESSING
    ###################
    nuclei_object = fill_and_filter_linear_size(nuclei_object, 
                                                hole_min=min_hole_w, 
                                                hole_max=max_hole_w, 
                                                min_size=small_obj_w,
                                                method=fill_filter_method)

    nuclei_labels = label_uint16(nuclei_object)

    return nuclei_labels

Run `_infer_nuclei_fromlabel` function

> ###### 📝 **Uses the same parameters as earlier in the notebook**

In [None]:
###################
# PARAMETERS
###################
nuc_ch = NUC_CH
median_sz = 4   
gauss_sig = 1.34
threshold_factor = 0.9
thresh_min = 0.1
thresh_max = 1.0
min_hole_w = 0
max_hole_w = 25
small_obj_w = 15
fill_filter_method = "3D"

_NU_object = _infer_nuclei_fromlabel(img_data,
                                nuc_ch,
                                median_sz,
                                gauss_sig,
                                threshold_factor,
                                thresh_min,
                                thresh_max,
                                min_hole_w,
                                max_hole_w,
                                small_obj_w,
                                fill_filter_method)

_NU_object.dtype

Run `_infer_nuclei_fromlabel` function

In [None]:
np.all(nuclei_labels == _NU_object)

## **`STEP 2` - Create composite image**

- determine weight to apply to each channel of the intensity image (w# = user input)
- rescale summed image intensities (rescale = user input)

In [None]:
# Creating a composite image

w0 = 0
w1 = 0
w2 = 0
w3 = 3
w4 = 2
w5 = 2
w6 = 0
w7 = 0
w8 = 0
w9 = 0

rescale = True

struct_img_raw = make_aggregate(img_data,
               weight_ch0= w0,
               weight_ch1= w1,
               weight_ch2= w2,
               weight_ch3= w3,
               weight_ch4= w4,
               weight_ch5= w5,
               weight_ch6= w6,
               weight_ch7= w7,
               weight_ch8= w8,
               weight_ch9= w9,
               rescale = rescale)

# # Creating a function to create composite image:
# weights =  [0,0,0,3,3,2]
# struct_img_raw2 = weighted_aggregate(img_data, *weights)

# # use splat so we can also break out the arguments for our napari widget later
# struct_img_raw3 = weighted_aggregate(img_data, 0,0,0,3,3,2)


# # Comfirming the results are the same:
# struct_img_raw[0,0:10,0], struct_img_raw2[0,0:10,0], struct_img_raw3[0,0:10,0]

# ***PRE-PROCESSING prototype - masks***

## **`STEP 3` - Rescale and smooth image**

- rescale intensity of composite image (min=0, max=1)
- median filter (media size = user input)
- gaussian filter (sigma = user input)

In [None]:
med_filter_size = 10
gaussian_smoothing_sigma = 1.34

structure_img_smooth = scale_and_smooth(struct_img_raw,
                                        median_size = med_filter_size, 
                                        gauss_sigma = gaussian_smoothing_sigma)

## **`STEP 4` - Log transform + Scharr edge detection**

- log transform image
- apply scharr edge detection filter 
- combine log image + scharr edge filtered intensity

In [None]:
# log scale the image, apply the scharr edge detection filter to logged image, add the two images together
composite_cellmask = non_linear_cellmask_transform(structure_img_smooth)

# ***CORE PROCESSING prototype - masks***

## **`STEP 5` - Global + local thresholding (AICSSeg – MO)**

- apply MO thresholding method from the Allen Cell [aicssegmentation](https://github.com/AllenCell/aics-segmentation) package (threshold options = user input)

In [None]:
# threshold the composite image after log/edge detection using the MO filter function from aicssegmentation - this applies a global threshold, then a local threshold to produce a semantic segmentation
thresh_method = 'med'
cutoff_size =  150
thresh_adj = 0.3

bw = masked_object_thresh(composite_cellmask, 
                          global_method=thresh_method, 
                          cutoff_size=cutoff_size, 
                          local_adjust=thresh_adj)

# ***POST-PROCESSING prototype - masks***

## **`STEP 6` - Remove small holes and objects**

- fill holes (hole size = user input)
- remove small objects (object size = user input)
- filter method (method = user input)

> ###### 📝 **the size parameters are by convention defined as one dimensional "width", so the inputs to the functions are _squared_ i.e. raised to the power of 2: `**2` for 2D analysis. For volumetric (3D) analysis this would be _cubed_: `**3`.**

In [None]:
hole_min_width = 0
hole_max_width = 50
small_object_width = 45
method = 'slice_by_slice'

cleaned_img2 = fill_and_filter_linear_size(bw, 
                                           hole_min=hole_min_width, 
                                           hole_max=hole_max_width, 
                                           min_size= small_object_width,
                                           method = method)

# ***POST-POST-PROCESSING prototype - masks***

## **`STEP 7` - Select one cellmask/nuclei based on signal**

- label unique cell objects based on watershed seeded from the nuclei objects
- select the single cell with the highest combined fluorescence

In [None]:
# apply a watershed to the inverted image using the nuclei as a seed for each cell
watershed_method = '3D'
cellmask_labels = masked_inverted_watershed(structure_img_smooth, 
                                            nuclei_labels, 
                                            cleaned_img2,
                                            method=watershed_method)

# find the cell with the highest total fluorescence after combining all channels together
keep_label = get_max_label(composite_cellmask, 
                           cellmask_labels)


# combine the above and find the nucleus associated to the highest fluorescence cell
cellmask_out = choose_max_label_cellmask_union_nucleus(structure_img_smooth,
                                                       cleaned_img2, 
                                                       nuclei_labels,
                                                       watershed_method=watershed_method)

cellmask = label_bool_as_uint16(cellmask_out)

In [None]:
cellmask.dtype

Define `infer_cellmask_fromcomposite` function

> ###### 📝 **these functions mainly serve for downstream prototyping in the notebooks. Each step above has an independent function that is implemented in the plugin for easy of use**

In [None]:
##########################
# infer_cellmask_fromaggr
##########################
def _infer_cellmask_fromcomposite(in_img: np.ndarray,
                                  weights: list[int],
                                  nuclei_labels: np.ndarray,
                                  median_sz: int,
                                  gauss_sig: float,
                                  mo_method: str,
                                  mo_adjust: float,
                                  mo_cutoff_size: int,
                                  min_hole_w: int,
                                  max_hole_w: int,
                                  small_obj_w: int,
                                  watershed_method: str
                                  ) -> np.ndarray:
    """
    Procedure to infer cellmask from linear unmixed input.

    Parameters
    ------------
    in_img: 
        a 3d image containing all the channels
    weights:
        a list of int that corresond to the weights for each channel in the composite; use 0 if a channel should not be included in the composite image
    nuclei_labels: 
        a 3d image containing the inferred nuclei labels
    median_sz: 
        width of median filter for _cellmask_ signal
    gauss_sig: 
        sigma for gaussian smoothing of _cellmask_ signal
    mo_method: 
         which method to use for calculating global threshold. Options include:
         "triangle" (or "tri"), "median" (or "med"), and "ave_tri_med" (or "ave").
         "ave" refers the average of "triangle" threshold and "mean" threshold.
    mo_adjust: 
        Masked Object threshold `local_adjust`
    mo_cutoff_size: 
        Masked Object threshold `size_min`
    max_hole_w: 
        hole filling cutoff for cellmask signal post-processing
    small_obj_w: 
        minimu object size cutoff for cellmask signal post-processing
    watershed_method:
        determines if the watershed should be run 'sice-by-slice' or in '3D' 

    Returns
    -------------
    cellmask_mask:
        a logical/labels object defining boundaries of cellmask

    """
    ###################
    # EXTRACT
    ###################
    struct_img = weighted_aggregate(in_img, *weights)

    ###################
    # PRE_PROCESSING
    ###################                         
    struct_img = scale_and_smooth(struct_img,
                                   median_size = median_sz, 
                                   gauss_sigma = gauss_sig)
    

    struct_img_non_lin = non_linear_cellmask_transform(struct_img)

    ###################
    # CORE_PROCESSING
    ###################
    struct_obj = masked_object_thresh(struct_img_non_lin, 
                                      global_method=mo_method, 
                                      cutoff_size=mo_cutoff_size, 
                                      local_adjust=mo_adjust)               

    ###################
    # POST_PROCESSING
    ###################
    struct_obj = fill_and_filter_linear_size(struct_obj, 
                                             hole_min=min_hole_w, 
                                             hole_max=max_hole_w, 
                                             min_size= small_obj_w)

    ###################
    # POST- POST_PROCESSING
    ###################
    cellmask_out = choose_max_label_cellmask_union_nucleus(struct_img, 
                                                           struct_obj, 
                                                           nuclei_labels, 
                                                           watershed_method=watershed_method) 

    return label_bool_as_uint16(cellmask_out)

Run `_infer_cellmask_fromcomposite` function

In [None]:
###################
# PARAMETERS
###################   
weights = [0,0,0,3,3,2]
median_sz = 10
gauss_sig = 1.34
mo_method = "med"
mo_adjust = 0.3
mo_cutoff_size = 150
hole_min_width = 0
hole_max_width = 50
small_obj_w = 45
watershed_method = '3D'

_CM_object = _infer_cellmask_fromcomposite(img_data,
                                            weights,
                                            nuclei_labels,
                                            median_sz,
                                            gauss_sig,
                                            mo_method,
                                            mo_adjust,
                                            mo_cutoff_size,
                                            hole_min_width,
                                            hole_max_width,
                                            small_obj_w,
                                            watershed_method)

_CM_object.dtype

In [None]:
np.all(cellmask == _CM_object)

## **`STEP 8` - Segment cytoplasm**

- mask nuclei with ***cellmask*** to select single ***nucleus***
- erode ***nucleus*** (shrink; *optional*)
- Segment cytoplasm from logical **XOR** of ***nucleus*** and ***cellmask***

In [None]:
# mask the nuclei segmentation with the cellmask to select the single nucleus
nucleus_obj =  apply_mask(nuclei, cellmask) 

# erode nucleus if desired (this likely depends on the type of label used)
nucleus_eroded = binary_erosion(nucleus_obj)

# select the cytoplasmic area (two ways shown here)
cyto_object = np.logical_and(cellmask,~nucleus_eroded)
cyto_object_xor = np.logical_xor(cellmask,nucleus_eroded)

nucleus_out = label_bool_as_uint16(nucleus_obj)
cytoplasm_out = label_bool_as_uint16(cyto_object_xor)

In [None]:
nucleus_out.dtype, cytoplasm_out.dtype

Define `_infer_cytoplasm` function

> ###### 📝 **these functions mainly serve for downstream prototyping in the notebooks. Each step above has an independent function that is implemented in the plugin for easy of use.**

In [None]:
def _infer_cytoplasm(nuclei_object, cellmask,  erode_nuclei = True):
    """
    Procedure to infer cytoplasm from linearly unmixed input.

    Parameters
    ------------
    nuclei_object: 
        a 3d image containing the nuclei signal
    cellmask: 
        a 3d image containing the cellmask signal
    erode_nuclei: 
        should we erode? Default False

    Returns
    -------------
    cytoplasm_mask 
        boolean np.ndarray
      
    """
    nucleus_obj =  apply_mask(nuclei_object, cellmask) 

    if erode_nuclei:
        cytoplasm_mask = np.logical_xor(cellmask, binary_erosion(nucleus_obj))
    else:
        cytoplasm_mask = np.logical_xor(cellmask, nucleus_obj)

    return label_bool_as_uint16(cytoplasm_mask)

Run `_infer_cytoplasm` function

In [None]:
_CY_object = _infer_cytoplasm(nuclei, cellmask, erode_nuclei=True)

_CY_object.dtype

In [None]:
np.all(cytoplasm_out == _CY_object)

# ***EXPORT prototype - masks***

## **`STEP 9` - Stack masks**

- stack masks in order of nucleus, cellmask and cytoplasm mask

In [None]:
stack = stack_masks(nuc_mask=nucleus_out, cellmask=cellmask, cyto_mask=cytoplasm_out)

Export `masks` file to output folder

In [None]:
out_file_n = export_inferred_organelle(stack, "masks", meta_dict, out_data_path)

## **Visualize `nucleus`, `cellmask` and `cytoplasm`**

In [None]:
viewer_masks = napari.Viewer(title = "masks",
                           ndisplay=3)
viewer_masks.grid.enabled = True

In [None]:
viewer_masks.add_image(stack[2].astype(bool),
                      scale = scale,
                      name = 'Cytoplasm')

viewer_masks.add_image(stack[1].astype(bool),
                      scale = scale,
                      name = 'Cellmask')

viewer_masks.add_image(stack[0].astype(bool),
                      scale = scale,
                      name = 'Nucleus')

viewer_masks.reset_view()

nbscreenshot(viewer_masks,
             canvas_only = True)

-------------
### NEXT: INFER LYSOSOME

proceed to [1.2_infer_lysosome.ipynb](./1.2_infer_lysosome.ipynb)