In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

## 3D meshing example

This notebook shows how to

1. Load and visualize a volume
1. Apply image filters to and segment image
1. Generate a 3D surface mesh from the binary image
1. Decimate/simplify a 3D mesh
1. Visualize and save the mesh to gmsh22 format

### Todo:

- Generate 3D volume mesh
- Surface mesh must be 'watertight'
- Constrained Delaunay triangulation to fix surface (do not go outside surface), i.e.
https://notebook.community/daniel-koehn/Theory-of-seismic-waves-II/02_Mesh_generation/4_Tri_mesh_delaunay_yigma_tepe,
https://wias-berlin.de/software/tetgen/1.5/doc/manual/manual002.html#sec7

In [None]:
from nanomesh.volume import Volume
import pyvista as pv
from skimage import filters

In [None]:
vol = Volume.load('sample_data.npy')
vol_gauss = vol.apply(filters.gaussian, sigma=5)
thresh = filters.threshold_li(vol_gauss.image)

seg_vol = Volume(1.0 * (vol_gauss.image >= thresh))
seg_vol.show_slice()

### Generate 3d tetragonal mesh

In [None]:
import numpy as np

def simplify_mesh_trimesh(vertices, faces, n_faces):
    """Simplify mesh using trimesh."""
    import trimesh
    mesh = trimesh.Trimesh(vertices=verts, faces=faces)
    decimated = mesh.simplify_quadratic_decimation(n_faces)
    return decimated

def simplify_mesh_open3d(vertices: np.ndarray, faces: np.ndarray, n_faces):
    """Simplify mesh using open3d."""
    import open3d
    o3d_verts = open3d.utility.Vector3dVector(vertices)
    o3d_faces = open3d.utility.Vector3iVector(faces)
    o3d_mesh = open3d.geometry.TriangleMesh(o3d_verts, o3d_faces)

    o3d_new_mesh = o3d_mesh.simplify_quadric_decimation(n_faces)
    
    new_verts = np.array(o3d_new_mesh.vertices)
    new_faces = np.array(o3d_new_mesh.triangles)
    
    decimated = meshio.Mesh(points=new_verts, cells=[('triangle', new_faces)])
    return decimated

In [None]:
import pyvista as pv

def meshio2polydata(mesh):
    return pv.from_meshio(mesh)

def meshio2trimesh():
    raise NotImplementedError

def trimesh2meshio():
    raise NotImplementedError

def trimesh2polydata():
    raise NotImplementedError

def polydata2meshio():
    raise NotImplementedError

def polydata2trimesh():
    raise NotImplementedError

In [None]:
from skimage import measure
from scipy.spatial import Delaunay
import meshio

image = seg_vol.image

point_density = False
pad = False

plot = False

if point_density:
    # grid_points = add_points_grid(image, border=5)
    n_points1 = int(np.sum(image == 1) * point_density)
    grid_points = add_points_kmeans(image, iters=20, n_points=n_points1)
    points.append(grid_points)

    # adding points to holes helps to get a cleaner result
    n_points0 = int(np.sum(image == 0) * point_density)
    grid_points = add_points_kmeans(1 - image,
                                    iters=20,
                                    n_points=n_points0)
    points.append(grid_points)

if pad:
    n_points_edge = (np.array(image.shape) *
                     (point_density**0.5)).astype(int)
    pad_points = add_edge_points(image, n_points=n_points_edge)
    points.append(pad_points)
    # image = np.pad(image, 1, constant_values=0)


# vers = contours
verts, faces, normals, values = measure.marching_cubes(
    image, 
    allow_degenerate=False,
    step_size=5,
)

mesh = simplify_mesh_trimesh(vertices=verts, faces=faces, n_faces=5000)
points = mesh.vertices

tetrahedra = Delaunay(points, incremental=False).simplices

teti, tetj, tetk, tetl = tetrahedra.T
centers = (
    points[teti] + points[tetj] + points[tetk] + points[tetl]
) / 4

mask = image[tuple(centers.astype(int).T)] == 0

cells = [
    ('tetra', tetrahedra[~mask]),
]

mesh = meshio.Mesh(points, cells)
mesh.remove_orphaned_nodes()

pv.plot_itk(mesh)

In [None]:
grid = meshio2polydata(mesh)

along = 'x'
index = 100

def show_submesh(grid, *, index=100, along='x'):
    """Slow a slice of the mesh."""

    # get cell centroids
    cells = grid.cells.reshape(-1, 5)[:, 1:]
    cell_center = grid.points[cells].mean(1)

    # extract cells below index
    axis = 'zyx'.index(along)

    mask = cell_center[:, axis] < index
    cell_ind = mask.nonzero()[0]
    subgrid = grid.extract_cells(cell_ind)

    plotter = pv.PlotterITK()
    plotter.add_mesh(subgrid)
    plotter.show()

show_submesh(grid, index=100)