# Notebook to experiment with ways to quantify pec fin morphology

### Import packages

In [29]:
import numpy as np
from napari_animation import Animation
import pandas as pd

# set parameters
filename = "2022_12_15 HCR Hand2 Tbx5a Fgf10a_3"
readPath = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/built_zarr_files/" + filename + '_nucleus_props.csv'

image_df = pd.read_csv(readPath, index_col=0)


### Now, load an image dataset along with nucleus masks inferred using cellpose.

In [None]:
#import open3d as o3d
import plotly.express as px
import plotly.graph_objects as go
from sklearn.cluster import KMeans, DBSCAN, OPTICS
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(np.asarray(image_df.iloc[:, 0:3]))
# Get points and transform it to a numpy array:
points = np.asarray(image_df.iloc[np.where(image_df["pec_fin_flag"]==1), 0:3]).copy()

# convert to point cloud

## Let's see what the raw data look like

In [None]:
r_vec = np.sqrt(image_df["X"]**2 + image_df["Y"]**2 + image_df["Z"]**2)

px.scatter_3d(image_df, x="X", y="Y", z="Z",opacity=0.5,color_continuous_scale="ice",color="Z",template="ggplot2")


#### Use Z to filter for points that lie on yolk surface (crude, but we will do better in future)

In [None]:
z_filter = np.where((~image_df["pec_fin_flag"]) & (image_df["Z"]<=35))

image_df_z = image_df.iloc[z_filter].copy()

px.scatter_3d(image_df_z, x="X", y="Y", z="Z",opacity=0.5,color_continuous_scale="ice",
              template="ggplot2")

### Fit sphere

In [None]:
# define a function to fit a sphere to points
import math
def sphereFit(spX,spY,spZ):
    #   Assemble the A matrix
    spX = np.array(spX)
    spY = np.array(spY)
    spZ = np.array(spZ)
    A = np.zeros((len(spX),4))
    A[:,0] = spX*2
    A[:,1] = spY*2
    A[:,2] = spZ*2
    A[:,3] = 1

    #   Assemble the f matrix
    f = np.zeros((len(spX),1))
    f[:,0] = (spX*spX) + (spY*spY) + (spZ*spZ)
    C, residules, rank, singval = np.linalg.lstsq(A,f)

    #   solve for the radius
    t = (C[0]*C[0])+(C[1]*C[1])+(C[2]*C[2])+C[3]
    radius = math.sqrt(t)

    return radius, C[0], C[1], C[2]

In [None]:
r, x0, y0, z0 = sphereFit(image_df_z["X"], image_df_z["Y"], image_df_z["Z"])
u, v = np.mgrid[0:2*np.pi:200j, 0:np.pi:100j]
x=np.cos(u)*np.sin(v)*r
y=np.sin(u)*np.sin(v)*r
z=np.cos(v)*r
x = x + x0
y = y + y0
z = z + z0

p_filter = np.where(image_df["pec_fin_flag"])
# plot result
fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, opacity=0.8)])

# fig = px.scatter_3d(x=XX, y=YY, z=ZZ,opacity=0.5)

fig.update_layout(
    scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[0,200]), 
                 yaxis = dict(nticks=10, range=[40,160])))

fig.add_trace(go.Scatter3d(x=image_df["X"].iloc[p_filter], y=image_df["Y"].iloc[p_filter], z=image_df["Z"].iloc[p_filter], 
                           mode='markers', 
                           opacity=0.25))

### Calculate distance from sphere surface

In [None]:
from scipy.spatial import distance_matrix
from sklearn.neighbors import KDTree

# calculate approximate distance to surface
sphere_surf_array = np.asarray(np.concatenate((np.reshape(list(x.flatten()), (x.size, 1)), np.reshape(list(y.flatten()), (y.size, 1)),
                                               np.reshape(list(z.flatten()), (z.size, 1))), axis=1))

image_df_fin = image_df.iloc[p_filter]
dist_array = distance_matrix(image_df_fin[["X", "Y", "Z"]], sphere_surf_array)
min_distance = np.min(dist_array, axis=1)
min_indices = np.argmin(dist_array, axis=1)

# use NN distances between nuclei to set the scale for points that are "close" to surface
dist_ref_k = 2 
tree = KDTree(image_df_fin[["X", "Y", "Z"]], leaf_size=2)
nearest_dist, nearest_ind = tree.query(image_df_fin[["X", "Y", "Z"]], k=dist_ref_k+1)
mean_nn_dist_vec = np.mean(nearest_dist, axis=0)
nn_thresh = mean_nn_dist_vec[dist_ref_k]

# plot the results
surf_filter = min_distance <= nn_thresh

fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, opacity=0.8)])

# fig = px.scatter_3d(x=XX, y=YY, z=ZZ,opacity=0.5)

fig.update_layout(
    scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[0,200]), 
                 yaxis = dict(nticks=10, range=[40,160])))

fig.add_trace(go.Scatter3d(x=image_df_fin["X"].iloc[surf_filter], 
                           y=image_df_fin["Y"].iloc[surf_filter], 
                           z=image_df_fin["Z"].iloc[surf_filter], 
                           mode='markers', 
                           opacity=0.25))

**That looks reasonable!** Let's see if I can iteratively search for the centroid of these surface points. Following closely along with the approach outlined in this post: https://skeptric.com/calculate-centroid-on-sphere/


In [None]:
import scipy

def cart2sph(x,y,z):
    XsqPlusYsq = x**2 + y**2
    r = m.sqrt(XsqPlusYsq + z**2)               # r
    elev = m.atan2(z,m.sqrt(XsqPlusYsq))     # theta
    az = m.atan2(y,x)                           # phi
    return r, elev, az

def cart2sphA(pts):
    return np.array([cart2sph(x,y,z) for x,y,z in pts])

# define some useful functions
def geodist(x, y, eps=1e-6):
    dotprod = y.T @ x
    assert ((-1 - eps) <= dotprod).all() and (dotprod <= (1 + eps)).all()
    dotprod = dotprod.clip(-1, 1)
    return np.arccos(dotprod)

def coord_to_latlon(x, y, z, r):
    assert np.all(np.abs((x*x +y*y +z*z) - r*r) < 1e-5)
    lat = np.arcsin(z/r)
    lon = np.arctan2(y/r, x/r)
    out = np.array([lat, lon])
    return out.T

# get list of each nucleus's nearest point on the sphere's surface
sphere_nn_xyz = sphere_surf_array[min_indices[surf_filter], :]

# note that we must center and normalize array for this calculation to work
sphere_nn_xyz_centered = sphere_nn_xyz - np.reshape(np.asarray([x0, y0, z0]), (1, 3))

# convert to lat/lon notation
lat_lon_array = coord_to_latlon(sphere_nn_xyz_centered[:, 0],
                                sphere_nn_xyz_centered[:, 1],
                                sphere_nn_xyz_centered[:, 2],
                                r)

def latlon_geodist(llarray1, llarray2, eps = 1e-6):

    lat1 = llarray1[0]
    lon1 = llarray1[1]
    
    lat2 = llarray2[0]
    lon2 = llarray2[1]
    
    dotprod = np.sin(lat1) * np.sin(lat2) + np.cos(lat1) * np.cos(lat2) * np.cos(lon2 - lon1)
    assert ((-1 - eps) <= dotprod).all() and (dotprod <= (1 + eps)).all()
    dotprod = dotprod.clip(-1, 1)
    return np.arccos(dotprod)

def latlon_total_dist(point, ll_array=lat_lon_array):
    dist_list = np.array([latlon_geodist(point, ll_array[i, :]) for i in range(ll_array.shape[0])])
    return dist_list**2

# now let's find the point that minimizes latlon distances
start_i = np.random.randint(0, lat_lon_array.shape[0], 1)[0]
start_point = lat_lon_array[start_i, :]

print(np.mean(lat_lon_array, axis=0))

test = scipy.optimize.least_squares(latlon_total_dist, np.array(start_point))

print(test.x)
# print(np.mean(lat_lon_array, axis=0))

# convert solution back to xyz
def latlon_to_coord(lat, lon):
    return np.array([np.cos(lat) * np.cos(lon), np.cos(lat)*np.sin(lon), np.sin(lat)])

xyz_sol_norm = latlon_to_coord(test.x[0], test.x[1])
center_vec = np.asarray([x0, y0, z0]).T
xyz_sol = xyz_sol_norm.T * r + center_vec
print(xyz_sol_norm)

fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, opacity=0.5)])

# fig = px.scatter_3d(x=XX, y=YY, z=ZZ, opacity=0.5)

# fig.update_layout(
#     scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[0,200]), 
#                  yaxis = dict(nticks=10, range=[40,160])))

fig.add_trace(go.Scatter3d(x=image_df_fin["X"].iloc[surf_filter], 
                           y=image_df_fin["Y"].iloc[surf_filter], 
                           z=image_df_fin["Z"].iloc[surf_filter], 
                           mode='markers', 
                           opacity=0.25))

fig.add_trace(go.Scatter3d(x=[xyz_sol[0,0]], 
                           y=[xyz_sol[0,1]], 
                           z=[xyz_sol[0,2]], 
                           mode='markers', 
                           opacity=1))


### Nice! Now plot normal vector implied by this fit alongside full pec fin

In [None]:
norm_array = np.concatenate((xyz_sol, 1.5*(xyz_sol-center_vec)+center_vec), axis=0)

fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, opacity=0.3)])


# fig.update_layout(
#     scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[0,200]), 
#                  yaxis = dict(nticks=10, range=[40,160])))

fig.add_trace(go.Scatter3d(x=image_df_fin["X"], 
                           y=image_df_fin["Y"], 
                           z=image_df_fin["Z"], 
                           mode='markers', 
                           opacity=0.05))

fig.add_trace(go.Scatter3d(x=norm_array[:, 0], 
                           y=norm_array[:, 1], 
                           z=norm_array[:, 2], 
                           mode='lines', 
                           opacity=1))

fig.add_trace(go.Scatter3d(x=[xyz_sol[0,0]], 
                           y=[xyz_sol[0,1]], 
                           z=[xyz_sol[0,2]], 
                           mode='markers', 
                           opacity=1))

## Can we build off of this procedure to infer an approximate PD axis?
First, let's use the normal axis inferred in the previous step to divide the fin up into N cross-sections

In [None]:
# start by finding the point below the centroid on the surface of the sphere
normal_vector = norm_array[1, :] - norm_array[0, :]
normal_vector_u = normal_vector / np.sqrt(np.sum(normal_vector**2))
normal_vector_u = normal_vector_u
surf_point = xyz_sol

# generate numpy array 
xyz_array = np.asarray(image_df_fin[["X", "Y", "Z"]])

# solve for initial D value
d_surf = -(np.dot(normal_vector_u, surf_point.T))

# calulate how many points are below plane
ab_signs = np.matmul(normal_vector_u, xyz_array.T) + d_surf

# calculate the maximum d value
d_vec = -np.matmul(normal_vector_u, xyz_array.T)
d_top = np.min(d_vec)
d_bottom = np.max(d_vec)

# check that everything is contained
ab_signs2 = np.matmul(normal_vector_u, xyz_array.T) + d_top
print(np.mean(ab_signs2<=0))

# divide nuclei into N slices
n_slices = 8
d_increments = np.linspace(d_surf, d_top, n_slices)
d_increments = np.insert(d_increments, 0, d_bottom)
slice_ids = []
for n in range(n_slices):
    d_start = d_increments[n]
    d_stop = d_increments[n+1]
    slice_ids.append(np.where((d_vec>(d_stop)) & (d_vec<=(d_start))))
    

fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, opacity=0.3)])


fig.update_layout(
    scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[0,200]), 
                 yaxis = dict(nticks=10, range=[40,160])))

for n in range(n_slices):

    fig.add_trace(go.Scatter3d(x=image_df_fin["X"].iloc[slice_ids[n]], 
                               y=image_df_fin["Y"].iloc[slice_ids[n]], 
                               z=image_df_fin["Z"].iloc[slice_ids[n]], 
                               mode='markers', 
                               opacity=1))

fig.update_layout(showlegend=False)

fig.show()

**This looks reasonably promising.**
Let's calculate centroids for each slice. These can then serve as the basis for defining an approximate axis. From there, the hope is to iteratively update slices and centroids 

In [None]:
centroid_array = np.empty((n_slices+1, 3))
for n in range(n_slices):
    centroid_array[n, :] = np.mean(xyz_array[slice_ids[n]], axis=0)
    
# find point that is farthest from base and call this the tip (for now)
dist_vec = np.sqrt(np.sum((xyz_array - xyz_sol)**2, axis=1))
low_indices = np.where(d_vec<=0.75*d_top)
dist_vec[low_indices] = 0
d_top_i = np.argmax(dist_vec)
centroid_array[-1, :] = xyz_array[d_top_i, :]

# plot centroids
fig = go.Figure(data=[go.Surface(x=x, y=y, z=z, opacity=0.3)])


fig.update_layout(
    scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[-20,200]), 
                 yaxis = dict(nticks=10, range=[40,200])))
    
fig.add_trace(go.Scatter3d(x=centroid_array[:, 0], 
                           y=centroid_array[:, 1], 
                           z=centroid_array[:, 2], 
                           mode='lines', 
                           opacity=1))

fig.add_trace(go.Scatter3d(x=centroid_array[:, 0], 
                           y=centroid_array[:, 1], 
                           z=centroid_array[:, 2], 
                           mode='markers', 
                           opacity=1))

fig.add_trace(go.Mesh3d(x=xyz_array[:, 0], y=xyz_array[:, 1], z=xyz_array[:, 2],
                                    alphahull=9,
                                    opacity=0.1,
                                    color='gray'))

fig.show()

**Calculate direction of PCA for each slice.** I will take this to be the approximate AP direction. Along each of these AP axes, perform a similar procedure to above--find centroid of each little "chunk"

In [None]:
n_slices_ap = 5
ap_slice_master = []
slice_centroid_array = np.empty((n_slices, 3, n_slices_ap))

# calculate approximate A-P axis
ap_array = np.empty((n_slices, 3))

for n in range(n_slices):
    datapoints = xyz_array[slice_ids[n]]

    # Do an SVD on the mean-centered data.
    datamean = np.mean(datapoints, axis=0)
    uu, dd, vv = np.linalg.svd(datapoints - datamean)

    ap_array[n, :] = vv[0]
    
    
ap_axis_mean = np.mean(ap_array, axis=0)
ap_axis_mean = ap_axis_mean / np.linalg.norm(ap_axis_mean)

for n in range(n_slices):

    # Now vv[0] contains the first principal component, i.e. the direction
    # vector of the 'best fit' line in the least squares sense.
    # print(vv[0])
    datapoints = xyz_array[slice_ids[n]]
    
    # calculate the min and maximum d value
    d_vec_ap = -np.matmul(ap_axis_mean, datapoints.T)
    
    d_a = np.min(d_vec_ap)
    d_p = np.max(d_vec_ap)

    d_inc_ap = np.linspace(d_a, d_p, n_slices_ap+1)

    ap_slice_ids = []
    bins = np.digitize(d_vec_ap, d_inc_ap)
    for m in range(n_slices_ap):
        # add points
        ap_slice_ids.append(np.where(bins == m+1))

        # calculate centroid
        centroid_temp = np.mean(datapoints[ap_slice_ids[m]], axis=0)
        slice_centroid_array[n, :, m] = centroid_temp
        
    ap_slice_master.append(ap_slice_ids)

  

fig = px.scatter_3d(image_df_fin, 
                    x="X", y="Y", z="Z",
                    opacity=0.1,
                    template="ggplot2")

# fig.add_trace(go.Scatter3d(x=datapoints[:, 0], 
#                                y=datapoints[:, 1], 
#                                z=datapoints[:, 2], 
#                                mode='markers', 
#                                opacity=1))

x_vec = slice_centroid_array[:-1, 0, :].flatten()
y_vec = slice_centroid_array[:-1, 1, :].flatten()
z_vec = slice_centroid_array[:-1, 2, :].flatten()

# add tip point
x_vec = np.append(x_vec, xyz_array[d_top_i, 0])
y_vec = np.append(y_vec, xyz_array[d_top_i, 1])
z_vec = np.append(z_vec, xyz_array[d_top_i, 2])

for n in range(n_slices):
#     x_vec.append(slice_centroid_array[n, 0, :].flatten())
#     y_vec.append(slice_centroid_array[n, 1, :].flatten())
#     z_vec.append(slice_centroid_array[n, 0, :].flatten())
    
    fig.add_trace(go.Scatter3d(x=slice_centroid_array[n, 0, :].flatten(), 
                               y=slice_centroid_array[n, 1, :].flatten(), 
                               z=slice_centroid_array[n, 2, :].flatten(), 
                               mode='markers', 
                               opacity=1))

fig.add_trace(go.Scatter3d(x=[x_vec[-1]], 
                           y=[y_vec[-1]], 
                           z=[z_vec[-1]], 
                           mode='markers', 
                           opacity=1))

# fig.add_trace(go.Scatter3d(x=linepts[:, 0], 
#                            y=linepts[:, 1], 
#                            z=linepts[:, 2], 
#                            mode='lines', 
#                            opacity=1))

fig.show()

In [None]:
print(vv[0])
print(slice_centroid_array[0, 0, :].flatten())
print(slice_centroid_array[n, 0, :].flatten())

**Not amazing, but it's a good start.** What does a fit to this look like?

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
import alphashape
import scipy as sp
import scipy.interpolate

from sklearn.decomposition import PCA

# Run The PCA
pca = PCA(n_components=3)
xyz_grid = np.concatenate((x_vec[:, np.newaxis], y_vec[:, np.newaxis], z_vec[:, np.newaxis]), axis=1)
pca.fit(xyz_grid)
pca_grid = pca.transform(xyz_grid)

# transform all fin points
pca_array_full = pca.transform(xyz_array)

# fit spline in spca space
pca_spline = sp.interpolate.Rbf(pca_grid[:, 0], pca_grid[:, 1], pca_grid[:, 2],
                            function='linear', smooth=1)

grid_res = 100
P0, P1 = np.meshgrid(np.linspace(np.min(pca_array_full[:,0]), np.max(pca_array_full[:,0]), grid_res), 
                     np.linspace(np.min(pca_array_full[:,1]), np.max(pca_array_full[:,1]), grid_res))

P2_spline = pca_spline(P0,P1)

# fit quadratic surface, just for kicks
A = np.c_[np.ones(pca_grid.shape[0]), pca_grid[:,:2], np.prod(pca_grid[:,:2], axis=1), pca_grid[:,:2]**2]
C,_,_,_ = scipy.linalg.lstsq(A, pca_grid[:,2])
    
# evaluate it on a grid
PP0 = P0.flatten()
PP1 = P1.flatten()
P2_curve = np.dot(np.c_[np.ones(PP0.shape), PP0, PP1, PP0*PP1, PP0**2, PP1**2], C).reshape(P0.shape)



# transform the result back to xyz
xyz_fit_spline = pca.inverse_transform(np.concatenate((np.reshape(P0, (P0.size, 1)), 
                                                np.reshape(P1, (P1.size, 1)), 
                                                np.reshape(P2_spline, (P2_spline.size, 1))), 
                                                axis=1)
                               )

xyz_fit_curve = pca.inverse_transform(np.concatenate((np.reshape(P0, (P0.size, 1)), 
                                                      np.reshape(P1, (P1.size, 1)), 
                                                      np.reshape(P2_curve, (P2_curve.size, 1))), 
                                                      axis=1)
                               )

X_fit = np.reshape(xyz_fit_curve[:, 0], (P0.shape))
Y_fit = np.reshape(xyz_fit_curve[:, 1], (P1.shape))
Z_fit_spline = np.reshape(xyz_fit_spline[:, 2], (P2_spline.shape))
Z_fit_curve = np.reshape(xyz_fit_curve[:, 2], (P2_curve.shape))


############
# keep only points that are inside the fin
xyz_array_norm = np.divide(xyz_array, np.asarray([np.max(xyz_array[:, 0]), 
                                                  np.max(xyz_array[:, 1]), 
                                                  np.max(xyz_array[:, 2])]))
# generate normalized arrays
X_norm = X_fit / np.max(xyz_array[:, 0])
Y_norm = Y_fit / np.max(xyz_array[:, 1])
Z_norm = Z_fit_curve / np.max(xyz_array[:, 2])

alpha_fin = alphashape.alphashape(xyz_array_norm, 4)

xyz_long = np.concatenate((np.reshape(X_norm, (X_norm.size, 1)),
                           np.reshape(Y_norm, (X_norm.size, 1)),
                           np.reshape(Z_norm, (X_norm.size, 1))),
                          axis=1)

inside_flags = alpha_fin.contains(xyz_long)
inside_mat = np.reshape(inside_flags,(P0.shape))

Z_fit_curve_filt = Z_fit_curve
Z_fit_curve_filt[np.where(~inside_mat)] = np.nan


fig = go.Figure(data=[go.Surface(x=X_fit, y=Y_fit, z=Z_fit_curve_filt, opacity=0.75)])
# fig.add_trace(go.Scatter3d(x=x_vec, y=y_vec, z=z_vec, 
#                            mode='markers', 
#                            opacity=1))

fig.add_trace(go.Mesh3d(x=xyz_array[:, 0], y=xyz_array[:, 1], z=xyz_array[:, 2],
                                    alphahull=9,
                                    opacity=0.25,
                                    color='gray'))

fig.show()

fig.write_html("surface_fit.html")

**Now let's see if I can define a principal P-D axis.**

In [None]:
# Use the sphere to define plane that can serve as the center PD "spine" that defines tha A-P coordinates

# # get just points inside fin
# xyz_fit_fin = xyz_fit_curve[np.where(inside_flags==1)]

# # define function to calculate point displacement from proposed plane
# def calculate_ap_plane_diffs(vec_prop, sphere_norm_vec=normal_vector_u, base_point=xyz_sol[0], xyz_points=xyz_array):
    
#     normal_vec = np.cross(vec_prop, sphere_norm_vec)
#     normal_vec = normal_vec/np.sqrt(np.sum(normal_vec**2))
    
#     D = -np.dot(base_point, normal_vec.T) 
    
#     # calculate the min and maximum d value
#     d_vec_ap = np.matmul(normal_vec, xyz_points.T) + D
    
#     return d_vec_ap


# def loss_fun(vec_prop):
    
#     ap_diffs = calculate_ap_plane_diffs(vec_prop)
    
#     return ap_diffs

# # find the best-fitting AP normal vector
# init_guess = np.cross(ap_axis_mean, normal_vector_u)
# out = scipy.optimize.least_squares(loss_fun, init_guess)
# d_vec_ap_centered = calculate_ap_plane_diffs(out.x)

# # val_vec = np.empty((xyz_fit_curve.shape[0]))
# # val_vec[np.where(inside_flags==1)] = d_vec_ap_centered
# # val_vec[np.where(inside_flags==0)] = np.nan
# # val_array = np.reshape(val_vec, (X_fit.shape))

# # generate plane to check solution
# grid_res = 100
# X, Y = np.meshgrid(np.linspace(np.min(xyz_array[:,0]), np.max(xyz_array[:,0]), grid_res), 
#                    np.linspace(np.min(xyz_array[:,1]), np.max(xyz_array[:,1]), grid_res))

# # calculate normal
# normal_vec = np.cross(out.x, normal_vector_u)
# normal_vec = normal_vec/np.sqrt(np.sum(normal_vec**2))

# # calculate D
# D = -np.dot(xyz_sol[0], normal_vec.T) 

# # predict z points
# Z_pd = -(normal_vec[0]*X + normal_vec[1]*Y + D)/normal_vec[2]

# fig = go.Figure(data=[go.Surface(x=X_fit, y=Y_fit, z=Z_fit_curve_filt, 
#                                  opacity=0.85)])

# # fig = px.scatter_3d(x=xyz_fit_fin[:, 0], y=xyz_fit_fin[:, 1], z=xyz_fit_fin[:, 2],
# #                     opacity=0.5,template="ggplot2") 

# fig.add_trace(go.Surface(x=x, y=y, z=z, opacity=0.1))

# # fig.add_trace(go.Surface(x=X, y=Y, z=Z_pd, opacity=1))

# fig.add_trace(go.Scatter3d(x=[xyz_sol[0][0]], 
#                            y=[xyz_sol[0][1]], 
#                            z=[xyz_sol[0][2]], 
#                            mode='markers', 
#                            opacity=1))


# # fig.update_layout(
# #     scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[50,160]), 
# #                  yaxis = dict(nticks=10, range=[50,160])))
          
# fig.show()

**Least squares fitting did not work.** So, instead, I will define the AP plane as the plane containing (i) the sphere normal and (ii) the vector normal to the fin surface at the base centroid point

In [None]:
# get just points inside fin
xyz_fit_fin = xyz_fit_curve[np.where(inside_flags==1)]

# Find a normal vector from the surface that meets the centroid point (closest point)
distance_list = np.sqrt(np.sum((xyz_sol[0]-xyz_fit_fin)**2, axis=1))
min_i = np.argmin(distance_list)
min_dist = np.min(distance_list)
print(min_dist)

# the vector from the surface point to the centroid defines a normal vector from surface to point
slice_vec = xyz_sol[0] - xyz_fit_fin[min_i, :]
print(slice_vec)

# we can combine this with the normal vec from the sphere to define an AP plane
plane_normal = np.cross(slice_vec, normal_vector_u)
plane_normal_u = plane_normal / np.sqrt(np.sum(plane_normal**2))

D_plane = -np.dot(xyz_fit_fin[min_i, :], plane_normal_u.T) 

# generate plane to check solution
grid_res = 100
X, Y = np.meshgrid(np.linspace(np.min(xyz_array[:,0]), np.max(xyz_array[:,0]), grid_res), 
                   np.linspace(np.min(xyz_array[:,1]), np.max(xyz_array[:,1]), grid_res))

# predict z points
Z_pd = -(plane_normal_u[0]*X + plane_normal_u[1]*Y + D_plane)/plane_normal_u[2]

fig = go.Figure(data=[go.Surface(x=X_fit, y=Y_fit, z=Z_fit_curve_filt, 
                                 opacity=0.85)])

# fig = px.scatter_3d(x=xyz_fit_fin[:, 0], y=xyz_fit_fin[:, 1], z=xyz_fit_fin[:, 2],
#                     opacity=0.5,template="ggplot2") 

fig.add_trace(go.Surface(x=x, y=y, z=z, opacity=0.1))

fig.add_trace(go.Surface(x=X, y=Y, z=Z_pd, opacity=1))

fig.add_trace(go.Scatter3d(x=[xyz_sol[0][0]], 
                           y=[xyz_sol[0][1]], 
                           z=[xyz_sol[0][2]], 
                           mode='markers', 
                           opacity=1))


# fig.update_layout(
#     scene = dict(zaxis = dict(nticks=10, range=[0,100]), xaxis = dict(nticks=10, range=[50,160]), 
#                  yaxis = dict(nticks=10, range=[50,160])))
          
fig.show()

**This looks good.** All AP positions can now be defined as the distance from each nucleus's projection onto the surface to this central plane. Importantly, we want the distance ALONG the surface, rather than the simple euclidean distance. For the PD reference, we will use the intersection of the surface with the sphere

In [None]:
# use average NN distance as a reference
dist_ref_k = 2 
tree = KDTree(xyz_fit_fin, leaf_size=2)
nearest_dist, nearest_ind = tree.query(xyz_fit_fin, k=3)
mean_nn_dist_vec = np.mean(nearest_dist, axis=0)
max_dist = mean_nn_dist_vec[1]*1

# identify points that fall along the AP center and PD base
dist_vec_ap = np.matmul(plane_normal_u, xyz_fit_fin.T) + D_plane
ap_ref_indices = np.where(np.abs(dist_vec_ap)<=max_dist)
ap_signs = np.sign(dist_vec_ap)

# identify points near base of fin
sphere_dist = np.sqrt(np.sum((xyz_fit_fin - center_vec[0])**2, axis=1)) - r
pd_ref_indices = np.where(np.abs(sphere_dist)<=max_dist)
pd_signs = np.sign(sphere_dist)

fig = go.Figure(data=[go.Surface(x=X_fit, y=Y_fit, z=Z_fit_curve_filt, 
                                 opacity=0.85)])

# fig = px.scatter_3d(x=xyz_fit_fin[:, 0], y=xyz_fit_fin[:, 1], z=xyz_fit_fin[:, 2],
#                     opacity=0.5,template="ggplot2") 

fig.add_trace(go.Surface(x=x, y=y, z=z, opacity=0.4))

#fig.add_trace(go.Surface(x=X, y=Y, z=Z_pd, opacity=1))

fig.add_trace(go.Scatter3d(x=xyz_fit_fin[ap_ref_indices, 0][0], 
                           y=xyz_fit_fin[ap_ref_indices, 1][0], 
                           z=xyz_fit_fin[ap_ref_indices, 2][0], 
                           mode='markers', 
                           opacity=1))

fig.add_trace(go.Scatter3d(x=xyz_fit_fin[pd_ref_indices, 0][0], 
                           y=xyz_fit_fin[pd_ref_indices, 1][0], 
                           z=xyz_fit_fin[pd_ref_indices, 2][0], 
                           mode='markers', 
                           opacity=1))

fig.show()

**Now, calculate the 2D projection for each fin point.** The distance from the point to the surface gives DV position. From there, we can estimate the AP and PD distance using the reference axes calculated above

In [None]:
import pyvista as pv
points = pv.wrap(xyz_fit_fin)
fin_surf = points.delaunay_2d()
# fin_surf.plot(show_edges=True)

# pl = pv.Plotter(shape=(1, 2))
# pl.add_mesh(points)
# pl.add_title('Point Cloud of 3D Surface')
# pl.subplot(0, 1)
# pl.add_mesh(surf, color=True, show_edges=True)
# pl.add_title('Reconstructed Surface')
# pl.show()
# print(grid)

In [None]:
p1 = fin_surf.find_closest_point(xyz_array[100, :])
p2 = fin_surf.find_closest_point(xyz_array[200, :])


geo_distance = fin_surf.geodesic_distance(p1, p2)
print(geo_distance)
linear_distance = np.sqrt(np.sum((fin_surf.points[p1, :] - fin_surf.points[p2, :])**2))
print(linear_distance)


**Better, but there are still exceptions.** I think the only way to truly nail this is to perform the assignments using a greedy sweep approach that starts at the bottom and proceeds up to the tip. Not clear how best to implement this algorithm. Definitely not work for the plane.

Ok, so I figured it out:
1. Assign P-D positions to each point on the surface (negative values for things below the sphere)
2. Divide these points into relatively fine P-D groups
3. Starting from the proximal-most group, use the points to define top and bottom planes. Consider only nuclei that are between these boundaries
4. Assign subgroup of nuclei to nearest surface point that is a member of the P-D group
5. Repeat for each group, disallowing reassignments of previously assigned nuclei

This should be reasonably fast still, and should avoid the pitfalls uncovered above

In [None]:
# assign AP position to all points on fin surface
surf_point_pd_list = []
surf_point_ap_list = []

for surf_i in range(fin_surf.points.shape[0]):

    ###################################
    # PD
    ###################################
    # Let's assume that closest point in euclidian space is also closest along geodesix
    # I think this should be true. Need to confirm
    dist_array = np.sqrt(np.sum((fin_surf.points[surf_i, :]-fin_surf.points[pd_ref_indices, :][0])**2, axis=1))
    closest_i_raw = np.argmin(dist_array)
    
    geo_path = fin_surf.geodesic(surf_i, pd_ref_indices[0][closest_i_raw])
#     geo_dist = fin_surf.geodesic_distance(surf_i, pd_ref_indices[0][closest_i_raw])
    point_array = np.concatenate((np.reshape(fin_surf.points[surf_i, :],(1,3)), 
                                  geo_path.points, 
                                  np.reshape(fin_surf.points[pd_ref_indices[0][closest_i_raw], :], (1,3))), axis=0)
    
    dist_array = np.sqrt(np.sum((point_array[:-1, :] - point_array[1:, :])**2, axis=1))

    surf_point_pd_list.append(np.sum(dist_array)*pd_signs[surf_i])
    
    ###################################
    # AP
    ###################################
    dist_array_ap = np.sqrt(np.sum((fin_surf.points[surf_i, :]-fin_surf.points[ap_ref_indices, :][0])**2, axis=1))
    closest_i_raw_ap = np.argmin(dist_array_ap)
    
    geo_path_ap = fin_surf.geodesic(surf_i, ap_ref_indices[0][closest_i_raw_ap])
    point_array_ap = np.concatenate((np.reshape(fin_surf.points[surf_i, :],(1,3)), 
                                     geo_path_ap.points, 
                                     np.reshape(fin_surf.points[ap_ref_indices[0][closest_i_raw_ap], :], (1,3))), axis=0)
    
    dist_array = np.sqrt(np.sum((point_array_ap[:-1, :] - point_array_ap[1:, :])**2, axis=1))

    surf_point_ap_list.append(np.sum(dist_array)*ap_signs[surf_i])


surf_point_pd_list = np.asarray(surf_point_pd_list)
surf_point_ap_list = np.asarray(surf_point_ap_list)

In [None]:
# check that this makse sense

fig = px.scatter_3d(x=fin_surf.points[:, 0],
                           y=fin_surf.points[:, 1],
                           z=fin_surf.points[:, 2],
                           color=surf_point_pd_list,
                                         opacity=0.2)

fig.update_coloraxes(showscale=True)

fig.show()

In [None]:
print(len(np.unique(surf_point_pd_list)))

**Now, discretize the PD axis into bins,** and use this to iteratively assign PD positions to all fin nuclei

In [None]:
# compute surface normals
fin_surf.compute_normals(cell_normals=False, point_normals=True, inplace=True)
normal_array = fin_surf["Normals"]
surf_point_pd_arr = np.asarray(surf_point_pd_list)

# group surface points into bins
n_bins = 10
pd_bins = np.linspace(np.min(surf_point_pd_arr), np.max(surf_point_pd_arr), n_bins+1)
surf_pd_indices = np.digitize(surf_point_pd_arr, pd_bins)

group_vec_surf = np.ones((fin_surf.points.shape[0],))-2
group_vec_nuclei = np.ones((xyz_array.shape[0],))-2

bio_pos_array = np.empty((xyz_array.shape))

for n in range(n_bins):
    
    # find AP-centered points closest to PD boundaries
    pd_curr = pd_bins[n+1]
    pd_prev = pd_bins[n]
    
    pd_dists_curr = np.abs(surf_point_pd_list - pd_curr)
    pd_dists_prev = np.abs(surf_point_pd_list - pd_prev)
    curr_dist_vec = np.sqrt(surf_point_ap_list**2 + pd_dists_curr**2)
    prev_dist_vec = np.sqrt(surf_point_ap_list**2 + pd_dists_prev**2)
    
    # get indices
    curr_ind = np.argsort(curr_dist_vec)[0]
    prev_ind = np.argsort(prev_dist_vec)[0]
    
    
    # get all points within this region
    pd_points_surf = fin_surf.points[np.where(surf_pd_indices == n+1)]
    
    pca_temp = PCA(n_components=3)
    pca_temp.fit(pd_points_surf)
    
    # get points
    curr_point = fin_surf.points[curr_ind, :]
    prev_point = fin_surf.points[prev_ind, :]
    
    pd_vec = curr_point - prev_point
    pd_vec_u = pd_vec / np.sqrt(np.sum(pd_vec**2))
    
    vec2 = pca_temp.components_[0]
    vec1 = pca_temp.components_[1]
    plane_vec = np.cross(vec1, vec2)
    plane_vec_u = plane_vec / np.sqrt(np.sum(plane_vec**2))
    
    # enforce that normal plane vec must point "up" along PD axis
    dp = np.dot(plane_vec_u,pd_vec_u)
    pd_angle = np.arccos(dp)
    
    if pd_angle > np.pi/2:
        plane_vec_u = -plane_vec_u
    
    # get d values for each
    D_p = -np.sum(np.multiply(plane_vec_u, curr_point))
    
    # calulate how many points are between planes
    p_signs_surf = np.sign(np.matmul(plane_vec_u, fin_surf.points.T) + D_p)
    p_signs_nuclei = np.sign(np.matmul(plane_vec_u, xyz_array.T) + D_p)
    
    if n < n_bins-1:
        surf_indices = np.where((p_signs_surf==-1)&(group_vec_surf==-1))[0]
        nc_indices = np.where((p_signs_nuclei==-1)&(group_vec_nuclei==-1))[0]
    else:
        surf_indices = np.where(group_vec_surf==-1)[0]
        nc_indices = np.where(group_vec_nuclei==-1)[0]
        
    group_vec_surf[surf_indices] = n
    group_vec_nuclei[nc_indices] = n
    
    # find nearest surface point for each nucles and assign spatial positions accordingle
    dist_array = distance_matrix(xyz_array[nc_indices], fin_surf.points[surf_indices])
    dv_distances = np.min(dist_array, axis=1)
    dv_indices = np.argmin(dist_array, axis=1)
    
    # assign positions
    for i in range(len(nc_indices)):
        # assign AP and PD coordinates
        bio_pos_array[nc_indices[i], 0] = surf_point_pd_list[surf_indices[dv_indices[i]]]
        bio_pos_array[nc_indices[i], 1] = surf_point_ap_list[surf_indices[dv_indices[i]]]
        
        # we need to assign direction to 
        dv_vec = xyz_array[nc_indices[i]] - fin_surf.points[surf_indices[dv_indices[i]]]
        dv_vec_u = dv_vec / np.sqrt(np.sum(dv_vec**2))
        norm_vec = normal_array[curr_ind, :]
        norm_vec_u = norm_vec / np.sqrt(np.sum(norm_vec**2))
        dp_dv = np.dot(norm_vec_u,dv_vec_u)        
        dv_angle = np.arccos(dp_dv)
        
        if dv_angle < np.pi/2:
            bio_pos_array[nc_indices[i], 2] = dv_distances[i]
        else:
            bio_pos_array[nc_indices[i], 2] = -dv_distances[i]
    
    
    
# what do these planes look like?
Z_p = -(plane_vec_u[0]*X + plane_vec_u[1]*Y + D_p)/plane_vec_u[2]

dv_max = 30
bio_pos_array[np.where(np.abs(bio_pos_array[:, 2])>=dv_max)[0], :] = np.nan

fig = go.Figure()

# logic_array = np.zeros((np.shape(group_vec)))
# logic_array[p_signs_prev==1] = 1
# logic_array[p_signs!=-1] = 0



In [None]:
fig.add_trace(go.Scatter3d(x=xyz_array[:, 0],
                           y=xyz_array[:, 1],
                           z=xyz_array[:, 2],
                           mode='markers',
                           marker=dict(
                           color=bio_pos_array[:, 2],
                           opacity=0.2)))

fig.update_layout(title="DV Position")

fig.show()

In [None]:
fig.add_trace(go.Scatter3d(x=xyz_array[:, 0],
                           y=xyz_array[:, 1],
                           z=xyz_array[:, 2],
                           mode='markers',
                           marker=dict(
                           color=bio_pos_array[:, 1],
                           opacity=0.2)))


fig.update_layout(title="AP Position")



fig.show()

In [None]:
fig.add_trace(go.Scatter3d(x=xyz_array[:, 0],
                           y=xyz_array[:, 1],
                           z=xyz_array[:, 2],
                           mode='markers',
                           marker=dict(
                           color=bio_pos_array[:, 0],
                           opacity=0.2)))

fig.update_layout(title="PD Position")

fig.show()

In [None]:
fig = go.Figure()

# logic_array = np.zeros((np.shape(group_vec)))
# logic_array[p_signs_prev==1] = 1
# logic_array[p_signs!=-1] = 0

fig.add_trace(go.Scatter3d(x=bio_pos_array[:, 1],
                           y=bio_pos_array[:, 2],
                           z=bio_pos_array[:, 0],
                           mode='markers',
                           marker=dict(
                           opacity=0.4)))