In [10]:
import numpy as np

import os

import pandas as pd
from scipy import ndimage
import mcubes
import trimesh
from scipy.ndimage import zoom
from plyfile import PlyData, PlyElement 
from auxiliary.data import imaging


import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [11]:
def median_3d_array_tissue(img, disk_size=3):
    """
    Apply a median filter to a 3D image. (Stack by stack)
    :param img: Input 3D image
    :param disk_size: Size of the disk structuring element
    :return: Image with median filter applied
    """
    from skimage import morphology
    
    
    if len(img.shape) == 4:
        img = img[:, :, :, 0]
    # return ndimage.median_filter(img, size=disk_size)
    return morphology.binary_closing(img, morphology.ball(disk_size))
    

def post_process_tissue(mesh):
    """
    Post-process a mesh to ensure it is watertight.
    """
    # if not isinstance(mesh, trimesh.Trimesh):
    #     mesh = trimesh.Trimesh(vertices=mesh['vertices'], faces=mesh['faces'])
    
    mesh.remove_degenerate_faces()
    mesh.remove_duplicate_faces()
    mesh.remove_infinite_values()
    mesh.remove_unreferenced_vertices()
    
    max_iters = 10
    is_watertight = mesh.is_watertight
    while not is_watertight and max_iters > 0:
        mesh.fill_holes()
        
        is_watertight = mesh.is_watertight
        max_iters -= 1
    return mesh


def marching_cubes_tissue(img, metadata, n_faces=10000):
    """
    Generate a mesh from a binary image using the marching cubes algorithm.
    :param img: Input binary 3D image
    :param metadata: Metadata containing the resolution of the image
    :param n_faces: Number of faces of the mesh
    :return: Dictionary containing the vertices, faces and normals of the mesh
    """
    add = 10
    img = median_3d_array_tissue(img, disk_size=5)
    aux = np.zeros(np.array(img.shape) + add, dtype=np.uint8)
    aux[
        add // 2: -add // 2,
        add // 2: -add // 2,
        add // 2: -add // 2
    ] = img
    
    vert, trian = mcubes.marching_cubes(mcubes.smooth(aux), 0)
    assert len(vert) > 0 and len(trian) > 0, 'No mesh was generated'
    
    
    # vert -= vert.mean(axis=0)
    vert *= np.array([metadata['x_res'], metadata['y_res'], metadata['z_res']])
    
    mesh = trimesh.Trimesh(vert, trian, process=False)
    
    mesh = mesh.simplify_quadratic_decimation(n_faces)
    trimesh.smoothing.filter_laplacian(
        mesh, lamb=0.6, iterations=15,
        volume_constraint=False
    )
    
    mesh = post_process_tissue(mesh)
    
    normals = mesh.vertex_normals
    return {
        'vertices': mesh.vertices,
        'faces': mesh.faces,
        'normals': normals
    }
    
    
def run(img_path, path_out, metadata, n_faces=10000):
    img = imaging.read_image(img_path, axes='XYZ', verbose=1)

    # Convert binary 
    img = img > 0
    
    mesh_data = marching_cubes_tissue(img, metadata, n_faces)
    assert mesh_data is not None, 'No mesh was generated'
    
    vertices = mesh_data['vertices']
    faces = mesh_data['faces']
    normals = mesh_data['normals']
    
    vertex_dtype = [
        ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
        ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4')
    ]
    vertex_data = np.empty(len(vertices), dtype=vertex_dtype)
    vertex_data['x'] = vertices[:, 0]
    vertex_data['y'] = vertices[:, 1]
    vertex_data['z'] = vertices[:, 2]
    vertex_data['nx'] = normals[:, 0]
    vertex_data['ny'] = normals[:, 1]
    vertex_data['nz'] = normals[:, 2]
    
    face_dtype = [('vertex_indices', 'i4', (3,))]
    face_data = np.empty(len(faces), dtype=face_dtype)
    face_data['vertex_indices'] = faces
    
    # Create PlyElement objects
    vertex_element = PlyElement.describe(vertex_data, 'vertex')
    face_element = PlyElement.describe(face_data, 'face')
    
    # Write the PLY file using plyfile
    mesh = PlyData([vertex_element, face_element], text=True)
    
    mesh.write(path_out)
    print('Mesh saved to:', path_out)

In [12]:
img_path = "/run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/Segmentation/Segmentation/E1.tif"
path_out = img_path.replace('.tif', '_mesh.ply')

metadata = {
    'x_res' : 0.3786029, 
    'y_res' : 0.3786029, 
    'z_res' : 0.9999284
}

run(img_path, path_out, metadata)

[94mReading TIFF[0m: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/Segmentation/Segmentation/E1.tif
Mesh saved to: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/Segmentation/Segmentation/E1_mesh.ply


In [13]:
from meshes.utils.mesh_reconstruction import marching_cubes

def run(seg_path, path_out, metadata):
    seg = imaging.read_image(seg_path, axes='XYZ', verbose=1)
    mesh_data = marching_cubes(seg, metadata)

    vertices = mesh_data['vertices']
    faces = mesh_data['faces']
    normals = mesh_data['normals']
    vertex_cell_ids = mesh_data['vertex_cell_ids']
    face_cell_ids = mesh_data['face_cell_ids']

    # Prepare structured arrays for vertices and faces
    vertex_dtype = [
        ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
        ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
        ('cell_id', 'i4')
    ]
    vertex_data = np.empty(len(vertices), dtype=vertex_dtype)
    vertex_data['x'] = vertices[:, 0]
    vertex_data['y'] = vertices[:, 1]
    vertex_data['z'] = vertices[:, 2]
    vertex_data['nx'] = normals[:, 0]
    vertex_data['ny'] = normals[:, 1]
    vertex_data['nz'] = normals[:, 2]
    vertex_data['cell_id'] = vertex_cell_ids

    face_dtype = [('vertex_indices', 'i4', (3,)), ('cell_id', 'i4')]
    face_data = np.empty(len(faces), dtype=face_dtype)
    face_data['vertex_indices'] = faces
    face_data['cell_id'] = face_cell_ids

    # Create PlyElement objects
    vertex_element = PlyElement.describe(vertex_data, 'vertex')
    face_element = PlyElement.describe(face_data, 'face')

    # Write the PLY file using plyfile
    PlyData([vertex_element, face_element], text=True).write(path_out)
    
    print('Mesh saved to:', path_out)
    
seg_path = "/run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/CellposeDenoise/Results/Filtered(tissue)/E1_MeisDKO.tif"

path_out = seg_path.replace('.tif', '_mesh.ply')

metadata = {
    'x_res' : 0.3786029, 
    'y_res' : 0.3786029, 
    'z_res' : 0.9999284
}

run(seg_path, path_out, metadata)

[94mReading TIFF[0m: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/CellposeDenoise/Results/Filtered(tissue)/E1_MeisDKO.tif
Mesh saved to: /run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/CellposeDenoise/Results/Filtered(tissue)/E1_MeisDKO_mesh.ply


In [21]:
from meshes.utils.features.extractor import MeshFeatureExtractor

tissue_path = "/run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/Segmentation/Segmentation/C3_mesh.ply"

cells_path = "/run/user/1003/gvfs/smb-share:server=tierra.cnic.es,share=sc/LAB_MT/LAB/Oscar/PROJECT_transition_zone/Paper/Data/Colum_FGFinh_MeisDKO/MeisdKO_WT_BCat_Columnarity/test_plantseg/CellposeDenoise/Results/Filtered(tissue)/C3_EZ_mesh.ply"

tissue_mesh = trimesh.load_mesh(tissue_path)
cells_mesh = trimesh.load_mesh(cells_path)
# feature = pd.read_csv(cells_path.replace('_mesh.ply', '_features.csv'))

extractor = MeshFeatureExtractor(cells_mesh, tissue_mesh)
features = extractor.extract()

features.to_csv(cells_path.replace('_mesh.ply', '_features.csv'), index=False)

In [22]:
from matplotlib.colors import LinearSegmentedColormap, BoundaryNorm
from scipy.spatial import cKDTree
from meshes.utils.features.operator import build_face_adjacency_csr_matrix, get_centroid, dijkstra


class CellMap:
    def __init__(self, tissue_mesh_path, cells_mesh_path, features):
        self.tissue_path = tissue_mesh_path
        self.cells_path = cells_mesh_path
        self.tissue_mesh = trimesh.load_mesh(tissue_mesh_path)
        self.cells_mesh = trimesh.load_mesh(cells_mesh_path)
        self.features = features
        
        self.face_cell_ids = self.cells_mesh.metadata['_ply_raw']['face']['data']['cell_id']
        self.vertex_cell_ids = self.cells_mesh.metadata['_ply_raw']['vertex']['data']['cell_id']
        self.cell_ids = features['cell_id'].unique()
        
        self.tissue_mesh_centroids = tissue_mesh.triangles_center
        self.tissue_face_tree = cKDTree(self.tissue_mesh_centroids)
        self.tissue_vertices_tree = cKDTree(tissue_mesh.vertices)
        
        self.tissue_graph = build_face_adjacency_csr_matrix(tissue_mesh)
        
        cell_centroids = {}

        for cell_id in self.cell_ids:
            try:
                cell_centroids[cell_id], _ = get_centroid(
                    self.cells_mesh, cell_id,
                    self.face_cell_ids, self.vertex_cell_ids
                )
            except Exception as e:
                print(f'Error in cell: {cell_id} - {e}')
        self.cell_centroids = cell_centroids
        
    def map_cells(self):
        centroid_array = np.array(list(self.cell_centroids.values()))
        cell_tree = cKDTree(centroid_array)
        
        _, closest_cell_indices = cell_tree.query(self.tissue_mesh_centroids, k=1)
        tissue_face_cell_ids = self.cell_ids[closest_cell_indices]
        
        mapping = pd.DataFrame({
            'tissue_face_id': np.arange(len(self.tissue_mesh.faces)),
            'cell_id': tissue_face_cell_ids
        })
        
        self.mapping = mapping
        
    def get_neighbours(self, radius=50):
        face_neighbours = {}
        
        for face_idx in range(len(self.tissue_mesh.faces)):
            dist = dijkstra(
                csgraph=self.tissue_graph, directed=False,
                indices=face_idx, return_predecessors=True,
                min_only=True
            )[0]
            face_neighbours[face_idx] = np.where(dist <= radius)[0]
            
        new_cols = pd.DataFrame({
            'tissue_face_id': np.array(list(face_neighbours.keys()), dtype=int),
            f'tissue_neighbors': face_neighbours.values()
        })
        
        self.mapping['tissue_face_id'] = self.mapping['tissue_face_id'].astype(int)
        
        self.mapping = self.mapping.merge(new_cols, on='tissue_face_id', how='left')
        self.mapping.to_csv(self.cells_path.replace('.ply', '_map.csv'), index=False)
        
    def color_mesh(self, feature_name):
        feature_map = self.features.set_index('cell_id')[feature_name].to_dict()
        face_values = self.mapping[f'cell_id'].map(feature_map)

        # Neighbor averaging
        aux_face_values = face_values.copy()
        for i, row in self.mapping.iterrows():
            if row[f'tissue_neighbors'] is not None:
                try:
                    neigh = np.array(row[f'tissue_neighbors'])
                    neighbors = np.array(neigh)
                    neighbor_values = [face_values[int(n)] for n in neighbors if not np.isnan(face_values[int(n)])]
                    if neighbor_values:
                        aux_face_values[i] = np.mean(neighbor_values)

                except Exception as e:
                    neigh = row[f'tissue_neighbors'].replace('[', '').replace(']', '').split()
                    neighbors = np.array(neigh)
                    neighbor_values = [face_values[int(n)] for n in neighbors if not np.isnan(face_values[int(n)])]
                    if neighbor_values:
                        aux_face_values[i] = np.mean(neighbor_values)


        face_values = aux_face_values

        colors = [
            (0, 0, 1),  # Pure blue
            (0, 0.5, 1),  # Cyan-like
            (0, 1, 0),  # Green
            (1, 1, 0),  # Yellow
            (1, 0, 0),  # Red
        ]
        cmap = LinearSegmentedColormap.from_list('custom_jet', colors, N=2048)

        norm = BoundaryNorm(
            boundaries=np.linspace(
                face_values.min(), face_values.max(),
                cmap.N
            ), ncolors=cmap.N
        )
        face_colors = cmap(norm(face_values))
        # face_colors = (face_colors * 255).astype(np.uint8)

        self.tissue_mesh.visual.face_colors = face_colors
        self.tissue_mesh.export(self.tissue_path.replace('.ply', f'_{feature_name}.ply'))

        face_values = pd.DataFrame({
            'tissue_face_id': np.arange(len(face_values)),
            'value': face_values
        })
        face_values.to_csv(self.cells_path.replace('.ply', f'_values_{feature_name}.csv'), index=False)

        return self.tissue_mesh
    
    def run(self, feature_name, radius=50):
        self.map_cells()
        self.get_neighbours(radius)
        return self.color_mesh(feature_name)
    
cell_map = CellMap(tissue_path, cells_path, features)
colored = cell_map.run('columnarity', radius=50)
# colored.export(tissue_path.replace('.ply', '_columnarity.ply'))

Error in cell: 140 - Cell ID 140 not found in the mesh
Error in cell: 234 - Cell ID 234 not found in the mesh
Error in cell: 316 - Cell ID 316 not found in the mesh
Error in cell: 407 - Cell ID 407 not found in the mesh
Error in cell: 473 - Cell ID 473 not found in the mesh
Error in cell: 487 - Cell ID 487 not found in the mesh
Error in cell: 557 - Cell ID 557 not found in the mesh
Error in cell: 632 - Cell ID 632 not found in the mesh
Error in cell: 698 - Cell ID 698 not found in the mesh
Error in cell: 706 - Cell ID 706 not found in the mesh
Error in cell: 821 - Cell ID 821 not found in the mesh
Error in cell: 933 - Cell ID 933 not found in the mesh
Error in cell: 957 - Cell ID 957 not found in the mesh
Error in cell: 1046 - Cell ID 1046 not found in the mesh
Error in cell: 1181 - Cell ID 1181 not found in the mesh
Error in cell: 1282 - Cell ID 1282 not found in the mesh
Error in cell: 1430 - Cell ID 1430 not found in the mesh
Error in cell: 1453 - Cell ID 1453 not found in the mesh
