In [1]:
import numpy as np
import json
import pickle
import matplotlib.pyplot as plt

from pathlib import Path

from scipy.spatial.distance import cdist, squareform
import matplotlib.pyplot as plt
from scipy.stats import wasserstein_distance, ttest_ind

from tqdm.notebook import tqdm
import trimesh
import seaborn as sns

from cellpack_analysis.utilities.data_tools import get_positions_dictionary_from_file, combine_multiple_seeds_to_dictionary

### Load in data

In [2]:
base_datadir = Path("/allen/aics/animated-cell/Saurabh/cellpack-analysis/data/")

In [11]:
packing_modes = ["mean_count_and_size", "variable_size", "variable_count", "variable_count_and_size", "shape"]
all_positions = {}
for mode in packing_modes:
    print(mode)
    data_folder = file_path = base_datadir / f"packing_outputs/stochastic_variation_analysis/{mode}/peroxisome/spheresSST/"
    if mode == "shape":
        combine_multiple_seeds_to_dictionary(data_folder)
    
    file_path = data_folder / "positions_peroxisome_analyze_random_mean.json"
    positions = get_positions_dictionary_from_file(file_path)
    
    all_positions[mode] = positions
    

mean_count_and_size
variable_size
variable_count
variable_count_and_size
shape


In [13]:
mode = "mean_count_and_size"
for key, value in all_positions[mode].items():
    print(key, len(value))

0 121
1 121
2 121
3 121
4 121
5 121
6 121
7 121
8 121
9 121
10 121
11 121
12 121
13 121
14 121
15 121
16 121
17 121
18 121
19 121
20 121
21 121
22 121
23 121
24 121
25 121
26 121
27 121
28 121
29 121
30 121
31 121
32 121
33 121
34 121
35 121
36 121
37 121
38 121
39 121
40 121
41 121
42 121
43 121
44 121
45 121
46 121
47 121
48 121
49 121
50 121
51 121
52 121
53 121
54 121
55 121
56 121
57 121
58 121
59 121
60 121
61 121
62 121
63 121
64 121
65 121
66 121
67 121
68 121
69 121
70 121
71 121
72 121
73 121
74 121
75 121
76 121
77 121
78 121
79 121
80 121
81 121
82 121
83 121
84 121
85 121
86 121
87 121
88 121
89 121
90 121
91 121
92 121
93 121
94 121
95 121
96 121
97 121
98 121
99 121
100 121
101 121
102 121
103 121
104 121
105 121
106 121
107 121
108 121
109 121
110 121
111 121
112 121
113 121
114 121
115 121
116 121
117 121
118 121
119 121
120 121
121 121
122 121
123 121
124 121
125 121
126 121
127 121
128 121
129 121
130 121
131 121
132 121
133 121
134 121
135 121
136 121
137 121
138 12

### Calculate distance measures

In [14]:
nuc_mesh_path = base_datadir / "average_shape_meshes/nuc_mesh_mean.obj"
nuc_mesh = trimesh.load_mesh(str(nuc_mesh_path))

In [15]:
all_pairwise_distances = {}  # pairwise distance between particles
all_nuc_distances = {}  # distance to nucleus surface
all_nearest_distances = {}  # distance to nearest neighbor
for mode, position_dict in all_positions.items():
    print(mode)
    all_pairwise_distances[mode] = {}
    all_nuc_distances[mode] = {}
    all_nearest_distances[mode] = {}
    for seed, positions in tqdm(position_dict.items()):
        nuc_distances = -trimesh.proximity.signed_distance(nuc_mesh, positions)
        all_nuc_distances[mode][seed] = nuc_distances / np.max(nuc_distances)
        all_distances = cdist(positions, positions, metric='euclidean')
        nearest_distances = np.min(all_distances + np.eye(len(positions)) * 1e6, axis=1)
        all_nearest_distances[mode][seed] = nearest_distances / np.max(nearest_distances)
        pairwise_distances = squareform(all_distances)
        all_pairwise_distances[mode][seed] = pairwise_distances / np.max(pairwise_distances)


mean_count_and_size


  0%|          | 0/305 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
num_rows = len(all_positions)
num_cols = 3

fig, axs = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3), dpi=300)

for i, (mode, position_dict) in enumerate(all_positions.items()):
    cmap = plt.get_cmap("jet", len(position_dict))
    for j, (distance_measure, distance_dict) in enumerate([
            ("pairwise", all_pairwise_distances[mode]),
            ("nucleus", all_nuc_distances[mode]),
            ("nearest", all_nearest_distances[mode]),
        ]):
        for k, (seed, distances) in enumerate(distance_dict.items()):
            sns.kdeplot(distances, ax=axs[i, j], color=cmap(k), alpha=0.2)

        axs[i, j].set_title(f"{mode} {distance_measure}")
        axs[i, j].set_xlim([0, 1])
        axs[i, j].set_xlabel("Distance")
        axs[i, j].set_ylabel("PDF")

fig.tight_layout()

plt.show()


Get pairwise earth movers distances

In [None]:
from cellpack_analysis.utilities.data_tools import get_pairwise_wasserstein_distance_dict

In [None]:
all_pairwise_emd = {}
for distribution_dict, distance_measure in zip(
        [all_pairwise_distances, all_nuc_distances, all_nearest_distances],
        ["pairwise", "nucleus", "nearest"],
    ):
    print(distance_measure)
    measure_pairwise_emd = {}
    for mode_1 in packing_modes:
        if mode_1 not in measure_pairwise_emd:
            measure_pairwise_emd[mode_1] = {}
        distribution_dict_1 = distribution_dict[mode_1]
        for mode_2 in packing_modes:
            if (
                measure_pairwise_emd.get(mode_2) is not None
                and measure_pairwise_emd[mode_2].get(mode_1) is not None
            ):
                continue
            print(mode_1, mode_2)

            if mode_1 == mode_2:
                distribution_dict_2 = None
            else:
                distribution_dict_2 = distribution_dict[mode_2]
            measure_pairwise_emd[mode_1][mode_2] = get_pairwise_wasserstein_distance_dict(
                distribution_dict_1, distribution_dict_2,
            )

            if mode_2 not in measure_pairwise_emd:
                measure_pairwise_emd[mode_2] = {}
            measure_pairwise_emd[mode_2][mode_1] = measure_pairwise_emd[mode_1][mode_2]
    all_pairwise_emd[distance_measure] = measure_pairwise_emd


Save EMD dict

In [None]:
base_results_dir = Path("/allen/aics/animated-cell/Saurabh/cellpack-analysis/results/")
results_dir = base_results_dir / "stochastic_variation_analysis/mean_shape"
results_dir.mkdir(exist_ok=True, parents=True)

In [None]:
# Specify the file path
file_path = results_dir / "packing_modes_pairwise_emd.dat"

# Open the file in write mode and dump the dictionary
with open(file_path, "wb") as f:
    pickle.dump(all_pairwise_emd, f)

In [None]:
fig, ax = plt.subplots()
sns.kdeplot(rv_values, color='blue', label="RV")
sns.kdeplot(rvc_values, color='red', label="RVC")
ax.set_xlabel('EMD')
ax.set_ylabel('pairwise EMD density')
ax.set_title(f"distance measure: {distance_measure}")
ax.legend()
plt.show()



In [None]:
print(f"RV: {np.mean(rv_values):0.3e} +/- {np.std(rv_values):0.3e}")
print(f"RVC: {np.mean(rvc_values):0.3e} +/- {np.std(rvc_values):0.3e}")

t_stat, p_value = ttest_ind(rv_values, rvc_values)
print(f"t-statistic: {t_stat:0.3e}, p-value: {p_value:0.3e}")
# Define the data
x = ['RV', 'RVC']
means = [np.mean(rv_values), np.mean(rvc_values)]
stds = [np.std(rv_values), np.std(rvc_values)]

# Create the barplot
plt.bar(x, means, yerr=stds, capsize=5)

# Set the titles and labels
plt.title(f'Mean and Standard Deviation\nDistance Measure: {distance_measure}\nt-statistic: {t_stat:0.3e}, p-value: {p_value:0.3e}')
plt.xlabel('Packing Modes')
plt.ylabel('Values')

# Show the plot
plt.show()
