In [1]:
import flowshape as fs
import igl
import numpy as np
import meshplot as mp
import os
from src.utilities.fin_shape_utils import fit_fin_hull, upsample_fin_point_cloud, plot_mesh
from src.utilities.fin_class_def import FinData
from src.utilities.functions import path_leaf
import glob2 as glob

### Load fin data

In [2]:
# root = "/Users/nick/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/pecfin_dynamics/"
root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/pecfin_dynamics/"
fin_object_path = os.path.join(root, "point_cloud_data", "fin_objects", "")
fin_object_list = sorted(glob.glob(fin_object_path + "*.pkl"))

file_ind01 = 146
seg_type = "tissue_only_best_model_tissue"
fp01 = fin_object_list[file_ind01]
point_prefix01 = path_leaf(fp01).replace("_fin_object.pkl", "")
print(point_prefix01)

fin_object = FinData(data_root=root, name=point_prefix01, tissue_seg_model=seg_type)

20240711_02_well0053_time0000


### Calculate distance from each fin nucleus to the yolk surface

In [3]:
from sklearn.metrics import pairwise_distances

full_df = fin_object.full_point_data
fin_df = full_df.loc[full_df["fin_label_curr"] == 1, :].reset_index(drop=True)

# orient to biological axes
fin_axis_df = fin_object.axis_fin
fin_axes = fin_object.calculate_axis_array(fin_axis_df)

# Use simple numerical procedure to calculate distance of each fin point to the yolk
fin_points = fin_df[["X", "Y", "Z"]].to_numpy()
shift_ref_vec = np.mean(fin_points, axis=0)

fin_points_pca = np.matmul(fin_points - shift_ref_vec, fin_axes.T)
fin_df.loc[:, ["XP", "YP", "ZP"]] = fin_points_pca

params = fin_object.yolk_surf_params

x_min, y_min = fin_points[:, 0].min(), fin_points[:, 1].min()
x_max, y_max = fin_points[:, 0].max(), fin_points[:, 1].max()

# Create a mesh grid for x and y values
x_vals = np.linspace(x_min, x_max, 100)
y_vals = np.linspace(y_min, y_max, 100)
X, Y = np.meshgrid(x_vals, y_vals)

yolk_xyz = np.reshape(fin_object.polyval2d(np.c_[X.ravel(), Y.ravel()], params).ravel(), (-1, 3))

dist_array = pairwise_distances(fin_points, yolk_xyz)
yolk_dist = np.min(dist_array, axis=1)
min_i = np.argmin(dist_array, axis=1)
yolk_signs = np.sign(fin_points[:, 2] - yolk_xyz[min_i, 2])
yolk_dist = -np.multiply(yolk_dist, yolk_signs)

fin_df["yolk_dist"] = yolk_dist

### Calculate fin dimensions at the base and find centerpoint 

In [4]:
# get points near surface
yolk_thresh = 5
base_fin_points = fin_df.loc[np.abs(fin_df["yolk_dist"])<=yolk_thresh, ["XP", "YP", "ZP"]].to_numpy()
base_fin_points_raw = fin_df.loc[np.abs(fin_df["yolk_dist"])<=yolk_thresh, ["X", "Y", "Z"]].to_numpy()

# calculate axis dims. Main one we care about is the AP axis ("YP")
axis_len_vec = np.max(base_fin_points, axis=0) - np.min(base_fin_points, axis=0)

# find centroid
point_center = np.mean(base_fin_points_raw, axis=0)
surf_center_i = np.argmin(np.sqrt(np.sum((yolk_xyz-point_center)**2, axis=1)))
surf_center = yolk_xyz[surf_center_i, :] # this is the one we will use

# define a local DV direction that is the cross product of the surface normal and the AP axis
surf_normal_raw, _ = fin_object.calculate_tangent_plane(fin_object.yolk_surf_params, surf_center)
if surf_normal_raw[2] > 0:
    surf_normal_raw = -surf_normal_raw

# convert the normal vector to the biological axis space
surf_normal = np.matmul(np.reshape(surf_normal_raw, (1, 3)) , fin_axes.T)[0]
surf_normal = surf_normal / np.linalg.norm(surf_normal)

# calculate local DV
dv_vec_base = np.cross(surf_normal, np.asarray([0, 1, 0]))
dv_vec_base = dv_vec_base / np.linalg.norm(dv_vec_base)

# finally, calculate local DV axis dims
dv_vec_loc = np.sum(np.multiply(dv_vec_base[np.newaxis, :], base_fin_points), axis=1)

# get axis lengths
ap_axis_len = axis_len_vec[1]
dv_axis_len = np.max(dv_vec_loc) - np.min(dv_vec_loc)
print(ap_axis_len)
print(dv_axis_len)

103.55529279846348
80.57091076030613


In [5]:
# import plotly.express as px 
# test = np.sum(np.multiply(dv_vec_base[np.newaxis, :], fin_points_pca), axis=1)
# fig = px.scatter_3d(x=fin_points_pca[:, 0], y=fin_points_pca[:, 1], z=fin_points_pca[:, 2], color=test)
# fig.show()

### Load and filter fin+yolk nuclei

In [6]:
# shift centerpoint into the oriented frame of reference
surf_center_o = np.matmul(surf_center - shift_ref_vec, fin_axes.T)

# shift fin+yolk dataset to oriented frame of reference
fin_yolk_df = full_df.loc[np.isin(full_df["fin_label_curr"], [1, 2]), :].reset_index(drop=True)
fin_yolk_points = fin_yolk_df[["X", "Y", "Z"]].to_numpy()
fin_yolk_points_o = np.matmul(fin_yolk_points - shift_ref_vec, fin_axes.T)
fin_yolk_df[["XP", "YP", "ZP"]] = fin_yolk_points_o

### Use AP and DV dims to capture ellipsoidal "cap" at fin base

In [7]:
# calculate yolk distances (again) and use to filter for
params = fin_object.yolk_surf_params

x_min, y_min = fin_yolk_points[:, 0].min(), fin_yolk_points[:, 1].min()
x_max, y_max = fin_yolk_points[:, 0].max(), fin_yolk_points[:, 1].max()

# Create a mesh grid for x and y values
x_vals = np.linspace(x_min, x_max, 250)
y_vals = np.linspace(y_min, y_max, 250)
X, Y = np.meshgrid(x_vals, y_vals)

yolk_xyz2 = np.reshape(fin_object.polyval2d(np.c_[X.ravel(), Y.ravel()], params).ravel(), (-1, 3))

# get nearest neighbor distances
dist_array2 = pairwise_distances(fin_yolk_points, yolk_xyz2)
yolk_dist2 = np.min(dist_array2, axis=1)
min_i2 = np.argmin(dist_array2, axis=1)
yolk_signs2 = np.sign(fin_yolk_points[:, 2] - yolk_xyz2[min_i2, 2])
yolk_dist2 = -np.multiply(yolk_dist2, yolk_signs2)

fin_yolk_df["yolk_dist"] = yolk_dist2

In [8]:
# get points within ellipsoidal boundary
depth_semi_axis = 25
ap_min = 50
dv_min = 30
ap_semi_axis = np.max([ap_axis_len, ap_min]) / 2
dv_semi_axis = np.max([dv_axis_len, dv_min]) / 2

# below the surface
fy_candidate_points = fin_yolk_points_o[yolk_dist2<0, :]
fy_indices = np.where(yolk_dist2<0)[0]
fy_candidate_ids = fin_yolk_df.loc[yolk_dist2<0, "nucleus_id"].to_numpy()
# (x/a)^2 + (y/b)^2 + (z/c)^2 <= 1

# AP piece is trivial
ap = ((fy_candidate_points[:, 1]-surf_center_o[1]) / ap_semi_axis)**2

# DV piece 
# dv_center = np.sum(np.dot(surf_center_o, dv_vec_base))
dv_dist = np.sum(np.multiply(fy_candidate_points-surf_center_o, dv_vec_base[np.newaxis,:]), axis=1)
dv = ((dv_dist) / dv_semi_axis)**2

# depth piece 
# depth_center = np.sum(np.dot(surf_center_o, surf_normal))
depth_dist = np.sum(np.multiply(fy_candidate_points-surf_center_o, surf_normal[np.newaxis,:]), axis=1)
dd = ((depth_dist) / depth_semi_axis)**2

# put it all together to get cap points
cap_flag = ((ap + dv + dd) <= 1)# & (depth_dist<0)
cap_ids = fy_candidate_ids[cap_flag]

fin_yolk_df["fin_cap_flag"] = False
fin_yolk_df.loc[fy_indices[cap_flag], "fin_cap_flag"] = True
print(np.max(depth_dist))
print(np.min(depth_dist))

-0.6485893196675994
-108.6241265893925


In [9]:
import plotly.graph_objects as go 

fin_filter = ((fin_yolk_df["fin_label_curr"]==1) & (fin_yolk_df["yolk_dist"] >= -5)).to_numpy()
cap_filter = (fin_yolk_df["fin_cap_flag"]==1).to_numpy()

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=fin_yolk_df.loc[fin_filter, "XP"], y=fin_yolk_df.loc[fin_filter, "YP"], 
                           z=fin_yolk_df.loc[fin_filter, "ZP"], mode="markers"))
# fig.add_trace(go.Scatter3d(x=fin_yolk_df.loc[cap_filter, "XP"], y=fin_yolk_df.loc[cap_filter, "YP"],
#                            z=fin_yolk_df.loc[cap_filter, "ZP"], mode="markers"))

fig.update_traces(
    marker=dict(color=fin_yolk_df.loc[fin_filter, "yolk_dist"], cmin=-20, cmax=20,
        colorbar=dict(
            title="Color Scale",
            tickformat=".2f",
            len=0.7  # Adjust the length of the colorbar
        ),
        size=6,  # Marker size
        opacity=0.8,  # Marker opacity
    ),
    hovertemplate=(
            "X: %{x}<br>"
            "Y: %{y}<br>"
            "Z: %{z}<br>"
            "Additional Data: %{text}<br>"
            "Color Value: %{marker.color}<extra></extra>"  # Suppresses default trace info
        )
)

fig.show()

In [10]:
# import plotly.graph_objects as go
# # test = np.sum(np.multiply(dv_vec_base[np.newaxis, :], fin_points_pca), axis=1)
# fig = go.Figure() #px.scatter_3d
# fig.add_trace(go.Scatter3d(x=fy_candidate_points[:, 0], y=fy_candidate_points[:, 1], z=fy_candidate_points[:, 2], mode="markers",
#                            opacity=0.1))
# fig.add_trace(go.Scatter3d(x=fy_candidate_points[cap_flag, 0], y=fy_candidate_points[cap_flag, 1],
#                            z=fy_candidate_points[cap_flag, 2], opacity=0.7, mode="markers"))
# # fig.add_trace(go.Scatter3d(x=yolk_xyz2[:, 0], y=yolk_xyz2[:, 1], z=yolk_xyz2[:, 2], mode="markers"))
# # fig.add_trace(go.Scatter3d(x=fin_yolk_points[:, 0], y=fin_yolk_points[:, 1], z=fin_yolk_points[:, 2], mode="markers"))
# fig.show()

### Load, filter, and orient nucleus centroid point cloud

In [11]:
full_df = fin_object.full_point_data

dist_thresh = -5
# fin_df["yolk_dist"] = -fin_df["yolk_dist"]=
fin_df.reset_index(inplace=True, drop=True)
dist_filter = (fin_df["yolk_dist"]>=dist_thresh).to_numpy()
# fin_df_ft = fin_df.loc[dist_filter, :]
# fin_points = fin_df_ft[["X", "Y", "Z"]].to_numpy()
nuclei_to_keep = fin_df.loc[dist_filter, "nucleus_id"].to_numpy()
# # orient to biological axes
# fin_axis_df = fin_object.axis_fin
# fin_axes = fin_object.calculate_axis_array(fin_axis_df)
# fin_points_pca = np.matmul(fin_points - np.mean(fin_points, axis=0), fin_axes.T)
# fin_df_ft.loc[:, ["XP", "YP", "ZP"]] = fin_points_pca

In [12]:
# import plotly.express as px
# fig = px.scatter_3d(fin_df_ft, x="XP", y="YP", z="ZP", color="yolk_dist")
# fig.show()

In [13]:
# sample nucleus boundary points from nucleus masks
fin_df_upsamp = upsample_fin_point_cloud(fin_object, sample_res_um=0.4, root=root, points_per_nucleus=100)
# fin_df_upsamp[["XP", "YP", "ZP"]] = fin_df_upsamp[["XP", "YP", "ZP"]]*0.4



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

100%|██████████| 597/597 [00:20<00:00, 29.69it/s]
100%|██████████| 597/597 [00:01<00:00, 373.02it/s]


In [14]:
import open3d as o3d

# define a third vector that is orthognal to AP and local (base) DV)
surf_vec_rel = np.cross([0, 1, 0], dv_vec_base)
surf_vec_rel = surf_vec_rel / np.linalg.norm(surf_vec_rel)


# get raw points
nc_vec_us = fin_df_upsamp.loc[:, "nucleus_id"].to_numpy().astype(np.uint16)
keep_filter = np.isin(nc_vec_us, nuclei_to_keep)
fin_points = fin_df_upsamp.loc[keep_filter, ["XP", "YP", "ZP"]].to_numpy()

# re-orient points to yolk frrame of reference
rotation_matrix = np.stack([dv_vec_base, np.asarray([0, 1, 0]), surf_vec_rel], axis=1)
# fin_points_ro = np.matmul(fin_points - surf_center_o, new_axes.T)
fin_points_ro = np.dot(fin_points - surf_center_o, rotation_matrix)
fin_points_ro[:, 1] = fin_points_ro[:, 1] - np.mean(fin_points_ro[:, 1])
# convert to point cloud format
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(fin_points_ro)

# resample points to be more spatially uniform
min_distance = 0.5
sampled_points = pcd.voxel_down_sample(voxel_size=min_distance) 

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [15]:
# import plotly.express as px
# fig = px.scatter_3d(fin_df_upsamp, x="XP", y="YP", z="ZP")
# fig.show()

In [16]:
# fit a mesh 
fin_points_u = np.asarray(sampled_points.points)
fin_hull, raw_hull, wt_flag = fit_fin_hull(fin_points_u, alpha=24, n_faces=5000)
print(wt_flag)

True


In [17]:
#igl.read_triangle_mesh("/home/nick/projects/flowshape/demo/ABal.obj")
v, f = v, f = fin_hull.vertices.copy(), fin_hull.faces.copy()
# mp.plot(v, f, shading = {"wireframe":True})
_, lines, mesh = plot_mesh(fin_hull, surf_alpha=1)
fig = go.Figure()
fig.add_trace(mesh)
fig.show()

### Experiment with adding an ellipsoidal cap

In [18]:
import trimesh

# Create a unit sphere (radius = 1.0)
sphere = trimesh.creation.icosphere(radius=1.0)

depth_semi_axis = 10
# Define the semi-axis lengths for the ellipsoid
semi_axes = np.asarray([dv_semi_axis, ap_semi_axis*1.5, depth_semi_axis]) # Example lengths for x, y, z axes
print(semi_axes)
# Apply scaling transformation to create an ellipsoid
scale_matrix = np.diag(semi_axes.tolist() + [1])  # Diagonal scaling matrix for x, y, z (and 1 for homogeneous coordinates)
ellipsoid = sphere.copy()
ellipsoid.apply_transform(scale_matrix)
ellipsoid.vertices = ellipsoid.vertices - 5

[40.28545538 77.6664696  10.        ]


In [19]:
fig, _, _ = plot_mesh(ellipsoid)
fig.update_scenes(aspectmode='data')
fig.show()

In [20]:
combined_mesh = trimesh.boolean.union([fin_hull, ellipsoid])

In [21]:
_, lines, mesh = plot_mesh(combined_mesh, surf_alpha=1)

fig = go.Figure()
fig.add_trace(mesh)
fig.update_scenes(aspectmode='data')
fig.show()

In [45]:
import alphashape
alpha = 24 
smoothing_strength = 5
xyz_fin = fin_points_ro
n_faces = 5000

# normalize for alphshape fitting
mp = np.min(xyz_fin)
points = xyz_fin - mp
mmp = np.max(points)
points = points / mmp

raw_hull = alphashape.alphashape(points, alpha)

raw_hull = trimesh.boolean.union([raw_hull, b_hull])
# copy
hull02_cc = raw_hull.copy()

# keep only largest component
hull02_cc = hull02_cc.split(only_watertight=False)
hull02_sm = max(hull02_cc, key=lambda m: m.area)

# fill holes
hull02_sm.fill_holes()

# smooth
hull02_sm = trimesh.smoothing.filter_laplacian(hull02_sm, iterations=smoothing_strength)

# resample
n_faces = np.min([n_faces, hull02_sm.faces.shape[0]-1])
hull02_rs = hull02_sm.simplify_quadric_decimation(face_count=n_faces)
hull02_rs = hull02_rs.split(only_watertight=False)
hull02_rs = max(hull02_rs, key=lambda m: m.area)
hull02_rs.fill_holes()
hull02_rs.fix_normals()

vt = hull02_rs.vertices
vt = vt * mmp
vt = vt + mp
hull02_rs.vertices = vt

# check
wt_flag = hull02_rs.is_watertight


In [46]:
_, lines, mesh = plot_mesh(hull02_rs, surf_alpha=1)

fig = go.Figure()
fig.add_trace(mesh)
fig.update_scenes(aspectmode='data')
fig.show()

### Alternative idea: take convex full of points below the yolk and use these

In [32]:
# below_ids = fin_df.loc[fin_df["yolk_dist"]<0, "nucleus_id"].to_numpy()
depth_dist = np.sum(np.multiply(fin_df[["XP", "YP", "ZP"]].to_numpy()-surf_center_o, surf_normal[np.newaxis,:]), axis=1)
below_ids = fin_df.loc[depth_dist<=0, "nucleus_id"].to_numpy()
below_filter = np.isin(nc_vec_us, below_ids)
b_points = fin_df_upsamp.loc[below_filter, ["XP", "YP", "ZP"]].to_numpy()

# re-orient points to yolk frrame of reference
rotation_matrix = np.stack([dv_vec_base, np.asarray([0, 1, 0]), surf_vec_rel], axis=1)
# fin_points_ro = np.matmul(fin_points - surf_center_o, new_axes.T)
b_points_ro = np.dot(b_points - surf_center_o, rotation_matrix)
b_points_ro[:, 1] = b_points_ro[:, 1] - np.mean(fin_points_ro[:, 1])
# convert to point cloud format
pcd_b = o3d.geometry.PointCloud()
pcd_b.points = o3d.utility.Vector3dVector(b_points_ro)

b_sampled_points = pcd_b.voxel_down_sample(voxel_size=min_distance) 
b_points_u = np.asarray(b_sampled_points.points)
b_points_u

array([[ 32.23355256,  13.66454742,  -9.49374207],
       [ 25.6585721 ,  10.79448794,  -9.9301423 ],
       [ 29.28935039,  14.00929018, -11.92460741],
       ...,
       [ 10.71889683, -27.67228925,  -0.34156538],
       [  1.68470554, -32.58650128,  -1.25814373],
       [ -5.68975825, -29.64908354,  -1.12727254]])

In [37]:
from scipy.spatial import ConvexHull

# Compute the convex hull
hull = ConvexHull(b_points_u)

# Create a Trimesh mesh from the convex hull
vertices = b_points_u #[hull.vertices]
faces = hull.simplices  # Faces of the convex hull
b_hull = trimesh.Trimesh(vertices=vertices, faces=faces)

b_hull.fix_normals()   
b_hull = b_hull.convex_hull   # Recompute as a watertight convex hull

In [38]:
fig, _, _ = plot_mesh(b_hull)
fig.update_scenes(aspectmode='data')
fig.show()

In [40]:


combined_mesh2 = trimesh.boolean.union([fin_hull, b_hull])

_, lines, mesh0 = plot_mesh(combined_mesh2, surf_alpha=1)
# _, lines, mesh0 = plot_mesh(fin_hull, surf_alpha=1)
# _, lines, mesh1 = plot_mesh(b_hull, surf_alpha=1)

fig = go.Figure()
fig.add_trace(mesh0)
# fig.add_trace(mesh1)
fig.update_scenes(aspectmode='data')
fig.show()

In [49]:
cb_sm = trimesh.smoothing.filter_laplacian(combined_mesh2, iterations=4)

In [50]:
_, lines, mesh0 = plot_mesh(cb_sm, surf_alpha=1)
# _, lines, mesh0 = plot_mesh(fin_hull, surf_alpha=1)
# _, lines, mesh1 = plot_mesh(b_hull, surf_alpha=1)

fig = go.Figure()
fig.add_trace(mesh0)
# fig.add_trace(mesh1)
fig.update_scenes(aspectmode='data')
fig.show()

### Run spherical mapping

In [9]:
# normalize the scaling of the mesh
v = fs.normalize(v)

# run the spherical mapping flow and mobius centering
sv = fs.sphere_map(v, f)

# Now we have a spherical mesh
mp.plot(sv, f, shading = {"wireframe":True})

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.000453…

<meshplot.Viewer.Viewer at 0x79ab9a54a830>

### Calculate the mean curvature

In [10]:
rho = fs.curvature_function(v, sv, f)

mp.plot(v,f, rho )

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0033084…

<meshplot.Viewer.Viewer at 0x79ab9a51d480>

## SH decomposition

In [11]:
# this utility does the above steps + SH decomposition
# Here, using maximum degree 24
weights, Y_mat, vs = fs.do_mapping(v,f, l_max = 24)

In [12]:
# This is the array of SH weights
np.set_printoptions(threshold = 100)
print(weights)

[ 3.1985965  -0.14189246 -0.09633128 ...  0.18525062 -0.10537151
  0.13012568]


In [13]:
# Y_mat is a matrix used to convert between weights and mesh function
rho2 = Y_mat.dot(weights)
mp.plot(sv,f, c = rho2)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.000453…

<meshplot.Viewer.Viewer at 0x79ab9a65e110>

### Mesh reconstruction

In [14]:
rec2 = fs.reconstruct_shape(sv, f, rho2 )
mp.plot(rec2,f, c = rho2)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.050321…

<meshplot.Viewer.Viewer at 0x79ab9a4f3400>

### Test with lower frequencies only

In [18]:
weights, Y_mat, vs = fs.do_mapping(v,f, l_max=4)
rec_8 = fs.reconstruct_shape(sv, f, Y_mat.dot(weights) )
mp.plot(rec_8, f, c = rho2)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.078880…

<meshplot.Viewer.Viewer at 0x79ab9a3d5810>