In [None]:
import flowshape as fs
import igl
import numpy as np
import meshplot as mp
import os
from src.utilities.fin_shape_utils import fit_fin_hull, upsample_fin_point_cloud, plot_mesh
from src.utilities.fin_class_def import FinData
from src.utilities.functions import path_leaf
import glob2 as glob

## Idea: Use bottom half of ellipsoid...smooth the fin-ellipse interface

### Load fin data

In [None]:
# root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/pecfin_dynamics/"
root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/pecfin_dynamics/"
fin_object_path = os.path.join(root, "point_cloud_data", "fin_objects", "")
fin_object_list = sorted(glob.glob(fin_object_path + "*.pkl"))

file_ind01 = 146
seg_type = "tissue_only_best_model_tissue"
fp01 = fin_object_list[file_ind01]
point_prefix01 = path_leaf(fp01).replace("_fin_object.pkl", "")
print(point_prefix01)

fin_object = FinData(data_root=root, name=point_prefix01, tissue_seg_model=seg_type)

### Calculate distance from each fin nucleus to the yolk surface

In [None]:
from sklearn.metrics import pairwise_distances

full_df = fin_object.full_point_data
fin_df = full_df.loc[full_df["fin_label_curr"] == 1, :].reset_index(drop=True)

# orient to biological axes
fin_axis_df = fin_object.axis_fin
fin_axes = fin_object.calculate_axis_array(fin_axis_df)

# Use simple numerical procedure to calculate distance of each fin point to the yolk
fin_points = fin_df[["X", "Y", "Z"]].to_numpy()
shift_ref_vec = np.mean(fin_points, axis=0)

fin_points_pca = np.matmul(fin_points - shift_ref_vec, fin_axes.T)
fin_df.loc[:, ["XP", "YP", "ZP"]] = fin_points_pca

params = fin_object.yolk_surf_params

x_min, y_min = fin_points[:, 0].min(), fin_points[:, 1].min()
x_max, y_max = fin_points[:, 0].max(), fin_points[:, 1].max()

# Create a mesh grid for x and y values
x_vals = np.linspace(x_min, x_max, 100)
y_vals = np.linspace(y_min, y_max, 100)
X, Y = np.meshgrid(x_vals, y_vals)

yolk_xyz = np.reshape(fin_object.polyval2d(np.c_[X.ravel(), Y.ravel()], params).ravel(), (-1, 3))

dist_array = pairwise_distances(fin_points, yolk_xyz)
yolk_dist = np.min(dist_array, axis=1)
min_i = np.argmin(dist_array, axis=1)
yolk_signs = np.sign(fin_points[:, 2] - yolk_xyz[min_i, 2])
yolk_dist = -np.multiply(yolk_dist, yolk_signs)

fin_df["yolk_dist"] = yolk_dist

### Remove nuclei that are too far below the yolk surface

In [None]:
full_df = fin_object.full_point_data

dist_thresh = -10
fin_df.reset_index(inplace=True, drop=True)
dist_filter = (fin_df["yolk_dist"]>=dist_thresh).to_numpy()
nuclei_to_keep = fin_df.loc[dist_filter, "nucleus_id"].to_numpy()

### Use multivariate gaussians to upsample nuclei

In [None]:
# sample nucleus boundary points from nucleus masks
fin_df_upsamp = upsample_fin_point_cloud(fin_object, sample_res_um=0.4, root=root, points_per_nucleus=100)

In [None]:
import open3d as o3d

# get raw points
nc_vec_us = fin_df_upsamp.loc[:, "nucleus_id"].to_numpy().astype(np.uint16)
keep_filter = np.isin(nc_vec_us, nuclei_to_keep)
fin_points = fin_df_upsamp.loc[keep_filter, ["XP", "YP", "ZP"]].to_numpy()

# convert to point cloud format
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(fin_points)

# resample points to be more spatially uniform
min_distance = 0.5
sampled_points = pcd.voxel_down_sample(voxel_size=min_distance) 

### Fit high- and low- resolution meshes

In [None]:
# fit a high-res mesh 
fin_points_u = np.asarray(sampled_points.points)
fin_hull_h, raw_hull_h, wt_flag = fit_fin_hull(fin_points_u, alpha=23, n_faces=5000)
print(wt_flag)

In [None]:
# fit a low-res mesh
fin_hull_l, raw_hull_l, wt_flag_l = fit_fin_hull(fin_points_u, alpha=1, n_faces=5000)
print(wt_flag)

In [None]:
import plotly.graph_objects as go

#igl.read_triangle_mesh("/home/nick/projects/flowshape/demo/ABal.obj")
vh, fh = fin_hull_h.vertices.copy(), fin_hull_h.faces.copy()
vl, fl = fin_hull_l.vertices.copy(), fin_hull_l.faces.copy()
# mp.plot(v, f, shading = {"wireframe":True})
_, lines_h, mesh_h = plot_mesh(fin_hull_h, surf_alpha=0.5)
_, lines_l, mesh_l = plot_mesh(fin_hull_l, surf_alpha=0.5)

fig = go.Figure()
fig.add_trace(mesh_h)
fig.add_trace(mesh_l)
fig.show()

### Use fit-surface to generate an oriented ellipsoid cap

In [None]:
# # rescale the raw hull
# xyz_fin = fin_points_u

# mp = np.min(xyz_fin)
# points = xyz_fin - mp
# mmp = np.max(points)

# vt = raw_hull.vertices
# vt = vt * mmp
# vt = vt + mp

# # shift fin hull points
# # vt = np.dot(vt - surf_center_o, rotation_matrix)
# raw_hull_rs = raw_hull.copy()
# raw_hull_rs.vertices = vt

#### Idea: use depth relative to yolk surface as a parameter to dictate how much weight is allocated to ellipsoid vs fin

In [None]:
from scipy.spatial import KDTree

fin_points_h = np.asarray(fin_hull_h.vertices)
fin_points_l = np.asarray(fin_hull_l.vertices)

# Find the nearest points on the fin boundary for each ellipsoid boundary point
tree = KDTree(fin_points_l)
distances, indices = tree.query(fin_points_h)
distances.shape

In [None]:
# Calculate distance from  yolk surface to ellipsoid points
# dist_array_e = pairwise_distances(fin_points_h, surf_points_ro)
# yolk_dist_e = np.min(dist_array_e, axis=1)
# min_i_e = np.argmin(dist_array_e, axis=1)
# yolk_signs_e = np.sign(ellipsoid_points[:, 2] - surf_points_ro[min_i_e, 2])
yolk_dist_e = fin_points_h[:, 2] #np.multiply(yolk_dist_e, yolk_signs_e)

In [None]:
# import plotly.express as px

# fig = px.scatter_3d(x=ellipsoid_points[:, 0], y=ellipsoid_points[:, 1], z=ellipsoid_points[:, 2], color=yolk_dist_e)
# fig.show()

In [None]:
# e_weight_vec = np.ones_like(yolk_dist_e) #.copy()
kd = 0
t = 10
e_weight_vec = np.divide(np.exp(-(yolk_dist_e - kd)/t), 1 + np.exp(-(yolk_dist_e - kd)/t)) #-(yolk_dist_e**1) / cap_depth
# e_weight_vec[e_weight_vec < 0] = 0
# e_weight_vec[e_weight_vec > 1] = 1

In [None]:
import plotly.express as px

fig = px.scatter_3d(x=fin_points_h[:, 0], y=fin_points_h[:, 1], z=fin_points_h[:, 2], color=e_weight_vec)
fig.show()

In [None]:
# get closest fin point for each ellipsoid point
f_weight_vec = (1 - e_weight_vec)
fin_points_e = fin_points_l[indices]
# take weighted average
new_points_e = np.multiply(e_weight_vec[:, np.newaxis], fin_points_l) + np.multiply(f_weight_vec[:, np.newaxis], fin_points_h)

In [None]:
### e_mesh_new = fin_hull_h.copy()
e_mesh_new.vertices = new_points_e

_, lines, mesh = plot_mesh(e_mesh_new, surf_alpha=1)

fig = go.Figure()
fig.add_traces(mesh)
fig.update_scenes(aspectmode='data')
fig.show()

In [None]:
smoothing_strength = 5
n_faces = 5000

# copy
hull02_cc = raw_hull_rs.copy()

# keep only largest component
hull02_cc = hull02_cc.split(only_watertight=False)
hull02_sm = max(hull02_cc, key=lambda m: m.area)

# fill holes
hull02_sm.fill_holes()

# smooth
combined_mesh = trimesh.boolean.union([hull02_sm, e_mesh_new])
combined_mesh = trimesh.smoothing.filter_laplacian(combined_mesh, iterations=smoothing_strength)

# resample
n_faces = np.min([n_faces, combined_mesh.faces.shape[0]-1])
combined_mesh = combined_mesh.simplify_quadric_decimation(face_count=n_faces)
combined_mesh = combined_mesh.split(only_watertight=False)
combined_mesh = max(combined_mesh, key=lambda m: m.area)
combined_mesh.fill_holes()
combined_mesh.fix_normals()


In [None]:
import plotly.graph_objects as go

# combined_mesh = trimesh.boolean.union([raw_hull, e_mesh_new])

_, lines, mesh = plot_mesh(combined_mesh, surf_alpha=1)

fig = go.Figure()
fig.add_traces(mesh)
# fig.add_traces(mesh0)
fig.update_scenes(aspectmode='data')
fig.show()

### Run spherical mapping

In [None]:
# normalize the scaling of the mesh
v = fs.normalize(v)

# run the spherical mapping flow and mobius centering
sv = fs.sphere_map(v, f)

# Now we have a spherical mesh
mp.plot(sv, f, shading = {"wireframe":True})

### Calculate the mean curvature

In [None]:
rho = fs.curvature_function(v, sv, f)

mp.plot(v,f, rho )

## SH decomposition

In [None]:
# this utility does the above steps + SH decomposition
# Here, using maximum degree 24
weights, Y_mat, vs = fs.do_mapping(v,f, l_max = 24)

In [None]:
# This is the array of SH weights
np.set_printoptions(threshold = 100)
print(weights)

In [None]:
# Y_mat is a matrix used to convert between weights and mesh function
rho2 = Y_mat.dot(weights)
mp.plot(sv,f, c = rho2)

### Mesh reconstruction

In [None]:
rec2 = fs.reconstruct_shape(sv, f, rho2 )
mp.plot(rec2,f, c = rho2)

### Test with lower frequencies only

In [None]:
weights, Y_mat, vs = fs.do_mapping(v,f, l_max=4)
rec_8 = fs.reconstruct_shape(sv, f, Y_mat.dot(weights) )
mp.plot(rec_8, f, c = rho2)