<a href="https://colab.research.google.com/github/fogg-lab/tissue-model-analysis-tools/blob/main/notebooks/cell_elongation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Install tmat in the runtime and import packages

In [None]:
!pip install -I fl_tissue_model_tools@git+https://github.com/fogg-lab/tissue-model-analysis-tools.git#subdirectory=src
!tmat configure /content/fl_tissue_model_tools

In [None]:
import json
import csv
from collections import defaultdict
from ipywidgets import FileUpload, FloatSlider, Layout, interactive
from IPython.display import display
from PIL import Image
import numpy as np
import io
from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.colors import ListedColormap
from skimage.morphology import disk, remove_small_objects
from skimage.exposure import equalize_adapthist, rescale_intensity
from skimage.filters import median, threshold_otsu
from skimage.measure import label, regionprops

from fl_tissue_model_tools.analysis import pixels_to_microns
from fl_tissue_model_tools import script_util as su
from fl_tissue_model_tools.well_mask_generation import generate_well_mask
from fl_tissue_model_tools import helper, defs

## Configuration Constants

## 2. Utilities

In [None]:
# well width in microns
WELL_WIDTH = 1000

def get_well_mask(img: np.ndarray):
    img_equalized = equalize_adapthist(img, clip_limit=0.03)
    img_equalized = rescale_intensity(img_equalized, out_range=(0, 255)).astype(np.uint8)
    well_mask = generate_well_mask(img_equalized, 0.05, return_superellipse_params=True)
    if isinstance(well_mask, tuple):
        well_mask = well_mask[0]
    return well_mask

def get_thresholded_image(img: np.ndarray):
    # apply small amount of smoothing
    med_img = median(img, disk(2))
    # threshold image using Otsu method to find a global threshold value
    thresh_val = threshold_otsu(med_img)
    thresh_img = med_img > thresh_val
    return thresh_img

def label_elongated_cells(img_thresh: np.ndarray, max_circularity: float, min_solidity: float,
                          min_particle_size: int, min_length: int, max_length: int):
    """Identify elongated cells given the binary mask and particle shape parameters.

    Returns:
        np.ndarray: A label image of the elongated cell regions.
        dict: Properties of each elongated cell (length, centroid, and orientation), keyed by label.
    """
    img_thresh_filt = remove_small_objects(img_thresh.astype(np.bool_), min_size=min_particle_size)
    labeled_components = label(img_thresh_filt)
    region_properties = regionprops(labeled_components)

    # Get labeled components
    cell_lengths = []
    regions_to_skip = set()
    for region in region_properties:
        region_circle_area = np.pi * (region.axis_major_length / 2)**2
        circularity = region.area_convex / region_circle_area
        region_length_microns = pixels_to_microns(region.axis_major_length, img_thresh.shape[1],
                                                  WELL_WIDTH)
        if (
            circularity > max_circularity
            or region.euler_number != 1
            or region.solidity < min_solidity
            or region_length_microns < min_length
            or region_length_microns > max_length
        ):
            regions_to_skip.add(region.label)
            label_mask = labeled_components==region.label
            labeled_components[label_mask] = 0

    # Get elongated cells properties
    cell_properties = defaultdict(dict)
    for region in region_properties:
        if region.label in regions_to_skip:
            continue
        cell_properties[region.label]['length'] = region.axis_major_length
        cell_properties[region.label]['width'] = region.axis_minor_length
        cell_properties[region.label]['centroid'] = region.centroid
        cell_properties[region.label]['orientation'] = region.orientation

    return labeled_components, cell_properties

def process_img_and_display(img_rgb: np.ndarray, img_thresh: np.ndarray, max_circularity: float,
                            min_solidity: float, min_particle_size: int, min_length: int,
                            max_length: int):
    labeled_components, cell_properties = label_elongated_cells(img_thresh, max_circularity,
                                                                min_solidity, min_particle_size,
                                                                min_length, max_length)
    cells_mask = (img_thresh & (labeled_components != 0))
    red_overlay = np.zeros((cells_mask.shape) + (3,), dtype=np.uint8)
    red_overlay[..., 0] = cells_mask * 255

    plt.figure(figsize=(8, 8))
    plt.imshow(img_rgb)
    plt.imshow(red_overlay, alpha=0.5, cmap='Reds')
    num_cells = len(cell_properties)
    avg_cell_len = sum(cell['length'] for cell in cell_properties.values()) / max(num_cells, 1)
    avg_cell_len = pixels_to_microns(avg_cell_len, img_rgb.shape[1], 1000)
    plt.title(f"Elongated cells: {num_cells}. Avg length: {avg_cell_len:.2f}µm")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

## 3. Upload an image
**Run the next cell, press the button to upload an image, then run the cell after it.**

In [None]:
upload = FileUpload(accept='image/*', multiple=False)
display(upload)

In [None]:
if upload.value:
    uploaded_file = next(iter(upload.value.values()))
    image = np.array(Image.open(io.BytesIO(uploaded_file['content'])))
    print("Uploaded image:")
    plt.figure(figsize=(4, 4))
    plt.imshow(image, cmap='gray')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("Upload an image with the button under the previous cell."
          "Then re-run this cell.")

## Create Well Mask and Threshold

In [None]:
image_255 = rescale_intensity(image, out_range=(0, 255)).round().astype(np.uint8)
image_rgb = np.stack((image_255,) * 3, axis=-1)
img_thresh = get_thresholded_image(image)
well_mask = get_well_mask(image)
img_thresh_masked = img_thresh & well_mask

fig, axs = plt.subplots(2, 2, figsize=(20, 20))
for ax in axs.flat:
    ax.axis('off')

axs[0, 0].imshow(image, cmap='gray')
axs[0, 0].set_title('Original Image', fontsize=20)
axs[0, 1].imshow(well_mask, cmap='gray')
axs[0, 1].set_title('Well Mask', fontsize=20)
axs[1, 0].imshow(img_thresh, cmap='gray')
axs[1, 0].set_title('Thresholded', fontsize=20)
axs[1, 1].imshow(img_thresh_masked, cmap='gray')
axs[1, 1].set_title('Masked & Thresholded', fontsize=20)

plt.tight_layout()
plt.show()

## Interactive Widget

1. Run the cell
2. Wait for the image to load
3. Drag and release sliders to change the display. It might take 5 or so seconds to update the display after you move and release the slider.

### Settings
- **Max circularity**: This is the most important parameter, controls how much blobness to allow
- **Min solidity**: Not as important. Setting this higher filters out shapes with gaps, crevices or folds
- **Min particle size**: Controls how small of an object to filter out. The unit is pixels
- **Min length**: Minimum length (μm)
- **Max length**: Maximum length (μm)

In [None]:
def update_figure(max_circularity: float, min_solidity: float,
                  min_particle_size: int, min_length: int, max_length: int):
    process_img_and_display(image_rgb, img_thresh_masked, max_circularity, min_solidity,
                            min_particle_size, min_length, max_length)

slider_style = {'description_width': 'initial'}
slider_layout = Layout(width='600px')
particle_size_slider = FloatSlider(value=130, min=30, max=300, step=10,
                                   description='Min particle size:', continuous_update=False,
                                   style=slider_style, layout=slider_layout, readout=True)
max_circularity_slider = FloatSlider(value=0.4, min=0.0, max=1.0, step=0.025,
                                     description='Max circularity:', continuous_update=False,
                                     style=slider_style, layout=slider_layout, readout=True)
min_solidity_slider = FloatSlider(value=0.5, min=0.0, max=1.0, step=0.025,
                                  description='Min solidity:', continuous_update=False,
                                  style=slider_style, layout=slider_layout, readout=True)
min_length_slider = FloatSlider(value=0, min=0, max=200, step=1,
                                description='Min length:', continuous_update=False,
                                style=slider_style, layout=slider_layout, readout=True)
max_length_slider = FloatSlider(value=200, min=0, max=200, step=1,
                                description='Max length:', continuous_update=False,
                                style=slider_style, layout=slider_layout, readout=True)

interactive(update_figure,
            max_circularity=max_circularity_slider,
            min_solidity=min_solidity_slider,
            min_particle_size=particle_size_slider,
            min_length=min_length_slider,
            max_length=max_length_slider)