In [1]:
import numpy as np
import json
import pickle
import matplotlib.pyplot as plt

from pathlib import Path

from scipy.spatial.distance import cdist, squareform
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance, ttest_ind

from tqdm.notebook import tqdm
import trimesh
import seaborn as sns

from cellpack_analysis.utilities.data_tools import get_positions_dictionary_from_file, combine_multiple_seeds_to_dictionary

### Load in data

In [6]:
STRUCTURE_ID = "SLC25A17"
base_datadir = Path("/allen/aics/animated-cell/Saurabh/cellpack-analysis/data/")

In [14]:
packing_modes = [
    "observed",
    # "mean_count_and_size",
    # "variable_size",
    # "variable_count", 
    # "variable_count_and_size", 
    "shape"
]
all_positions = {}
for mode in packing_modes:
    print(mode)
    data_folder = file_path = base_datadir / f"packing_outputs/stochastic_variation_analysis/{mode}/peroxisome/spheresSST/"
    if mode == "shape":
        combine_multiple_seeds_to_dictionary(data_folder)
    file_path = data_folder / "positions_peroxisome_analyze_random_mean.json"
    
    if mode == "observed":
        file_path = base_datadir / f"structure_data/{STRUCTURE_ID}/sample_8d/positions_{STRUCTURE_ID}.json"
    
    positions = get_positions_dictionary_from_file(file_path, drop_random_seed=True)
    
    all_positions[mode] = positions
    

observed
shape


In [16]:
all_positions["shape"].keys()

dict_keys(['743916', '743920', '743921', '745991', '746203', '746983', '747001', '747683', '747697', '747936', '747945', '747950', '747959', '748161', '748163', '748172', '748950', '748961', '748963', '748964', '749467', '749488', '750698', '750712', '750905', '750909', '750911', '751018', '751026', '751031', '751043', '751045', '751301', '752225', '752227', '752228', '752237', '752240', '752442', '752458', '752499', '752503', '752512', '752519', '752637', '752652', '752950', '752962', '753057', '753476', '753480', '753555', '753556', '753564', '753569', '753816', '753978', '753988', '754316', '754330', '755501', '755503', '755514', '755517', '755759', '756080', '756092', '756819', '756824', '756841', '756842', '758300', '758308', '759000', '759002', '759007', '759011', '759514', '759524', '760284', '760286', '760554', '760558', '760599', '760600', '762191', '762213', '762456', '762476', '762489', '762501', '763249', '763251', '763252', '763253', '763254', '763512', '763517', '763519',

In [21]:
seed_list_shape = list(all_positions["shape"].keys())
seed_list_observed = list(all_positions["observed"].keys())
seed_list_shape == seed_list_observed


True

### Calculate distance measures

In [15]:
all_pairwise_distances = {}  # pairwise distance between particles
all_nuc_distances = {}  # distance to nucleus surface
all_nearest_distances = {}  # distance to nearest neighbor
base_mesh_path = base_datadir / f"structure_data/{STRUCTURE_ID}/meshes"
mesh_dict = {}
for mode, position_dict in all_positions.items():
    print(mode)
    all_pairwise_distances[mode] = {}
    all_nuc_distances[mode] = {}
    all_nearest_distances[mode] = {}
    for seed, positions in tqdm(position_dict.items()):

        if seed in mesh_dict:
            nuc_mesh = mesh_dict[seed]
        else:
            nuc_mesh_path = base_mesh_path / f"nuc_mesh_{seed}.obj"
            nuc_mesh = trimesh.load_mesh(nuc_mesh_path)
            mesh_dict[seed] = nuc_mesh

        nuc_distances = -trimesh.proximity.signed_distance(nuc_mesh, positions)
        all_nuc_distances[mode][seed] = nuc_distances / np.max(nuc_distances)
        all_distances = cdist(positions, positions, metric="euclidean")
        nearest_distances = np.min(all_distances + np.eye(len(positions)) * 1e6, axis=1)
        all_nearest_distances[mode][seed] = nearest_distances / np.max(
            nearest_distances
        )
        pairwise_distances = squareform(all_distances)
        all_pairwise_distances[mode][seed] = pairwise_distances / np.max(
            pairwise_distances
        )

mean_count_and_size


  0%|          | 0/305 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
num_rows = len(all_positions)
num_cols = 3

fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), dpi=300)

for i, (mode, position_dict) in enumerate(all_positions.items()):
    cmap = plt.get_cmap("jet", len(position_dict))
    for j, (distance_measure, distance_dict) in enumerate([
            ("pairwise", all_pairwise_distances[mode]),
            ("nucleus", all_nuc_distances[mode]),
            ("nearest", all_nearest_distances[mode]),
        ]):
        for k, (seed, distances) in enumerate(distance_dict.items()):
            sns.kdeplot(distances, ax=axs[i, j], color=cmap(k), alpha=0.2)

        axs[i, j].set_title(f"{mode} {distance_measure}")
        axs[i, j].set_xlim([0, 1])
        axs[i, j].set_xlabel("Distance")
        axs[i, j].set_ylabel("PDF")

fig.tight_layout()

plt.show()


Get pairwise earth movers distances

In [None]:
from cellpack_analysis.utilities.data_tools import get_pairwise_wasserstein_distance_dict

In [None]:
all_pairwise_emd = {}
for distribution_dict, distance_measure in zip(
        [all_pairwise_distances, all_nuc_distances, all_nearest_distances],
        ["pairwise", "nucleus", "nearest"],
    ):
    print(distance_measure)
    measure_pairwise_emd = {}
    for mode_1 in packing_modes:
        if mode_1 not in measure_pairwise_emd:
            measure_pairwise_emd[mode_1] = {}
        distribution_dict_1 = distribution_dict[mode_1]
        for mode_2 in packing_modes:
            if (
                measure_pairwise_emd.get(mode_2) is not None
                and measure_pairwise_emd[mode_2].get(mode_1) is not None
            ):
                continue
            print(mode_1, mode_2)

            if mode_1 == mode_2:
                distribution_dict_2 = None
            else:
                distribution_dict_2 = distribution_dict[mode_2]
            measure_pairwise_emd[mode_1][mode_2] = get_pairwise_wasserstein_distance_dict(
                distribution_dict_1, distribution_dict_2,
            )

            if mode_2 not in measure_pairwise_emd:
                measure_pairwise_emd[mode_2] = {}
            measure_pairwise_emd[mode_2][mode_1] = measure_pairwise_emd[mode_1][mode_2]
    all_pairwise_emd[distance_measure] = measure_pairwise_emd


Save EMD dict

In [None]:
base_results_dir = Path("/allen/aics/animated-cell/Saurabh/cellpack-analysis/results/")
results_dir = base_results_dir / "stochastic_variation_analysis/mean_shape"
results_dir.mkdir(exist_ok=True, parents=True)

In [None]:
# Specify the file path
file_path = results_dir / "packing_modes_pairwise_emd.dat"

# Open the file in write mode and dump the dictionary
with open(file_path, "wb") as f:
    pickle.dump(all_pairwise_emd, f)

In [None]:
fig, ax = plt.subplots()
sns.kdeplot(rv_values, color='blue', label="RV")
sns.kdeplot(rvc_values, color='red', label="RVC")
ax.set_xlabel('EMD')
ax.set_ylabel('pairwise EMD density')
ax.set_title(f"distance measure: {distance_measure}")
ax.legend()
plt.show()



In [None]:
print(f"RV: {np.mean(rv_values):0.3e} +/- {np.std(rv_values):0.3e}")
print(f"RVC: {np.mean(rvc_values):0.3e} +/- {np.std(rvc_values):0.3e}")

t_stat, p_value = ttest_ind(rv_values, rvc_values)
print(f"t-statistic: {t_stat:0.3e}, p-value: {p_value:0.3e}")
# Define the data
x = ['RV', 'RVC']
means = [np.mean(rv_values), np.mean(rvc_values)]
stds = [np.std(rv_values), np.std(rvc_values)]

# Create the barplot
plt.bar(x, means, yerr=stds, capsize=5)

# Set the titles and labels
plt.title(f'Mean and Standard Deviation\nDistance Measure: {distance_measure}\nt-statistic: {t_stat:0.3e}, p-value: {p_value:0.3e}')
plt.xlabel('Packing Modes')
plt.ylabel('Values')

# Show the plot
plt.show()
