## SAM segmentation, meristem delineation, size and curvature estimation

### Necessary imports

In [None]:
import os
from time import time as current_time

import numpy as np
import scipy.ndimage as nd
import pandas as pd
import matplotlib.pyplot as plt

import vtk

from timagetk.components import SpatialImage, LabelledImage
from timagetk.io import imread, imsave
from timagetk.plugins import h_transform, region_labeling, segmentation
from timagetk.plugins.resampling import isometric_resampling
from timagetk.plugins import labels_post_processing

from cellcomplex.utils.array_dict import array_dict

from cellcomplex.property_topomesh.creation import triangle_topomesh, vertex_topomesh
from cellcomplex.property_topomesh.analysis import compute_topomesh_property, compute_topomesh_vertex_property_from_faces, topomesh_property_median_filtering, topomesh_property_gaussian_filtering
from cellcomplex.property_topomesh.optimization import property_topomesh_vertices_deformation
from cellcomplex.property_topomesh.extraction import property_filtering_sub_topomesh, topomesh_connected_components
from cellcomplex.property_topomesh.io import read_ply_property_topomesh, save_ply_property_topomesh

from cellcomplex.property_topomesh.visualization.vtk_actor_topomesh import VtkActorTopomesh

from tissue_nukem_3d.microscopy_images.read_microscopy_image import read_czi_image, read_tiff_image
from tissue_nukem_3d.nuclei_mesh_tools import nuclei_image_surface_topomesh, up_facing_surface_topomesh

from tissue_analysis.property_spatial_image import PropertySpatialImage

from tissue_paredes.utils.tissue_image_tools import pseudo_gradient_norm

from visu_core.matplotlib import glasbey
from visu_core.vtk.display import vtk_display_notebook, vtk_save_screenshot_actors

### Function to visualize segmented images

In [None]:
def labelled_image_projection(seg_img, axis=2, direction=1, background_label=1, return_coords=False):
    """Compute the 2D projection of a labelled image along the specified  axis

    Parameters
    ----------
    seg_img : LabelledImage
        Labelled image to project in 2D
    axis : int
        The axis along which to project the labelled image
    direction : int
        On which side of the image to project (-1 or 1)
    background_label
        Ignored value for projection

    Returns
    -------
    np.ndarray
        2D Projected labelled image

    """
    xxx, yyy, zzz = np.mgrid[0:seg_img.shape[0], 0:seg_img.shape[1], 0:seg_img.shape[2]].astype(float)

    if axis == 0:
        y = np.arange(seg_img.shape[1])
        z = np.arange(seg_img.shape[2])
        yy, zz = map(np.transpose, np.meshgrid(y, z))
        proj = xxx * (seg_img.get_array() != background_label)
    elif axis == 1:
        x = np.arange(seg_img.shape[0])
        z = np.arange(seg_img.shape[2])
        xx, zz = map(np.transpose, np.meshgrid(x, z))
        proj = yyy * (seg_img.get_array() != background_label)
    elif axis == 2:
        x = np.arange(seg_img.shape[0])
        y = np.arange(seg_img.shape[1])
        xx, yy = map(np.transpose, np.meshgrid(x, y))
        proj = zzz * (seg_img.get_array() != background_label)

    proj[proj == 0] = np.nan 
    if direction == 1:
        proj = np.nanmax(proj, axis=axis)
        proj[np.isnan(proj)] = seg_img.shape[axis] - 1
    elif direction == -1:
        proj = np.nanmin(proj, axis=axis)
        proj[np.isnan(proj)] = 0

    if axis == 0:
        xx = proj
    elif axis == 1:
        yy = proj
    elif axis == 2:
        zz = proj

    # coords = tuple(np.transpose(np.concatenate(np.transpose([xx, yy, zz], (1, 2, 0)).astype(int))))
    coords = tuple(np.transpose(np.concatenate(np.transpose([xx, yy, zz], (1, 2, 0)).astype(int))))
    projected_img = seg_img.get_array()[coords].reshape(xx.shape)
    
    if return_coords:
        return projected_img, coords
    else:
        return projected_img

### Parameters

In [None]:
 dirname = '/Users/gcerutti/Projects/RDP/LayerHeight_Marketa/test/'
#dirname = '/home/carlos/Documents/WORK/Image analysis/20160625 MS-E3 CLV-DR5/LD/RAW'
filename = 'E_set_blue_SAM02'
#filename = 'CLV3-CH-DR5-3VE-MS-E3-LD-SAM3'
#filename = 'CLV3-CH-DR5-3VE-MS-E3-LD-SAM5'

# image parameters
channel_names = ['PI']
#channel_names = ['CLV3','DR5','PI']
membrane_channel = 'PI'
#signal_channels = ['CLV3','DR5']
signal_channels = []
microscope_orientation = 1

# segmentation parameters
recompute_segmentation = False
gaussian_sigma = 0.75
h_min = 200
segmentation_gaussian_sigma = 0.5
volume_threshold = 100000

# SAM surface extraction parameters
recompute_surface = False
resampling_voxelsize = 1.
gaussian_curvature_threshold = -2.5e-4
min_curvature_threshold = -1e-3

# Signal quantification parameters
quantification_erosion_radius = 1.

# Visualization parameters
subsampling = 2
figure_size = 8

### Preparing folder architecture

In [None]:
if not os.path.exists(dirname + "/" + filename + "/"):
    os.makedirs(dirname + "/" + filename + "/")

### Reading the microscopy image

In [None]:
microscopy_filename = dirname + '/' + filename + '.czi'
img_dict = read_czi_image(microscopy_filename, channel_names)
if len(channel_names) == 1:
    img_dict = {channel_names[0]:img_dict}

In [None]:
figure = plt.figure(0)
figure.clf()

figure.set_size_inches(figure_size,figure_size)

figure.gca().imshow(img_dict[membrane_channel].get_array().max(axis=2),cmap='gray',interpolation='none')
figure.gca().axis('off')

figure.tight_layout()

### Auto-seeded watershed segmentation

In [None]:
segmentation_filename = dirname + '/' + filename + '/' + filename + '_' + membrane_channel + '_seg.tif'

if recompute_segmentation or not os.path.exists(segmentation_filename):
    img = img_dict[membrane_channel]
    voxelsize = np.array(img.voxelsize)

    smooth_image = nd.gaussian_filter(img, sigma=gaussian_sigma / voxelsize).astype(img.dtype)
    smooth_img = SpatialImage(smooth_image, voxelsize=voxelsize)

    ext_img = h_transform(smooth_img, h=h_min, method='min')

    seed_img = region_labeling(ext_img, low_threshold=1, high_threshold=h_min, method='connected_components')

    seg_smooth_image = nd.gaussian_filter(img, sigma=segmentation_gaussian_sigma / voxelsize).astype(img.dtype)
    seg_smooth_img = SpatialImage(seg_smooth_image, voxelsize=voxelsize)

    seg_img = segmentation(seg_smooth_img, seed_img, control='first', method='seeded_watershed')
    
    seg_volumes = dict(zip(np.arange(seg_img.max()) + 1, 
                           nd.sum(np.prod(voxelsize) * np.ones_like(seg_img),
                                  seg_img, 
                                  index = np.arange(seg_img.max())+1)))

    labels_to_remove = np.array(list(seg_volumes.keys()))[np.array(list(seg_volumes.values())) > volume_threshold]
    print("--> Removing too large labels :",labels_to_remove)
    for l in labels_to_remove:
        seg_img[seg_img == l] = 1
    
    imsave(segmentation_filename, seg_img)
else:
    seg_img = read_tiff_image(segmentation_filename)


In [None]:
figure = plt.figure(1)
figure.clf()

figure.set_size_inches(13.9,13.9)

projected_seg_img, projection_coords = labelled_image_projection(seg_img[::subsampling,::subsampling,::subsampling], return_coords=True, direction=microscope_orientation)
projected_cell_contours = pseudo_gradient_norm(LabelledImage(projected_seg_img,voxelsize=seg_img.voxelsize[:2],no_label_id=0),wall_sigma=0)
        
figure.gca().imshow(projected_seg_img%256,cmap='glasbey',vmin=0,vmax=255,interpolation='none')
figure.gca().contourf(projected_cell_contours,[254,256],colors=['k'])
figure.gca().axis('off')

figure.tight_layout()
figure.savefig(dirname + '/' + filename + '/' + filename + '_' + membrane_channel + '_segmented_projection.png')

figure.set_size_inches(figure_size,figure_size)

### Surface extraction

In [None]:
binary_img = SpatialImage(nd.gaussian_filter(255*(seg_img>1),sigma=1./np.array(seg_img.voxelsize)).astype(np.uint8),voxelsize=seg_img.voxelsize)

if resampling_voxelsize is not None:
    binary_img = isometric_resampling(binary_img, method=resampling_voxelsize, option='linear')
else:
    resampling_voxelsize = np.array(binary_img.voxelsize)

In [None]:
surface_filename = dirname + '/' + filename + '/' + filename + '_' + membrane_channel + '_surface.ply'

curvature_properties = ['mean_curvature','gaussian_curvature','principal_curvature_min','principal_curvature_max']
    
if recompute_surface or not os.path.exists(surface_filename):
    topomesh = nuclei_image_surface_topomesh(binary_img,
                                             nuclei_sigma=resampling_voxelsize,
                                             density_voxelsize=resampling_voxelsize,
                                             maximal_length=10*resampling_voxelsize,
                                             intensity_threshold=64,
                                             decimation=100)

    topomesh = up_facing_surface_topomesh(topomesh, normal_method='orientation', upwards=microscope_orientation==1)

    start_time = current_time()
    print("--> Computing mesh curvature")
    compute_topomesh_vertex_property_from_faces(topomesh,'normal',neighborhood=3,adjacency_sigma=1.2)

    compute_topomesh_property(topomesh,'mean_curvature',2)
    for property_name in curvature_properties:
        compute_topomesh_vertex_property_from_faces(topomesh,property_name,neighborhood=3,adjacency_sigma=1.2)
    print("<-- Computing mesh curvature [", current_time() - start_time, "s]")

    topomesh.update_wisp_property('barycenter_z',0,dict(zip(topomesh.wisps(0),topomesh.wisp_property('barycenter',0).values(list(topomesh.wisps(0)))[:,2])))

    properties_to_save = {}
    properties_to_save[0] = curvature_properties + ['normal','barycenter_z']
    properties_to_save[1] = []
    properties_to_save[2] = curvature_properties + ['normal','area']
    properties_to_save[3] = []

    save_ply_property_topomesh(topomesh,surface_filename,properties_to_save)
else:
    topomesh = read_ply_property_topomesh(surface_filename)

In [None]:
actors = []

surface_actor = VtkActorTopomesh(topomesh,2,property_name='principal_curvature_min',property_degree=0)
surface_actor.update(colormap='RdBu_r',value_range=(-0.1,0.1),opacity=1)
actors += [surface_actor]

vtk_d(actors,focal_point=(0,0,-1),view_up=(-1,0,0))

figure = plt.figure(2)
figure.clf()

figure.set_size_inches(13.9,13.9)

figure.gca().imshow(actor_image[::-1])
figure.gca().axis('off')

figure.tight_layout()
figure.savefig(dirname + '/' + filename + '/' + filename + '_surface_min_curvature.png')

figure.set_size_inches(figure_size,figure_size)

### Curvature-based SAM delineation

In [None]:
# convex_topomesh = property_filtering_sub_topomesh(topomesh,'gaussian_curvature',2,(gaussian_curvature_threshold,10))
convex_topomesh = property_filtering_sub_topomesh(topomesh,'principal_curvature_min',2,(min_curvature_threshold,10))
compute_topomesh_property(convex_topomesh,'area',2)

convex_topomeshes = topomesh_connected_components(convex_topomesh,degree=2)
mesh_scores = []
for i_mesh, mesh in enumerate(convex_topomeshes):
    mesh_area = mesh.wisp_property('area',2).values(list(mesh.wisps(2))).sum()
    mesh_center = mesh.wisp_property('barycenter',0).values(list(mesh.wisps(0))).mean(axis=0)
    
    mesh_scores += [np.sqrt(mesh_area) / np.linalg.norm((mesh_center - np.array(seg_img.extent)/2.)[:2])]
meristem_topomesh = convex_topomeshes[np.argmax(mesh_scores)]

meristem_area = meristem_topomesh.wisp_property('area',2).values(list(meristem_topomesh.wisps(2))).sum()
meristem_radius = np.sqrt(meristem_area/np.pi)

meristem_curvature = np.nanmedian(meristem_topomesh.wisp_property('mean_curvature',2).values(list(meristem_topomesh.wisps(2))))

In [None]:
actors = []

# actor = VtkActorTopomesh(topomesh,2,property_name='gaussian_curvature',property_degree=0)
surface_actor = VtkActorTopomesh(topomesh,2,property_name='principal_curvature_min',property_degree=0)
# actor.update(colormap='RdBu_r',value_range=(-0.002,0.002))
surface_actor.update(colormap='RdBu_r',value_range=(-0.1,0.1),opacity=0.1)
# surface_actor.update(colormap='RdBu_r',value_range=(-0.1,0.1),opacity=1)
actors += [surface_actor.actor]

# actor = VtkActorTopomesh(topomesh,2,property_name='gaussian_curvature',property_degree=0)
actor = VtkActorTopomesh(meristem_topomesh,2,property_name='principal_curvature_min',property_degree=0)
# actor.update(colormap='RdBu_r',value_range=(-0.002,0.002))
actor.update(colormap='RdBu_r',value_range=(-0.1,0.1))
actors += [actor.actor]

actor_image = vtk_image_actors(actors,focal_point=(0,0,-1),view_up=(-1,0,0))

figure = plt.figure(3)
figure.clf()

figure.set_size_inches(13.9,13.9)

figure.gca().imshow(actor_image[::-1])
figure.gca().axis('off')

figure.tight_layout()
figure.savefig(dirname + '/' + filename + '/' + filename + '_meristem_surface.png')

figure.set_size_inches(figure_size,figure_size)

### Meristem projection on the segmented image

In [None]:
p_img = PropertySpatialImage(seg_img)
p_img.compute_image_property('layer')

In [None]:
cell_center = p_img.image_property('barycenter')
cell_layer = p_img.image_property('layer')

cell_points = cell_center.values(p_img.labels)
meristem_points = meristem_topomesh.wisp_property('barycenter',0).values(list(meristem_topomesh.wisps(0)))
cell_meristem_distances = np.linalg.norm(cell_points[:,np.newaxis] - meristem_points[np.newaxis],axis=2).min(axis=1)
cell_meristem_distances = array_dict(cell_meristem_distances,keys=p_img.labels)

cell_meristem = {c:(cell_meristem_distances[c]<5.*resampling_voxelsize and cell_layer[c]==1) for c in p_img.labels}
p_img.update_image_property('meristem',cell_meristem)

def image_property_morphological_operation(p_img, property_name, method='erosion', iterations=1, layer_restriction=1):
    labels = list(p_img.labels)
    if layer_restriction is not None:
        layer_labels = [l for l in labels if p_img.image_property('layer')[l]==layer_restriction]
    else:
        layer_labels = labels
    
    label_neighbors = [n for c in labels for n in p_img._image_graph.neighbors(c) if n in layer_labels]
    label_neighbor_labels = [c for c in labels for n in p_img._image_graph.neighbors(c) if n in layer_labels]
    
    for iteration in range(iterations):
        label_neighbor_properties = p_img.image_property(property_name).values(labels+label_neighbors)
        if method == 'erosion':
            morpho_property = nd.minimum(label_neighbor_properties!=0, labels+label_neighbor_labels,index=labels)
        elif method == 'dilation':
            morpho_property = nd.maximum(label_neighbor_properties!=0, labels+label_neighbor_labels,index=labels)
        elif method == 'opening':
            morpho_property = nd.minimum(label_neighbor_properties!=0, labels+label_neighbor_labels, index=labels)
            label_neighbor_properties = array_dict(morpho_property,keys=labels).values(labels+label_neighbors)
            morpho_property = nd.maximum(label_neighbor_properties!=0, labels+label_neighbor_labels, index=labels)
        elif method == 'closing':
            morpho_property = nd.maximum(label_neighbor_properties!=0, labels+label_neighbor_labels, index=labels)
            label_neighbor_properties = array_dict(morpho_property,keys=labels).values(labels+label_neighbors)
            morpho_property = nd.minimum(label_neighbor_properties!=0, labels+label_neighbor_labels, index=labels)
        else:
            morpho_property = p_img.image_property(property_name).values(labels)
        
        if layer_restriction is not None:
            morpho_property *= p_img.image_property('layer').values(labels) == layer_restriction
        p_img.update_image_property(property_name,dict(zip(labels,morpho_property)))

image_property_morphological_operation(p_img,'meristem',method='closing',iterations=3)

In [None]:
meristem_img = p_img.create_property_image('meristem', background_value=-1)+1

figure = plt.figure(4)
figure.clf()

figure.set_size_inches(13.9,13.9)

projected_meristem_img = meristem_img[::subsampling,::subsampling,::subsampling][projection_coords].reshape(projected_seg_img.shape)

figure.gca().imshow(projected_meristem_img,cmap='glasbey',vmin=0,vmax=255,interpolation='none')
figure.gca().contourf(projected_cell_contours,[254,256],colors=['k'])
figure.gca().axis('off')

figure.tight_layout()
figure.savefig(dirname + '/' + filename + '/' + filename + '_meristem_cells.png')

figure.set_size_inches(figure_size,figure_size)

### Meristem cells property computation

In [None]:
p_img.compute_image_property('volume')

meristem_cells = [c for c in p_img.labels if p_img.image_property('meristem')[c]]
meristem_cell_average_volume = np.nanmean([p_img.image_property('volume')[c] for c in meristem_cells])

In [None]:
l1_labels = [c for c in p_img.labels if p_img.image_property('layer')[c] == 1]
l1_cell_points = np.array([p_img.image_property('barycenter')[c] for c in l1_labels])
surface_points = topomesh.wisp_property('barycenter',0).values(list(topomesh.wisps(0)))
l1_surface_distances = np.linalg.norm(l1_cell_points[:,np.newaxis] - surface_points[np.newaxis],axis=2)
l1_cell_surface_vertex = np.array(list(topomesh.wisps(0)))[np.argmin(l1_surface_distances,axis=1)]

for property_name in curvature_properties:
    l1_cell_curvature = dict(zip(l1_labels,topomesh.wisp_property(property_name,0).values(l1_cell_surface_vertex)))
    p_img.update_image_property(property_name,{c:l1_cell_curvature[c] if c in l1_labels else np.nan for c in p_img.labels})

In [None]:
curvature_img = p_img.create_property_image('mean_curvature', dtype=float, background_value=0)

figure = plt.figure(4)
figure.clf()

figure.set_size_inches(13.9,13.9)

projected_curvature_img = curvature_img[::subsampling,::subsampling,::subsampling][projection_coords].reshape(projected_seg_img.shape)

figure.gca().imshow(projected_curvature_img,cmap='RdBu_r',vmin=-0.1,vmax=0.1,interpolation='none')
figure.gca().contourf(projected_cell_contours,[254,256],colors=['k'])
figure.gca().contour(projected_meristem_img,[1],colors=['r'],linewidths=[3])
figure.gca().axis('off')

figure.tight_layout()
figure.savefig(dirname + '/' + filename + '/' + filename + '_meristem_cell_curvature.png')

figure.set_size_inches(figure_size,figure_size)

### Cell signal quantification

In [None]:
eroded_seg_img = labels_post_processing(seg_img, method='erosion', radius=quantification_erosion_radius, iterations=1)

for i_s, signal_name in enumerate(signal_channels):
    signal_img = img_dict[signal_name]

    cell_total_signals = nd.sum(signal_img.get_array().astype(float),eroded_seg_img.get_array(),index=p_img.labels)
    cell_volumes = nd.sum(np.ones_like(eroded_seg_img.get_array()),eroded_seg_img.get_array(),index=p_img.labels)
    cell_signals = dict(zip(p_img.labels,cell_total_signals/cell_volumes))
    p_img.update_image_property(signal_name,cell_signals)

In [None]:
figure = plt.figure(5)
figure.clf()

figure.set_size_inches(2*13.9,len(signal_channels)*13.9)

for i_s, signal_name in enumerate(signal_channels):
    
    figure.add_subplot(len(signal_channels),2,2*i_s + 1)
    figure.gca().imshow(img_dict[signal_name].get_array().max(axis=2),cmap='gray',interpolation='none')
    figure.gca().set_title(signal_name+" Image",size=24)
    figure.gca().axis('off')
    
    signal_img = p_img.create_property_image(signal_name, dtype=img_dict[signal_name].dtype, background_value=0)
    projected_signal_img = signal_img[::subsampling,::subsampling,::subsampling][projection_coords].reshape(projected_seg_img.shape)

    figure.add_subplot(len(signal_channels),2,2*i_s + 2)
    figure.gca().imshow(projected_signal_img,cmap='gray',vmin=img_dict[signal_name].min(),vmax=img_dict[signal_name].max()/2,interpolation='none')
    figure.gca().contourf(projected_cell_contours,[254,256],colors=['k'])
    figure.gca().contour(projected_meristem_img,[1],colors=['r'],linewidths=[3])
    figure.gca().set_title("Quantified "+signal_name,size=24)
    figure.gca().axis('off')
    

figure.tight_layout()
figure.savefig(dirname + '/' + filename + '/' + filename + '_meristem_cell_signals.png')

figure.set_size_inches(2*figure_size,len(signal_channels)*figure_size)
figure.tight_layout()

### Cell data export

In [None]:
def property_image_to_dataframe(p_img, labels=None):

    cell_labels = list(p_img.labels)
    if labels is not None:
        labels = list(set(cell_labels) & set(list(labels)))
    else:
        labels = cell_labels

    dataframe = pd.DataFrame()
    dataframe['id'] = np.array(list(labels))
    dataframe['label'] = np.array(list(labels))

    for property_name in p_img.image_property_names():
        if np.array(p_img.image_property(property_name).values()[0]).ndim == 0:
            #print "  --> Adding column ",property_name
            dataframe[property_name] = np.array([p_img.image_property(property_name)[v] for v in labels])
        elif property_name == 'barycenter':
            for i, axis in enumerate(['x','y','z']):
                dataframe[property_name+"_"+axis] = np.array([p_img.image_property(property_name)[v][i] for v in labels])

    dataframe = dataframe.set_index('id')
    dataframe.index.name = None

    return dataframe

cell_df = property_image_to_dataframe(p_img)

cell_data_filename = dirname + '/' + filename + '/' + filename + '_cell_data.csv'
cell_df.to_csv(cell_data_filename,index=False)

### SAM data export

In [None]:
meristem_data = {}
meristem_data['filename'] = [filename]
meristem_data['area'] = [meristem_area]
meristem_data['mean_curvature'] = [meristem_curvature]
meristem_data['L1_cells'] = [len(meristem_cells)]
meristem_data['L1_cell_volume'] = [meristem_cell_average_volume]
meristem_df = pd.DataFrame().from_dict(meristem_data)

print(meristem_df)
meristem_data_filename = dirname + '/' + filename + '/' + filename + '_meristem_data.csv'
meristem_df.to_csv(meristem_data_filename,index=False)