In [None]:
import sys
sys.path.append('/home/nick/projects')
# from PointGPT import segmentation
from fle_3d.fle_3d import FLEBasis3D
import numpy as np
import pandas as pd
import os
from glob2 import glob
import plotly.express as px
import plotly.graph_objects as go
import torch
from src.utilities.functions import path_leaf
from src.utilities.fin_class_def import FinData
from tqdm import tqdm
from sklearn.neighbors import KernelDensity

### Notebook to test how many components we need to "accurately" reconstruct fin volumes
There are multiple variables at play here, including the nature of the input data that we seek to reconstruct. 

Should the target be: 
1) a 3D density derived from the point cloud?
2) A sparse 3D histogram of discrete counts?
3) Fin masks?
4) The raw pixel probabilities?

The hope is to settle on both the optimal kind of input data for this problem, and an approximate bound on the number of components required for reconstruction

### Load test dataset

In [None]:
root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/pecfin_dynamics/"

fin_object_path = os.path.join(root, "point_cloud_data", "fin_objects", "")
fin_object_list = sorted(glob(fin_object_path + "*.pkl"))

In [None]:
file_ind01 = 46
seg_type = "tissue_only_best_model_tissue"

fp01 = fin_object_list[file_ind01]
point_prefix01 = path_leaf(fp01).replace("_fin_object.pkl", "")
print(point_prefix01)

fin_data = FinData(data_root=root, name=point_prefix01, tissue_seg_model=seg_type)
fin_df = fin_data.full_point_data
fin_df = fin_df.loc[fin_df["fin_label_curr"]==1, :]
fin_points = fin_df[["X", "Y", "Z"]].to_numpy()

fin_axis_df = fin_data.axis_fin
fin_axes = fin_data.calculate_axis_array(fin_axis_df)
fin_points_pca = np.matmul(fin_points - np.mean(fin_points, axis=0), fin_axes.T)
fin_df.loc[:, ["XP", "YP", "ZP"]] = fin_points_pca

### (1) point-based fin representations
What's the best we can do operating with simple point cloud-based representations?

In [None]:
#############
# calculate density-based representation
X01 = fin_df[["ZP", "YP", "XP"]].to_numpy()
X01 = X01 - np.mean(X01, axis=0)
kde_lr = KernelDensity(bandwidth=5, kernel="gaussian").fit(X01) # Gaussian sampling kernel with sigma=5 pixels
kde_hr = KernelDensity(bandwidth=2, kernel="gaussian").fit(X01) # Gaussian sampling kernel with sigma=2 pixels


max_dim = int(np.ceil(np.max(np.abs(X01)) / 5) * 5)
res = 4 # in um
N = int(np.ceil(2*max_dim / res)) + 1
# print(N)
x_axis = np.linspace(-max_dim, max_dim, N)
y_axis = np.linspace(-max_dim, max_dim, N)
z_axis = np.linspace(-max_dim, max_dim, N)
x_grid, y_grid, z_grid = np.meshgrid(x_axis, y_axis, z_axis)
xyz_array = np.c_[x_grid.ravel(), y_grid.ravel(), z_grid.ravel()]

probs_lr = np.exp(kde_lr.score_samples(xyz_array))
probs_hr = np.exp(kde_hr.score_samples(xyz_array))

In [None]:
# Go even simpler and just calculate 3D histogram
x_bins = np.linspace(-max_dim, max_dim, N+1)
y_bins = np.linspace(-max_dim, max_dim, N+1)
z_bins = np.linspace(-max_dim, max_dim, N+1)
point_hist, _ = np.histogramdd(X01, (x_bins, y_bins, z_bins))

#### Histogram representation

In [None]:
fig = go.Figure(data=go.Volume(
    x=x_grid.flatten(), y=y_grid.flatten(), z=z_grid.flatten(),
    value=point_hist.flatten(),
    opacity=0.1,
    isomin=1e-6,
    surface_count=25,
    ))
fig.update_layout(scene_xaxis_showticklabels=False,
                  scene_yaxis_showticklabels=False,
                  scene_zaxis_showticklabels=False)
fig.show()

#### Low-res kernel density

In [None]:
fig = go.Figure(data=go.Volume(
    x=x_grid.flatten(), y=y_grid.flatten(), z=z_grid.flatten(),
    value=probs_lr.flatten(),
    opacity=0.1,
    isomin=1e-6,
    surface_count=25,
    ))
fig.update_layout(scene_xaxis_showticklabels=False,
                  scene_yaxis_showticklabels=False,
                  scene_zaxis_showticklabels=False)
fig.show()

In [None]:
import pyvista as pv

probs_test = np.reshape(probs_lr / np.max(probs_lr), (N, N, N))
probs_test[probs_test < 0.05] = 0

grid = pv.ImageData()
grid.dimensions = probs_test.shape
grid_spacing = (res, res, res)
grid.origin = (0, 0, 0)
grid["density"] = probs_test.ravel(order="F") 

# Step 6: Plot the volume


In [None]:
plotter = pv.Plotter()
plotter.add_volume(grid, scalars="density", cmap="viridis", opacity="sigmoid")
plotter.show()

In [None]:
pv.UniformGrid()