In [1]:
import dash
import plotly.express as px
from ome_zarr.io import parse_url
from ome_zarr.reader import Reader
import plotly.graph_objs as go
import pandas as pd
import numpy as np
import json
import glob2 as glob
from skimage.measure import regionprops
import itertools
import os
from scipy.interpolate import LinearNDInterpolator
import alphashape


def load_nucleus_dataset(filename):
    # global fin_points_prev, not_fin_points_prev, class_predictions_curr, df, curationPath, propPath

    propPath = dataRoot + filename + '_nucleus_props.csv'

    if os.path.isfile(propPath):
        df = pd.read_csv(propPath, index_col=0)
    else:
        raise Exception(
            f"Selected dataset( {filename} ) dataset has no nucleus data. Have you run extract_nucleus_stats?")

#     fin_nuclei = np.where(df["pec_fin_flag"] == 1)
#     df = df.iloc[fin_nuclei]

    # normalize gene expression levels
    colnames = df.columns
    list_raw = [item for item in colnames if "_cell_mean_nn" in item]
    gene_names = [item.replace("_cell_mean_nn", "") for item in list_raw]

    for g in gene_names:
        ind_list = [i for i in range(len(colnames)) if g in colnames[i]]
        for ind in ind_list:
            colname = colnames[ind]
            c_max = np.max(df[colname])
            df[colname] = df.loc[:, colname] / c_max

    return df



In [2]:
dataRoot = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_Data/nucleus_props/"
figureRoot = "/Users/nick/Dropbox (Cole Trapnell's Lab)/Nick/pecFin/HCR_figures/"
if os.path.isdir(figureRoot) == False:
    os.makedirs(figureRoot)
    
# get list of filepaths
fileList = sorted(glob.glob(dataRoot + '*_nucleus_props.csv'))
pdNameList = []
for fn in range(len(fileList)):
    labelName = fileList[fn].replace(dataRoot, '', 1)
    labelName = labelName.replace('_nucleus_props.csv', '')
    pdNameList.append(labelName)

# compile dictionary of gene names that correspond to each dataset
gene_name_dict = {}
for f in range(len(fileList)):
    filename_temp = pdNameList[f]
    propPath = dataRoot + filename_temp + '_nucleus_props.csv'

    if os.path.isfile(propPath):
        df_temp = pd.read_csv(propPath, index_col=0)
    else:
        raise Exception(
            f"Selected dataset( {filename_temp} ) dataset has no nucleus data. Have you run extract_nucleus_stats?")

    # get list of genes that we can look at
    colnames_temp = df_temp.columns
    list_raw = [item for item in colnames_temp if "_cell_mean_nn" in item]
    gene_names_temp = [item.replace("_cell_mean_nn", "") for item in list_raw]

    gene_name_dict[filename_temp] = gene_names_temp

In [25]:
import math 
import shutil

df_ind = 11

fileName = fileList[df_ind]
imageName = pdNameList[df_ind]
plot_type = "Volume Plot"
# gene_name = gene_name_dict[imageName][df_ind]

df = load_nucleus_dataset(imageName)
df_fin = df.iloc[np.where(df["pec_fin_flag"]==2)]

# Import the necessaries libraries
import plotly.offline as pyo
import plotly.graph_objs as go
from sklearn.decomposition import PCA
from sklearn.neighbors import KDTree
import open3d as o3d 
import pickle
import alphashape
import scipy

# fit sphere to base
fin_indices = np.where(df["pec_fin_flag"]==2)
not_fin_indices = np.where(df["pec_fin_flag"]!=2)

def sphereFit_fixed_r(spX, spY, spZ, r0):
    #   Assemble the A matrix
    spX = np.array(spX)
    spY = np.array(spY)
    spZ = np.array(spZ)

    xyz_array = np.zeros((len(spX), 3))
    xyz_array[:, 0] = spX
    xyz_array[:, 1] = spY
    xyz_array[:, 2] = spZ
    c0 = np.mean(xyz_array, axis=0)
    c0[2] = -r0

    def ob_fun(c0, xyz=xyz_array, r=r0):
        res = np.sqrt((xyz[:, 0] - c0[0]) ** 2 + (xyz[:, 1] - c0[1]) ** 2 + (xyz[:, 2] - c0[2]) ** 2) - r
        return res

    C = scipy.optimize.least_squares(ob_fun, c0, bounds=([-np.inf, -np.inf, -np.inf], [np.inf, np.inf, 0]))

    return r0, C.x[0], C.x[1], C.x[2]


base_nuclei = np.where(df["pec_fin_flag"] == 1)[0]
xyz_base = df[["X", "Y", "Z"]].iloc[base_nuclei]
xyz_base = xyz_base.to_numpy()
# print(xyz_base)
r, x0, y0, z0 = sphereFit_fixed_r(xyz_base[:, 0], xyz_base[:, 1], xyz_base[:, 2], 250)
c0 = [x0, y0, z0]


# Set notebook mode to work in offline
pyo.init_notebook_mode()

k_nn = 7

#################
# load surface model and pca info
curation_path = dataRoot + imageName +  "_curation_info/"

with open(curation_path + "fin_surf_model.pkl", 'rb') as fn:
    fin_surf = pickle.load(fn)
surface_model = np.asarray(json.loads(fin_surf))

with open(curation_path + "pca_fin.pkl", 'rb') as fn:
    pca_components = pickle.load(fn)
pca_components = np.asarray(json.loads(pca_components))


###################
# Transform fin nuclei to PCA space
xyz_array = df[["X", "Y", "Z"]].to_numpy()
tree = KDTree(xyz_array)
nearest_dist, nearest_ind = tree.query(xyz_array, k=k_nn + 1)

# find average distance to kth closest neighbor
mean_nn_dist_vec = np.mean(nearest_dist, axis=0)
nn_thresh = mean_nn_dist_vec[k_nn]
nn_thresh_small = mean_nn_dist_vec[1]
    
xyz_fin = df[["X", "Y", "Z"]].iloc[fin_indices].to_numpy()

dist_array = np.sqrt(np.sum((c0 - xyz_fin) ** 2, axis=1)) - r
surf_indices = np.where(dist_array <= nn_thresh)[0]
xyz_surf = xyz_fin[surf_indices]

# calculate centroid
surf_cm = np.mean(xyz_surf, axis=0)

# calculate PCA
pca_surf = PCA(n_components=3)

pca_surf.fit(xyz_surf)
vec1 = pca_surf.components_[0]
vec2 = surf_cm - c0
vec2 = vec2 / np.sqrt(np.sum(vec2 ** 2))
plane_normal = np.cross(vec1, vec2)
plane_normal_u = plane_normal / np.sqrt(np.sum(plane_normal ** 2))
D_plane = -np.dot(plane_normal, surf_cm)

# generate additional points to constrain the axis fit such that it is approximately normal to the sphere when
# it intersects the sphere surface
# In future, we could consider adding a constraint directly to the objective function that enforces this
c_vec = surf_cm - c0
c_vec_u = c_vec / np.sqrt(np.sum(c_vec ** 2))
lower_point = surf_cm - 2*c_vec_u*nn_thresh_small
n_reps = 6
add_array = np.empty((n_reps*n_reps+1, 3))
add_array[0, :] = lower_point
for n in range(n_reps):
    ind = 2*n
    add_array[ind+1, :] = lower_point + nn_thresh*vec1*(n+1)
    add_array[ind+2:, :] = lower_point - nn_thresh*vec1*(n+1)

#######################
# fit surface to fin points
#######################
xyz_fin_raw = df[["X", "Y", "Z"]].iloc[fin_indices]
xyz_surf = xyz_fin_raw.iloc[surf_indices].to_numpy()

# exclude base points that are too far from our plane prior. If included, these will tend to cause the
# surface fit to perform poorly
d_surf = np.abs(np.sum(np.multiply(plane_normal_u, xyz_surf), axis=1) + D_plane)
ind_rm = surf_indices[np.where(d_surf > nn_thresh)[0]]
keep_indices = [k for k in range(xyz_fin_raw.shape[0]) if k not in ind_rm]
xyz_fin = xyz_fin_raw.iloc[keep_indices].copy()

# convert to point cloud and downsample points
tree = KDTree(xyz_fin.to_numpy())
nearest_dist, nearest_ind = tree.query(xyz_fin.to_numpy(), k=2)

# find average distance to kth closest neighbor
mean_nn_dist_vec = np.mean(nearest_dist, axis=0)
nn_thresh1 = mean_nn_dist_vec[1] # sets scale for downsampling withing the fin

# downsample to achieve uniform distribution of fit points
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(xyz_fin.to_numpy())
pcd_down = pcd.voxel_down_sample(voxel_size=nn_thresh1 * 1.5)
xyz_fin_down = np.asarray(pcd_down.points)

# Transform to PCA space
PCAFIN = PCA(n_components=3)
PCAFIN.fit(xyz_fin_down)

##########
# pca_fin = PCAFIN.transform(xyz_fin.to_numpy())
pca_fin = PCAFIN.transform(xyz_fin_raw.to_numpy())

# #################
# obtain surface prediction
grid_res = 100
P0, P1 = np.meshgrid(np.linspace(np.min(pca_fin[:, 0]), np.max(pca_fin[:, 0]), grid_res),
                     np.linspace(np.min(pca_fin[:, 1]), np.max(pca_fin[:, 1]), grid_res))

PP0 = P0.flatten()
PP1 = P1.flatten()

P2_curve = np.dot(np.c_[np.ones(PP0.shape), PP0, PP1, PP0 * PP1, PP0 ** 2, PP1 ** 2], surface_model).reshape(P0.shape)


####################
# Transform back to xyz space
xyz_fit_curve = PCAFIN.inverse_transform(np.concatenate((np.reshape(P0, (P0.size, 1)),
                                                         np.reshape(P1, (P1.size, 1)),
                                                         np.reshape(P2_curve, (P2_curve.size, 1))),
                                                         axis=1))
                                         
# make plots
XS = np.reshape(xyz_fit_curve[:, 0], (P0.shape))
YS = np.reshape(xyz_fit_curve[:, 1], (P0.shape))
ZS = np.reshape(xyz_fit_curve[:, 2], (P0.shape))

xyz_fit_curve_norm = xyz_fit_curve / np.max(xyz_fin_raw.to_numpy(), axis=0)
xyz_plot = xyz_fin_raw.to_numpy()
xyz_fin_norm = xyz_fin_raw.to_numpy() / np.max(xyz_fin_raw.to_numpy(), axis=0)
alphafin = alphashape.alphashape(xyz_fin_norm, 0.1)
inside_flags = np.reshape(alphafin.contains(xyz_fit_curve_norm), (grid_res, grid_res))

ZS[np.where(inside_flags!=1)] = np.nan



fig = go.Figure()
    
fig.add_trace(go.Surface(x=XS, y=YS, z=ZS, opacity=0.75, showscale=False))

# fig.add_trace(go.Scatter3d(x=xyz_plot[:, 0], y=xyz_plot[:, 1], z=xyz_plot[:, 2],
#                        mode='markers',
#                        marker=dict(opacity=0.35, color=r_vec[fin_indices], colorscale="Greys")))
fig.add_trace(go.Mesh3d(x=df["X"].iloc[fin_indices], 
                             y=df["Y"].iloc[fin_indices], 
                             z=df["Z"].iloc[fin_indices], opacity=0.25, alphahull=9,color="gray"))

fig.show()

In [18]:
px.scatter_3d(xyz_fin_raw, x="X", y="Y", z="Z")

In [15]:
np.any(np.isnan(np.asarray(xyz_fin_raw[:, 0]))

SyntaxError: incomplete input (250204949.py, line 1)