In [5]:
import numpy as np
import os
import pandas as pd
from src.utilities.fin_shape_utils import plot_mesh
from src.utilities.fin_class_def import FinData
from src.utilities.functions import path_leaf
import glob2 as glob
import trimesh
from tqdm import tqdm
import torch
import geomloss
from numpy.linalg import norm

In [6]:
# get list of refined fin mesh objects
root = "/media/nick/hdd02/Cole Trapnell's Lab Dropbox/Nick Lammers/Nick/pecfin_dynamics/"
fin_mesh_list = sorted(glob.glob(os.path.join(root, "point_cloud_data", "processed_fin_data", "*smoothed_fin_mesh*")))

# load metadata
metadata_df = pd.read_csv(os.path.join(root, "metadata", "master_metadata.csv"))
metadata_df["experiment_date"] = metadata_df["experiment_date"].astype(str)
metadata_df.head()

# set write directory
write_dir = os.path.join(root, "point_cloud", "fin_shape_analyses")
if not os.path.isdir(write_dir):
    os.makedirs(write_dir) 

### Calculate approximate Wassersten distance between:
1) Mesh surface points 
2) Fin centroids
3) Mesh surface points (normalized case)
4) Mesh surface points (normalized case

### Preload data structures for distance calculations

In [60]:
df_list = []
surf_list = []
surf_norm_list = []
centroid_list = []
centroid_norm_list = []

n_points = 250
np.random.seed(125)
# preload meshes and generate metadata table

for file_ind0, fp0 in enumerate(tqdm(fin_mesh_list)):

    fp0_centroid = fp0.replace("smoothed_fin_mesh.obj", "fin_data_upsampled.csv")
    # extract relevant metadata
    fname0 = os.path.basename(fp0)
    well_ind0 = fname0.find("well")
    date_string0 = fname0[:well_ind0-1]
    well_num0 = int(fname0[well_ind0+4:well_ind0+8])
    time_ind0 = fname0.find("time")
    time_num0 = int(fname0[time_ind0+4:time_ind0+8])
    
    # match this to a row in the metadata df
    date_ft0 = metadata_df["experiment_date"] == date_string0
    well_ft0 = metadata_df["well_index"] == well_num0
    time_ft0 = metadata_df["time_index"] == time_num0
    
    meta_temp = metadata_df.loc[date_ft0 & well_ft0 & time_ft0, :].reset_index(drop=True)
    
    cv = meta_temp.loc[0, "chem_i"]
    if isinstance(cv, str):
        cvs = cv.split("_")
        chem_id = cvs[0]
        chem_time = int(cvs[1])
    else:
        chem_id = "WT"
        chem_time = np.nan
    meta_temp.loc[0, "chem_id"] = chem_id
    meta_temp.loc[0, "chem_time"] = chem_time

    # load mesh
    fin_mesh = trimesh.load(fp0)
    fin_df = pd.read_csv(fp0_centroid)

    fin_df_c = fin_df.loc[:, ["nucleus_id", "XB", "YB", "ZB"]].groupby("nucleus_id").mean().reset_index(drop=False)
    options = np.arange(fin_df_c.shape[0])
    c_to_samp = np.random.choice(options, n_points, replace=True)

    c = torch.tensor(fin_df_c[["XB", "YB", "ZB"]].to_numpy())
    c = c[c_to_samp, :]
    cn = c.clone() - torch.mean(c, axis=0)
    mn = np.amax(norm(c, axis=1))
    cn /= mn
    
    # convert to tensor format
    s = torch.tensor(fin_mesh.sample(n_points))
    sn = s.clone() - torch.mean(s, axis=0)
    mn = np.amax(norm(s, axis=1))
    sn /= mn

    if sn.shape[0] > 0:
        df_list.append(meta_temp)
        
        # surf points
        surf_list.append(s)
        surf_norm_list.append(sn)

        # centroids
        centroid_list.append(c)
        centroid_norm_list.append(cn)

dist_df = pd.concat(df_list, axis=0, ignore_index=True)

100%|██████████| 224/224 [00:20<00:00, 11.15it/s]


In [62]:
cn.shape

torch.Size([250, 3])

In [65]:
# calculate distances

# Create a Sinkhorn loss instance
loss = geomloss.SamplesLoss("sinkhorn", p=1, blur=0.5)

# initialize array to store distance
dist_arr0 = np.empty((len(fin_mesh_list), len(fin_mesh_list)))
distance_array[:] = np.nan
dist_arr1 = dist_arr0.copy()
dist_arr2 = dist_arr0.copy()
dist_arr3 = dist_arr0.copy()

# iterate
for i in tqdm(range(len(surf_list))):
    si = surf_list[i]
    sni = surf_norm_list[i]
    ci = centroid_list[i]
    cni = centroid_norm_list[i]
    
    for j in (range(i+1, len(surf_list))):
        sj = surf_list[j]
        snj = surf_norm_list[j]
        cj = centroid_list[j]
        cnj = centroid_norm_list[j]
        
        # (1) Compute distance between surf points
        dist0 = loss(si, sj)
    
        dist_arr0[i, j] = dist0
        dist_arr0[j, i] = dist0
        
        # (2) Compute distance between centroid points
        dist1 = loss(ci, cj)
    
        dist_arr1[i, j] = dist1
        dist_arr1[j, i] = dist1
        
        # (3) Compute distance between NORMALIZED surf points
        dist2 = loss(sni, snj)
    
        dist_arr2[i, j] = dist2
        dist_arr2[j, i] = dist2

        # (4) Compute distance between NORMALIZED centroid points
        dist3 = loss(cni, cnj)
    
        dist_arr3[i, j] = dist3
        dist_arr3[j, i] = dist3

        if (dist0 < 0) |  (dist1 < 0) | (dist2 < 0) | (dist3 < 0):
            print(f"why_{i}_{j}")
            break

100%|██████████| 224/224 [11:48<00:00,  3.16s/it]


In [57]:
cj.shape

torch.Size([53400, 3])

In [66]:
# sort array and DF for clarity
indices = np.lexsort((dist_df['chem_id'], dist_df['chem_time']))
dist_df_s = dist_df.iloc[indices]

dist_arr0_s = dist_arr0.copy()
dist_arr0_s = dist_arr0_s[indices, :]
dist_arr0_s = dist_arr0_s[:, indices]

dist_arr1_s = dist_arr1.copy()
dist_arr1_s = dist_arr1_s[indices, :]
dist_arr1_s = dist_arr1_s[:, indices]

dist_arr2_s = dist_arr2.copy()
dist_arr2_s = dist_arr2_s[indices, :]
dist_arr2_s = dist_arr2_s[:, indices]

dist_arr3_s = dist_arr3.copy()
dist_arr3_s = dist_arr3_s[indices, :]
dist_arr3_s = dist_arr3_s[:, indices]

In [67]:
# save 
dist_df_s.to_csv(os.path.join(write_dir, "emd_dist_df.csv"), index=False)

np.save(os.path.join(write_dir, "surf_dist_arr.npy"), dist_arr0_s)
np.save(os.path.join(write_dir, "centroid_dist_arr.npy"), dist_arr1_s)
np.save(os.path.join(write_dir, "surf_dist_norm_arr.npy"), dist_arr2_s)
np.save(os.path.join(write_dir, "centroid_dist_norm_arr.npy"), dist_arr3_s)