In [None]:
import tifffile as tif
import scipy.ndimage as ndi
import skimage.measure as skim
import skimage.feature as skif
import skimage.morphology as skimo
import skimage.transform as skit
import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
import matplotlib.pyplot as plt

In [None]:
im = tif.imread('labeled_img_post.tif')

In [None]:
im.shape

In [None]:
im.max()

In [None]:
px_size = (0.5, 0.11, 0.11)

In [None]:
dots = pd.read_csv('pre_seg_diff_1_minseeds_3_filtered.csv')
dots.head()

In [None]:
scale_factor = (1., 1./16, 1./16)

px_scaled = tuple(a/b for a, b in zip(px_size, scale_factor))

im_small = skit.rescale(im, 
                       scale_factor, 
                       order=0, 
                       mode='reflect', 
                       cval=0, 
                       clip=True, 
                       preserve_range=True, 
                       anti_aliasing=False, 
                       anti_aliasing_sigma=None
                      ).astype(np.uint8)

In [None]:
px_scaled

In [None]:
im_small.dtype

In [None]:
with tif.TiffWriter('labeled_small.tif') as w:
    w.write(im_small)

## Triangulate each labeled region separately

Use `skimage.measure.marching_cubes` to generate triangulated mesh for each labeled region. 

Also add the bounding box origin of each region so that when they are combined their relative
positions are correct.

In [None]:
region_meshes = []

for r in tqdm(skim.regionprops(im_small)):
    # compute isosurface on 1 px padded image
    tris = skim.marching_cubes(
        np.pad(r.image, ((1, 1), (1, 1), (1, 1))),
        level=0.5, 
        spacing=px_scaled,
        step_size=1
    )
    # add corner coordinates to properly position points
    new_pts = np.add(tris[0]-np.array(px_scaled), np.multiply(px_scaled, r.bbox[:3]))
    
    region_meshes.append((new_pts, tris[1])) # skip normals for now
    

In [None]:
(tris[0] - np.array(px_scaled)).min(axis=0)

In [None]:
sum([a[0].nbytes+a[1].nbytes for a in region_meshes])/1e6

In [None]:
def combine_meshes(mesharr, indices):
    """
    combine_meshes
    --------------
    take any combo of individual meshes and concatenate them,
    adding appropriate offsets to the point indices in the faces array.
    """
    if indices is None:
        indices = range(len(mesharr))
        
    indices = np.atleast_1d(indices)
    
    comb_pts = []
    comb_tris = []
    maxpt = 0
    
    for ind in indices:
        pts = mesharr[ind][0]
        tris = mesharr[ind][1]
        comb_pts.extend(pts)
        comb_tris.extend(tris+maxpt)
        maxpt += len(pts)
        
    return np.array(comb_pts), np.array(comb_tris)

In [None]:
comb_test = combine_meshes(region_meshes, None)

In [None]:
len(comb_test[0]), len(comb_test[1])

## Entire field triangulation version

In [None]:
region_meshes = []

tris = skim.marching_cubes(
    np.pad(im_small>0, ((1, 1), (1, 1), (1, 1))),
    level=0.5, 
    spacing=px_scaled,
    step_size=1
)

comb_test = (tris[0], tris[1])

In [None]:
len(comb_test[0]), len(comb_test[1])

In [None]:
comb_test[0].min(axis=0)

In [None]:
z,y,x = comb_test[0].T
i,j,k = comb_test[1].T

## Triangulated mesh

In [None]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Mesh3d(x=x, y=y, z=z, 
                                i=i, j=j, k=k,
                                color='lightpink', opacity=1)
                     ])
fig.show()

## Delauney of edge pixels?

## Decimation with pyvista

In [None]:
import pyvista as pv

In [None]:
def tris2pyvista(tris):
    return np.pad(tris, ((0, 0), (1, 0)), constant_values=3)

def pyvista2tris(faces):
    assert all(faces[::4]==3), 'At least one face is not a triangle'
    
    ntris = len(faces)//4
    
    selector = [False, True, True, True] * ntris
    
    return np.compress(selector, faces).reshape((ntris,3))
    

In [None]:
len(region_meshes)

In [None]:
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED

In [None]:
comb_all = combine_meshes(region_meshes, None)

In [None]:
def decimate_mapper(mesh, target=0.8):
    poly = pv.PolyData(mesh[0], tris2pyvista(mesh[1]))
    poly_dec = poly.decimate(target)

    return (poly_dec.points, pyvista2tris(poly_dec.faces))

In [None]:
list.pop?

In [None]:
decimated_meshes = []

workers = 8

with ThreadPoolExecutor(max_workers=workers) as tpe:
    
    while len(region_meshes) != 0:
        print(f'meshes left = {len(region_meshes)}')
        working = []
        try:
            # take the first `workers` sub-meshes, if available
            [ working.append(region_meshes.pop(0)) for _ in range(workers) ]
        except:
            pass
        
        wfuts = [ tpe.submit(decimate_mapper, w) for w in working ]
        
        wait(wfuts)
        
        decimated_meshes.extend([f.result() for f in wfuts])

In [None]:
decimated_comb = combine_meshes(decimated_meshes, None)

In [None]:
len(decimated_comb[0]), len(decimated_comb[1])

## Decimated triangulated mesh

In [None]:
import plotly.graph_objects as go

z,y,x = decimated_comb[0].T
i,j,k = decimated_comb[1].T

fig = go.Figure(data=[go.Mesh3d(x=x, y=y, z=z, 
                                i=i, j=j, k=k,
                                color='lightpink', opacity=0.9)])
fig.show()

# 2D slicewise contours

In [None]:
def im2slicecontours(im, approx=2):
    slicecontours = []
    
    for i, s in enumerate(im):
        cont = skim.find_contours(
                np.pad(s>0, ((1, 1), (1, 1))),
                level=0.5, 
                fully_connected='high'
            )

        # add the Z coordinate in front and simplify the contours in one statement
        slicecontours.extend([ 
            np.pad(
                skim.approximate_polygon(c, approx),
                    ((0, 0), (1, 0)), constant_values=i
                )
            for c in cont if len(c) > 3 
        ])
        
    return slicecontours
    

In [None]:
test = im2slicecontours(im)

In [None]:
sum([len(s) for s in test])

In [None]:
sum([len(a) for a in appcont])/sum([len(c) for c in cont])

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
plt.imshow(im[0]>0)
for c in appcont:
    plt.plot(c.T[1], c.T[0])
    
plt.show()

In [None]:
import plotly.express as px

z,y,x = test[0].T

fig = px.line_3d(x=x, y=y, z=z)
fig.show()

In [None]:
import plotly.graph_objects as go


fig = go.Figure(data=[
    go.Scatter3d(x=c.T[2], y=c.T[1], z=c.T[0],
                marker=dict(size=0),
                line=dict(
                color='darkblue',
                width=0.5))
    for c in test[:3]
])
fig.show()

In [None]:
test_concat = np.vstack(test)

In [None]:
test_concat.shape

In [None]:
import plotly.graph_objects as go

z,y,x = test_concat.T

fig = go.Figure(data=[go.Mesh3d(x=x, y=y, z=z, 
                                color='lightpink', opacity=0.50)])
fig.show()

# Open3d

In [None]:
import open3d as o3d

In [None]:
im.shape

In [None]:
mc_mesh = o3d.geometry.TriangleMesh()
mc_mesh.vertices = o3d.utility.Vector3dVector(comb_test[0])
mc_mesh.triangles = o3d.utility.Vector3iVector(comb_test[1])

In [None]:
np.asarray(mc_mesh.vertices)

In [None]:
px_scaled

In [None]:
dots['geneInd'] = dots['geneID'].factorize()[0] % 20


In [None]:
cm = plt.get_cmap('tab20')
dots['geneColor'] = [ cm(c)[:3] for c in dots['geneInd'] ]

In [None]:
dots_pcd = o3d.geometry.PointCloud()
dots_pcd.points = o3d.utility.Vector3dVector(np.multiply(dots[['z', 'y', 'x']].values-np.array([1, 0, 0]), px_size))
dots_pcd.colors = o3d.utility.Vector3dVector(dots['geneColor'].values)

In [None]:
dots_pcd

In [None]:
dots_pcd.colors

In [None]:
o3d.o3.visualization.draw_geometries([mc_mesh, dots_pcd])

In [None]:
o3d.io.write_triangle_mesh('labeled_16x_tris.obj', mc_mesh)
o3d.io.write_point_cloud('dots_colored.pcd', dots_pcd)

In [None]:
import json

mesh_json = {
    'verts': np.asarray(mc_mesh.vertices).tolist(),
    'faces': np.asarray(mc_mesh.triangles).tolist()
}

pcd_json = {
    'points': np.asarray(dots_pcd.points).tolist(),
    'colors': np.asarray(dots_pcd.colors).tolist(),
    'genes': dots['geneID'].values.tolist()
}

In [None]:
with open('labeled_16x_mesh.json', 'w') as fp:
    json.dump(mesh_json, fp)
    
with open('dots_colored.json', 'w') as fp:
    json.dump(pcd_json, fp)

In [None]:
json.load?

## Plotly with points

In [None]:
pz,py,px = np.asarray(dots_pcd.points).T

In [None]:
import plotly.graph_objects as go

figdata = [
    go.Mesh3d(x=x, y=y, z=z, 
              i=i, j=j, k=k,
        color='lightpink',
        opacity=1,
        hoverinfo='skip',
    ),
    go.Scatter3d(x=px, y=py, z=pz,
        mode='markers',
        marker=dict(
            size=1,
            color=dots['geneColor'].values,
            opacity=1,
            symbol='circle',
        ),
        hoverinfo='skip',
    )
]

figscene = go.layout.Scene(
    aspectmode='data'
)

figlayout= go.Layout(
    height=800,
    width=800,
    margin=dict(b=10, l=10, r=10, t=10),
    scene=figscene
)


fig = go.Figure(data=figdata, layout=figlayout)
fig.show(renderer='jupyterlab')

In [None]:
import plotly.io

In [None]:
plotly.io.renderers