In [None]:
import jax
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import os
import utils
import optax
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from tueplots import bundles, figsizes
# bundle = bundles.icml2024()
plt.rcParams.update(bundles.icml2024(usetex=False, family='sans-serif'))
from scipy.ndimage import rotate as rotation_fn
from scipy.ndimage import label, binary_dilation, generate_binary_structure
def shuffle(arr):
    idx = np.random.choice(np.arange(arr.shape[0]), replace=False, size=arr.shape[0])
    return arr[idx]
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-3
WEIGHT_DECAY=1e-3
STEPS = 676
PRINT_EVERY = 25
SEED = 12345678
USE_RESIDUAL = True
STRIDE_FIRST_LAYER=2

FIGURE_SAVEDIR = 'Experiment_figures/'
if not os.path.exists(FIGURE_SAVEDIR):
    os.makedirs(FIGURE_SAVEDIR)

In [None]:
def load_data_batched_real(bs, data_dir, t=-1):

    files = os.listdir(data_dir)
    # select some random files to load:
    selected_files = np.random.choice(files, size=bs, replace=False)  # list of .npz files
    data = []
    for f in tqdm(selected_files):
        # load the final state of each simulation. shape is (t, 2, grid_size, grid_size)
        arr = utils.load_data_from_file(os.path.join(data_dir, f))
        # randomly rotate all all channels with a random number of degrees using built-in function from library:
        angle = np.random.uniform(0, 360)
        rotated_arr = rotation_fn(arr, angle, axes=(2, 3), reshape=False, order=0)
        arr = rotated_arr
        if t is not None:
            t_int = t
            arr = arr[t_int]
        data.append(
            arr[None, ...]
        )  # now shape (1, [t], 2, grid_size, grid_size)
    return np.concatenate(data).astype(int)

In [None]:
path = '../data/Exp_2_toda_padded/'
all_data = load_data_batched_real(len(os.listdir(path)),
        path, t=-1)

In [None]:
all_data[0].shape

In [None]:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. #  light green
        ]

fig, axs = plt.subplots(3, 5)
for i, ax in enumerate(axs.flat):
    plotted = utils.plot_cell_image(all_data[i], ax, colors=colors)
    # ax.set_title(f"Label: {all_labels[i]}")
    ax.axis('off')
plt.tight_layout()
# plt.savefig(FIGURE_SAVEDIR+'some_real_data_exp2.png')
plt.show()

In [None]:
all_data[0][0].max()

In [None]:
paths = []
for p in os.listdir('Exp2/neuralcpm_base/'):
    if p.endswith('.eqx'):
        files = os.listdir(os.path.join('Exp2/neuralcpm_base/', p))
        for f in files:
            if f.endswith('.npz'):
                paths.append(os.path.join('Exp2/neuralcpm_base/', p, f))


datas = {}
samples = {}
energies = {}
for path in paths:
    data = np.load(path)
    datas[path] = data
    samples[path] = data['all_samples']
    energies[path] = data['all_energies']
    print(path, samples[path].shape)

In [None]:
def calc_moment(datapoint, moment=1, types=(1,2), num_cells=41, rotate_alignment_type=None):
    # shape of datapoint: 2, h, w - -first channel cell id, second cell type
    moments = jnp.zeros((len(types), 2))
    coms_types = {}
    for i, type in enumerate(types):
        type_mask = (datapoint[1] == type)[None, ...]
        data_type = jnp.where(type_mask, datapoint, 0)
        coms = utils.get_cells_com(data_type, num_cells)[1:]  # note: this will give NaN for non-existing cells. format x, y
        coms_types[type] = coms
    if rotate_alignment_type in types:
        # we find the coms of this type, and rotate all coms of all types with the principal components of this type
        coms = coms_types[rotate_alignment_type]
        coms = coms[~jnp.isnan(coms).any(axis=1)]
        if len(coms) <= 1:
            coms = coms_types[rotate_alignment_type].at[:].set(jnp.nan)
        mean = jnp.nanmean(coms, axis=0)
        coms = coms - mean
        cov = jnp.cov(coms, rowvar=False)
        eigvals, eigvecs = jnp.linalg.eigh(cov)
        # sort eigenvectors by eigenvalues
        idx = jnp.argsort(eigvals)[::-1]
        eigvecs = eigvecs[:, idx]
        coms_types = {k: jnp.dot(v - mean, eigvecs) for k, v in coms_types.items()}
    # finally, calculate the moments in the possibly transformed coordinate system and return them:
    for i, type in enumerate(types):
        coms = coms_types[type]
        if moment == 0:
            moments = moments.at[i].set(jnp.sum(~jnp.isnan(coms), axis=0))
        elif moment == 1:
            moments = moments.at[i].set(jnp.nanmean(coms, axis=0))
        else: # calculate the centralized moment:
            mean = jnp.nanmean(coms, axis=0)
            moments = moments.at[i].set(jnp.nanmean((coms - mean) ** moment, axis=0))
    return moments  # shape types, coords


In [None]:
# print('zeroth moment training data:')
# m0_real = jax.vmap(calc_moment, in_axes=(0, None))(all_data, 0).mean(0)

In [None]:
# print('first moment training data:')
# m1_real = jax.vmap(calc_moment, in_axes=(0, None))(all_data, 1).mean(0)

In [None]:
print('second moment training data:')
m2s_real = []
for d in tqdm(all_data):
    # print(d.shape)
    m2 = calc_moment(d, 2, rotate_alignment_type=2)
    m2s_real.append(m2)
# m2_real = jax.vmap(calc_moment, in_axes=(0, None, None, None, None))(all_data, 2,(1,2),41,1).mean(0)

In [None]:
m2_real = jnp.nanmean(jnp.stack(m2s_real, axis=0), axis=0)

In [None]:
m2_samples = {}
m2s_samples = {}
if 'data_real' in samples:
    samples.pop('data_real')
for k, v in samples.items():
    sample_final = v[:, -1]
    print(k)
    m2s_sample_this = []
    for d in tqdm(sample_final):
        # print(d.shape)
        m2 = calc_moment(d, 2, rotate_alignment_type=2)
        m2s_sample_this.append(m2)
    m2_samples[k.split("/")[-2] if '/' in k else k] = jnp.nanmean(jnp.stack(m2s_sample_this, axis=0), axis=0)
    m2s_samples[k.split("/")[-2] if '/' in k else k] = jnp.stack(m2s_sample_this, axis=0)

In [None]:
list(m2_samples.values())[0]

In [None]:
m2s_real_arr = jnp.stack(m2s_real, axis=0)
m2s_real_arr.shape

In [None]:
#rebuttal: test if mean alignment is identical to training data:
from scipy.stats import ttest_ind, kstest, false_discovery_control

# types = (0,1)
# axes = (0,1)
#
# for k in m2_samples.keys():
#     print(f'\n {k} \n')
#     for type in types:
#         for ax in axes:
#             m2s_real_this = m2s_real_arr[:, type, ax]
#             m2s_sample_this = m2s_samples[k][:, type, ax]
#             # stat, p = ttest_ind(m2s_real_this, m2s_sample_this, equal_var=False, nan_policy='omit')
#             stat, p = kstest(m2s_real_this, m2s_sample_this, nan_policy='omit')
#             print(f'type {type}, axis {ax}: p-value: {p}')


# calculate RMSE based on sub-sampled data:

subsamp_size=10
all_rmses = {}
for k in m2_samples.keys():
    rmses = []
    m2s_samples_key = m2s_samples[k]
    for i_start in range(0, m2s_samples_key.shape[0], subsamp_size):
        m2s_sample_this = m2s_samples_key[i_start:i_start+subsamp_size]
        m2_mean_this = jnp.nanmean(m2s_sample_this, axis=0)
        rmse = jnp.sqrt(jnp.mean((m2_mean_this - m2_real) ** 2))
        rmses.append(rmse)
    all_rmses[k] = np.array(rmses)


from scipy.stats import ttest_ind, shapiro

base_key = 'experiment_2_nch3_2950.eqx'
keys = []
p_values = []
base = all_rmses.pop(base_key)


for k, v in all_rmses.items():
    # print(jnp.mean(v), jnp.std(v))
    stat, p = ttest_ind(base, v, equal_var=False, nan_policy='omit')
    # bonferroni correction:
    p = p * len(all_rmses)
    p_values.append(p)
    keys.append(k)
    print(k, stat, p)

print('\n\n')
# test normality:
stat, p = shapiro(base)
print('normality test:', stat, p)
for k, v in all_rmses.items():
    # test normality:
    stat, p = shapiro(v)
    print(k, 'normality test:', stat, p)



In [None]:
all_rmses.keys()

In [None]:
print('RMSE')
for k, v in m2_samples.items():
    rmse = jnp.sqrt(jnp.mean((v - m2_real) ** 2))
    print(k, rmse)


print('normalized RMSE')
for k, v in m2_samples.items():
    m2_real_norm = m2_real / jnp.sum(m2_real, axis=1, keepdims=True)
    v_norm = v / jnp.sum(v, axis=1, keepdims=True)
    rmse = jnp.sqrt(jnp.mean((v_norm - m2_real_norm) ** 2))
    print(k, rmse)

In [None]:
# a single plot with one qualitative example for each model:

# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000

if 'real_data' in samples:
    samples.pop('real_data')



names_of_models = ['experiment_2_cellsort_9950.eqx', 'experiment_2_ext_pot_9950.eqx', 'experiment_2_conv_ham_500.eqx',
                   'experiment_2_shallow_nh_400.eqx', 'experiment_2_nh_3200.eqx', 'experiment_2_nch3_2950.eqx']
names_to_plot = ['Cellsort\nHamiltonian', 'Cellsort\nHamiltonian\n+External\nPotential', 'CNN', '1-layer\nNeural\nHamiltonian\n+CNN',
                 'Neural\nHamiltonian', 'Neural\nHamiltonian\n+closure']

if not os.path.exists(os.path.join(FIGURE_SAVEDIR, 'add_qual_exp2')):
    os.makedirs(os.path.join(FIGURE_SAVEDIR, 'add_qual_exp2'))

for plot_id in range(4):
    keys = list(samples.keys())
    data_to_plot = []

    for name in names_of_models:
        k = list(filter(lambda x: name in x, keys))[0]
        print(name, k)
        v = samples[k]
        sample = v
        data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
        data_to_plot.append(data_this[:, :100])

    data_to_plot = np.concatenate(data_to_plot, axis=0)
    ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
    mcs_per_t = 134 * 50 / (data_to_plot.shape[-1] * data_to_plot.shape[-2])

    time_labels = [i * float(np.round(750 / 4/60, 1)) for i in range(5)]
    fig, axs = plt.subplots(len(data_to_plot), len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(data_to_plot)), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})


    print(data_to_plot.shape)
    print(axs.shape)
    utils.plot_cell_trajectory_data(data_to_plot, len(data_to_plot), ts_to_plot, axs, colors=colors)


    # Adjust axis labels and formatting
    for i, ax_row in enumerate(axs):
        for j, ax in enumerate(ax_row):
            ax.axis("on")  # Explicitly enable axes for adding labels
            ax.set_xticks([])
            ax.set_yticks([])
            if j == 0:  # Add row labels to the left of the subplots
                ax.set_ylabel(
                    names_to_plot[i], fontsize=8, labelpad=20,
                    rotation=90, va="center", ha="center"
                )
            if i == len(axs) - 1:  # Add x-axis labels below the bottom row
                ax.set_xlabel(f"{time_labels[j]}", fontsize=8, labelpad=10)

    # Adjust layout and show the plot
    plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels

    # Add a global x-axis label for time in seconds
    fig.text(0.5, 0.04, 'Time (hours)', ha='center', va='center', fontsize=9)
    plt.savefig(os.path.join(FIGURE_SAVEDIR, 'add_qual_exp2', f'traj_2_qual_{plot_id}.pdf'), dpi=400)
    plt.show()

print('\n')

In [None]:
axs.shape

In [None]:
ts_to_plot

In [None]:
# calculate the stable states:

num_cells = len(np.unique(all_data[:,0]))
volumes = jax.vmap(utils.calculate_all_cell_volumes, in_axes=(0, None))(all_data, num_cells)[:, 1:] # ignore medium

vol_low, vol_high = np.min(volumes), np.max(volumes)

In [None]:
vol_low, vol_high

In [None]:
list(samples.values())[0].shape

In [None]:
samples['data_real'] = all_data[:, None]

In [None]:
stable_dict = {}
frag_dict = {}
vol_dict = {}

neighborhood_order=4
from scipy.ndimage import label



def calc_frag_batched(samples, num_cells, neighborhood_order): # this is not transformable under jax unfortunately
    num_fragmented = np.zeros(samples.shape[0])
    for i, sample in enumerate(samples):
        num_fragmented[i] = utils.count_num_fragmented(sample, num_cells, neighborhood_order)
    return num_fragmented

calc_vol_batched = eqx.filter_jit(jax.vmap(
            utils.calculate_all_cell_volumes, in_axes=(0, None)
        ))
for k, v in samples.items():
    print(k)
    sample = v
    stable_times = []
    frag_times = []
    vol_times = []
    for t in tqdm(range(0, sample.shape[1], 5)):
        # stable = calc_stable_batched(sample[:, t], num_cells, vol_low, vol_high, 0, 0, 2)
        frag = calc_frag_batched(np.array(sample[:, t]), num_cells, neighborhood_order)
        vol = calc_vol_batched(sample[:, t], num_cells)[:, 1:]
        # stable_times.append(stable[:, None, ...])
        frag_times.append(frag[:, None, ...])
        vol_times.append(vol[:, None, ...])
    # stable_dict[k] = np.concatenate(stable_times, axis=1)
    frag_dict[k] = np.concatenate(frag_times, axis=1)
    vol_dict[k] = np.concatenate(vol_times, axis=1)

In [None]:
min_vol = vol_low - 15
max_vol = vol_high + 15
max_num_vol_violate = 0
max_num_fragmented = 3

fig, ax = plt.subplots(1, 3)

for k, v in frag_dict.items():
    num_frag = v
    vol = vol_dict[k]
    vol_good = jnp.logical_or((vol < min_vol), (vol > max_vol)).sum(-1) <= max_num_vol_violate
    frag_good = num_frag <= max_num_fragmented
    stable = jnp.logical_and(frag_good, vol_good)
    # stable = np.cumprod(stable, axis=1)
    ax[0].plot(stable.mean(0), label=k)
    # put legend above plot
    ax[0].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), shadow=True, ncol=2)
    ax[1].plot(vol_good.mean(0))
    ax[2].plot(frag_good.mean(0))
    name = k.split('/')[-2] if '/' in k else k
    print(name, 'vol: ', vol_good.mean(), 'frag: ',frag_good.mean(), 'stable: ',stable.mean())
# plt.legend()
plt.show()

In [None]:
# load the data provided by Toda, preprocessed by Lutz:
from PIL import Image
data_toda_dir = 'toda_experiment_data/Toda2018_Fig.S6B_5samples_2channels_frames_equilib'

# first get the prefix to match red and green channel:
prefixes = set(map(lambda x: "_".join(x.split('_')[:-1]), os.listdir(data_toda_dir)))
# load the images
imgs = []
for p in prefixes:
    print(p)
    red = Image.open(os.path.join(data_toda_dir, p+'_red.png'))
    green = Image.open(os.path.join(data_toda_dir, p+'_green.png'))
    red = np.array(red)
    green = np.array(green)
    blue = np.zeros_like(green)
    im = np.stack([red, green, blue], axis=-1)
    imgs.append(im)




In [None]:
def align_img_green_ch(im):
    # find the pixels that are occupied by green
    # then rotate the entire image using principle components of green

    green = im[...,1]
    pixels_green = np.stack(np.nonzero(green), axis=-1)  # pixels, 2 pixel coords
    mean = np.mean(pixels_green, axis=0)
    pixels_green = pixels_green - mean
    cov = np.cov(pixels_green, rowvar=False)
    eigvals, eigvecs = np.linalg.eigh(cov)
    # sort eigenvectors by eigenvalues
    idx = np.argsort(eigvals)[::-1]
    eigvecs = eigvecs[:, idx]
    # find the rotation corresponding to the eigenvectors -- align green horizontally:
    angle = np.arctan2(eigvecs[0, 0], eigvecs[1, 0]) * 180 / np.pi
    # rotate the image:
    im = rotation_fn(im, angle, reshape=False, order=0)
    return im




def calc_moment_image(im, moment=1):
    im = align_img_green_ch(im)
    moments = np.zeros((im.shape[-1], 2))  # shape: [num channels/cell types, 2(x,y)]
    h_range = np.arange(im.shape[0]) / im.shape[0]
    w_range = np.arange(im.shape[1]) / im.shape[1]
    for c in range(im.shape[-1]):
        ch_this = im[..., c]
        E_h = (ch_this * h_range[:, None]).sum() / np.sum(ch_this)
        E_w = (ch_this * w_range[None, :]).sum() / np.sum(ch_this)
        if moment > 1:
            m_h = (ch_this * (h_range[:, None] - E_h)**moment).sum() / np.sum(ch_this)
            m_w = (ch_this * (w_range[None, :] - E_w)**moment).sum() / np.sum(ch_this)
        else:
            m_h, m_w = E_h, E_w
        moments[c] = m_w, m_h  # x, y moments
    return moments


In [None]:
thresh=165
for i in range(3):
    fig, axs = plt.subplots(1,2)
    im = imgs[i].copy()
    im = im * (im>thresh)
    axs[0].imshow(im)
    axs[1].imshow(align_img_green_ch(im))
    plt.show()
    # m2 = calc_moment_image(im, moment=2)
    # m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
    # print(m2, '\n', m2_norm)

In [None]:
# plot a synthetic and real-world example side to side:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. #  light green
        ]
fig, axs = plt.subplots(1, 4, figsize=(4,1.1),  gridspec_kw={'wspace': 0.05, 'hspace': 0.05})
im = imgs[1].copy()


axs[0].imshow(imgs[1].copy())
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].set_ylabel(
    '\nBipolar sorting', fontsize=7.8, labelpad=10,
    rotation=90, va="center", ha="center"
)

axs[1].imshow(imgs[3].copy())
axs[1].axis('off')

im_synth = all_data[4].copy()
utils.plot_cell_image(im_synth, axs[2], colors=colors)
axs[2].axis('off')

im_synth = all_data[18].copy()
utils.plot_cell_image(im_synth, axs[3], colors=colors)
axs[3].axis('off')


# Add group labels for "Type A" and "Type B"
fig.text(0.325, 0.05, 'Lab observations', fontsize=8, ha='center', va='top')
fig.text(0.72, 0.05, 'Synthetic training data', fontsize=8, ha='center', va='top')


plt.tight_layout()
plt.savefig(FIGURE_SAVEDIR+'exp_2_real_and_synth_example.png', dpi=400, transparent=True)
plt.show()




In [None]:
calc_moment_image(imgs[-1], moment=2)

In [None]:
moments = []
moments_norm = []
thresh=165
for i, im in enumerate(imgs):
    im_thresh = im * (im > thresh)
    m2 = calc_moment_image(im_thresh, moment=2)
    m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
    moments.append(m2)
    moments_norm.append(m2_norm)

In [None]:
np.mean(np.stack(moments_norm, axis=0), axis=0)

In [None]:
np.std(np.stack(moments_norm, axis=0), axis=0)

In [None]:
# calculate moments for our training data:
moments_synth = []
moments_synth_norm = []
for d in tqdm(all_data):
    # shape (2, 149, 149)
    im = np.zeros((d.shape[1], d.shape[2], 3)).astype(int)
    im[..., 0] = 255 * (d[1] == 1)
    im[..., 1] = 255 * (d[1] == 2)
    im[..., 2] = 0
    assert 0 < thresh < 255
    im_thresh = im*(im>thresh)
    m2 = calc_moment_image(im_thresh, moment=2)
    m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
    moments_synth.append(m2)
    moments_synth_norm.append(m2_norm)

In [None]:
np.mean(np.stack(moments_synth_norm, axis=0), axis=0)

In [None]:
np.std(np.stack(moments_synth_norm, axis=0), axis=0)

In [None]:
# calculate moments for our sampled data from nch3 model:
moments_model = []
moments_model_norm = []
for d_all_t in tqdm(samples['Exp2/neuralcpm_base/experiment_2_nch3_2950.eqx/samples_PermuteTypeInitializer_134_100_50.npz']):
    d = d_all_t[-1]
    # shape (2, 149, 149)
    im = np.zeros((d.shape[1], d.shape[2], 3)).astype(int)
    im[..., 0] = 255 * (d[1] == 1)
    im[..., 1] = 255 * (d[1] == 2)
    im[..., 2] = 0
    assert 0 < thresh < 255
    im_thresh = im*(im>thresh)
    m2 = calc_moment_image(im_thresh, moment=2)
    m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
    moments_model.append(m2)
    moments_model_norm.append(m2_norm)

In [None]:
np.mean(np.stack(moments_model_norm, axis=0), axis=0)


In [None]:
# plot the moments of toda, morpheus,a nd our model as mean+ error bar of 1 std, for both type 1 and type 2 along the principal component of type 1:
from matplotlib.patches import Patch

mean_toda = np.mean(np.stack(moments_norm, axis=0), axis=0)[:2, 0] # shape (2) for type 1 and 2
std_toda = np.std(np.stack(moments_norm, axis=0), axis=0)[:2, 0]
mean_synth = np.mean(np.stack(moments_synth_norm, axis=0), axis=0)[:2, 0]
std_synth = np.std(np.stack(moments_synth_norm, axis=0), axis=0)[:2, 0]
mean_model = np.mean(np.stack(moments_model_norm, axis=0), axis=0)[:2, 0]
std_model = np.std(np.stack(moments_model_norm, axis=0), axis=0)[:2, 0]


# Data preparation
labels = ['Type 1', 'Type 2']
x = np.arange(len(labels))  # [0, 1]
width = 0.25  # Bar width

# Means and standard deviations for each group
means = [mean_toda, mean_synth, mean_model]
stds = [std_toda, std_synth, std_model]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
group_labels = ['Toda et al., 2018 (real data)', 'Synthetic configurations', 'NeuralCPM (simulated)']
hatches = ['\\', '/', '|']  # More subtle hatching

#Plotting:

fsize = figsizes.icml2024_half(ncols=1, nrows=1)
fsize['figure.figsize'] = (fsize['figure.figsize'][0] * 1, fsize['figure.figsize'][1] * 0.8)
with plt.rc_context(fsize):
    fig, ax = plt.subplots()

    for i, (mean, std, color, label, hatch) in enumerate(zip(means, stds, colors, group_labels, hatches)):
        ax.bar(
            x + i * width - width,  # Shift bars for each group
            mean,
            yerr=std,  # Add error bars
            width=width,
            label=label,
            color=color,
            alpha=0.75,  # Slight transparency
            capsize=4,  # Error bar caps
            edgecolor='black',  # Add edge for contrast
            linewidth=0.7,  # Thin border,
            hatch=hatch
        )

    # Axes labels and title
    ax.set_ylabel('Variance fraction\nalong polar axis')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    # legend_handles = [Patch(facecolor=color, edgecolor='black', hatch=hatch, label=label) for color, hatch, label in zip(colors, hatches, group_labels)]
    # ax.legend(handles=legend_handles, frameon=False)

    ax.legend(frameon=False)
    # ax.grid(axis='y', linestyle='--', alpha=0.6)

    # Display the plot
    plt.tight_layout()
    plt.savefig(FIGURE_SAVEDIR+'moments_comparison_T_final.pdf', dpi=400)
    plt.show()



In [None]:
# now calculate the moments' development over time in the movie (shared by Toda, preprocessed by Lutz),
# our synthetic data, and our samples

In [None]:
data_toda_dir = 'toda_experiment_data/Toda2018_separate_channels_dynamics'

frames = range(1, 51)
# load the images
imgs_movie = []
for f in frames:
    # print(f)
    this_frame = []
    for channel in ['red', 'yellow']:
        im = Image.open(os.path.join(data_toda_dir, f'Toda2018_{channel}_channel_frame_{f:02d}.png'))
        im = np.array(im)
        this_frame.append(im)
    this_frame = np.stack(this_frame, axis=-1)
    blue = np.zeros_like(this_frame[...,:1])
    im = np.concatenate([this_frame, blue], axis=-1)
    imgs_movie.append(im)

In [None]:
len(imgs_movie)

In [None]:
thresh=25 # 20
for i in range(0, 50, 20):
    fig, axs = plt.subplots(1,2)
    im = imgs_movie[i].copy()
    im = im * (im>thresh)
    axs[0].imshow(im)
    axs[0].set_ylabel(f'{i+1}')
    axs[1].imshow(align_img_green_ch(im))
    # axs[2].hist(im.flatten())
    # axs[2].set_yscale('log')
    plt.show()

In [None]:
# a single plot with one qualitative example for each model:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. / 1.15 #  light green
        ]

# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000

if 'real_data' in samples:
    samples.pop('real_data')


data_to_plot = []
names_of_models = ['experiment_2_cellsort_9950.eqx', 'experiment_2_ext_pot_9950.eqx', 'experiment_2_conv_ham_500.eqx',
                   'experiment_2_shallow_nh_400.eqx', 'experiment_2_nh_3200.eqx', 'experiment_2_nch3_2950.eqx']
names_to_plot = ['Cellsort\nHamiltonian', 'Cellsort\nHamiltonian\n+External\nPotential', 'CNN', '1 NH layer\n+CNN',
                 'Neural\nHamiltonian', 'Neural\nHamiltonian\n+closure']

keys = list(samples.keys())


for name in names_of_models:
    k = list(filter(lambda x: name in x, keys))[0]
    print(name, k)
    v = samples[k]
    sample = v
    data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
    data_this = data_this[:, :100]
    data_to_plot.append(data_this)

data_to_plot = np.concatenate(data_to_plot, axis=0)
ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
time_labels = [i * float(np.round(750 / 4/60, 1)) for i in range(5)]

fig, axs = plt.subplots(len(data_to_plot)+1, len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(data_to_plot)+1), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})

num_mcs_per_t = 134 * 50 / (data_to_plot.shape[-1] * data_to_plot.shape[-2])

# plot toda et al movie:

for i in range(0, 5):
    im = imgs_movie[i*10].copy()
    axs[0, i].imshow(im)
    if i == 0:
        axs[0, i].set_ylabel('Toda et al.,\n2018\n(real data)', fontsize=8, labelpad=20,
                    rotation=90, va="center", ha="center")
    axs[0, i].axis("on")  # Explicitly enable axes for adding labels
    axs[0, i].set_xticks([])
    axs[0, i].set_yticks([])


utils.plot_cell_trajectory_data(data_to_plot, len(data_to_plot), ts_to_plot, axs[1:], colors=colors)


# Adjust axis labels and formatting
for i, ax_row in enumerate(axs[1:]):
    for j, ax in enumerate(ax_row):
        ax.axis("on")  # Explicitly enable axes for adding labels
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:  # Add row labels to the left of the subplots
            ax.set_ylabel(
                names_to_plot[i], fontsize=8, labelpad=20,
                rotation=90, va="center", ha="center"
            )
        if i == len(axs)-2:  # Add x-axis labels below the bottom row
            ax.set_xlabel(f"{time_labels[j]}", fontsize=8, labelpad=10)

# Adjust layout and show the plot
plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels

# Add a global x-axis label for time in seconds
fig.text(0.5, 0.04, 'Time (hours)', ha='center', va='center', fontsize=9)
plt.savefig(FIGURE_SAVEDIR+'traj_exp2-incl-toda.pdf', dpi=400)
plt.show()

print('\n')

In [None]:
#REBUTTAL:


# a single plot with one qualitative example for each model:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. / 1.15 #  light green
        ]

# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000

if 'real_data' in samples:
    samples.pop('real_data')


data_to_plot = []
names_of_models = ['experiment_2_closure_gnn_6100.eqx'] * 3 + ['experiment_2_nch_no_interactions_3700.eqx'] * 3
names_to_plot = 3*['GNN\n+ closure'] + 3*['NH + closure\nno interactions']

keys = list(samples.keys())


for name in names_of_models:
    k = list(filter(lambda x: name in x, keys))[0]
    print(name, k)
    v = samples[k]
    sample = v
    data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
    data_this = data_this[:, :100]
    data_to_plot.append(data_this)

data_to_plot = np.concatenate(data_to_plot, axis=0)
ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
time_labels = [i * float(np.round(750 / 4/60, 1)) for i in range(5)]

fig, axs = plt.subplots(len(data_to_plot), len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(data_to_plot)+1), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})

num_mcs_per_t = 134 * 50 / (data_to_plot.shape[-1] * data_to_plot.shape[-2])



utils.plot_cell_trajectory_data(data_to_plot, len(data_to_plot), ts_to_plot, axs, colors=colors)


# Adjust axis labels and formatting
for i, ax_row in enumerate(axs):
    for j, ax in enumerate(ax_row):
        ax.axis("on")  # Explicitly enable axes for adding labels
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:  # Add row labels to the left of the subplots
            ax.set_ylabel(
                names_to_plot[i], fontsize=8, labelpad=20,
                rotation=90, va="center", ha="center"
            )
        if i == len(axs)-1:  # Add x-axis labels below the bottom row
            ax.set_xlabel(f"{time_labels[j]}", fontsize=8, labelpad=10)

# Adjust layout and show the plot
plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels

# Add a global x-axis label for time in seconds
fig.text(0.5, 0.04, 'Time (hours)', ha='center', va='center', fontsize=9)
plt.savefig(FIGURE_SAVEDIR+'traj_exp2_rebuttal.pdf', dpi=400)
plt.show()

print('\n')



In [None]:
moments_movie = []
moments_movie_norm = []
for i, im in enumerate(imgs_movie):
    im_thresh = im * (im > thresh)
    m2 = calc_moment_image(im_thresh, moment=2)
    m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
    moments_movie.append(m2)
    moments_movie_norm.append(m2_norm)

In [None]:
# moments_movie_norm[0]

In [None]:
# plot the normalized moment in the x direction for both types
to_plot_movie = []
for m in moments_movie_norm:
    to_plot_movie.append(m[:2, 0])
to_plot_movie = np.stack(to_plot_movie, axis=0)

In [None]:
plt.plot(to_plot_movie)
plt.show()

In [None]:
# Now get the morpheus and sampled data and make a similar plot:

path = '../data/Exp_2_toda_padded/'
all_data_incl_time = load_data_batched_real(100, #len(os.listdir(path)),
        path, t=None)


In [None]:
moments_dynamic_morpheus_norm = []
for d in tqdm(all_data_incl_time):
    moments_this_traj = []
    for d_t in d:
        im = np.zeros((d_t.shape[1], d_t.shape[2], 3)).astype(int)
        im[..., 0] = 255 * (d_t[1] == 1)
        im[..., 1] = 255 * (d_t[1] == 2)
        im[..., 2] = 0
        assert 0 < thresh < 255
        im_thresh = im*(im>thresh)
        m2 = calc_moment_image(im_thresh, moment=2)
        m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
        moments_this_traj.append(m2_norm[:2, 0])
    moments_dynamic_morpheus_norm.append(moments_this_traj)
moments_dynamic_morpheus_norm = np.array(moments_dynamic_morpheus_norm)

In [None]:
plt.plot(moments_dynamic_morpheus_norm.mean(0))
# plt.fill_between(range(moments_dynamic_morpheus_norm.shape[1]),
#                     moments_dynamic_morpheus_norm.mean(0) - moments_dynamic_morpheus_norm.std(0),
#                     moments_dynamic_morpheus_norm.mean(0) + moments_dynamic_morpheus_norm.std(0),
#                     alpha=0.2)
plt.show()

In [None]:
moments_dynamic_samples_norm = []
for d in tqdm(samples['Exp2/neuralcpm_base/experiment_2_nch3_2950.eqx/samples_PermuteTypeInitializer_134_100_50.npz']):
    moments_this_traj = []
    for d_t in d:
        im = np.zeros((d_t.shape[1], d_t.shape[2], 3)).astype(int)
        im[..., 0] = 255 * (d_t[1] == 1)
        im[..., 1] = 255 * (d_t[1] == 2)
        im[..., 2] = 0
        assert 0 < thresh < 255
        im_thresh = im*(im>thresh)
        m2 = calc_moment_image(im_thresh, moment=2)
        m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
        moments_this_traj.append(m2_norm[:2, 0])
    moments_dynamic_samples_norm.append(moments_this_traj)
moments_dynamic_samples_norm = np.array(moments_dynamic_samples_norm)

In [None]:
moments_dynamic_samples_norm[:, -1].mean(0)

In [None]:

fsize = figsizes.icml2024_half(ncols=1, nrows=1)
fsize['figure.figsize'] = (fsize['figure.figsize'][0] * 1, fsize['figure.figsize'][1] * 0.8)
with plt.rc_context(fsize):

    colors = [
                np.array([[0.,0.,0.]]),# black
                np.array([[0.,0.,0.25]]),# dark blue
                np.array([[1.,0.,0.]]), #  red
                np.array([[204.,255.,11.]]) / 255. / 1.15 #  light green
            ]

    times = [i * 15 for i in range(to_plot_movie.shape[0])]
    for i in range(2):
        c = colors[i+2]
        if i ==1:
            c = c/1.15
        m_this_type_sample = moments_dynamic_samples_norm[:, :-1:2, i]
        m_this_type_movie = to_plot_movie[:, i]
        plt.plot(times, m_this_type_sample.mean(0), color=c, label=f'type {str(i+1)} - NeuralCPM')
        plt.fill_between(times,
                        m_this_type_sample.mean(0) - m_this_type_sample.std(0),
                        m_this_type_sample.mean(0) + m_this_type_sample.std(0),
                        alpha=0.2, color=c)
        plt.plot(times, m_this_type_movie, color=c, linestyle='None', marker='*', label=f'type {str(i+1)} - real data\n(Toda et al., 2018)', markersize=6)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.55), ncol=2, frameon=False)
    plt.xlabel('Time (minutes)')
    plt.ylabel('Variance fraction\nalong polar axis')
    # plt.ylim(0.1,0.8)
    plt.savefig(FIGURE_SAVEDIR+'moment_dynamics-neuralcpm-toda.pdf')
    plt.show()

In [None]:
# same but for cellsort simulator:

moments_dynamic_samples_norm = []
for d in tqdm(samples['Exp2/neuralcpm_base/experiment_2_cellsort_9950.eqx/samples_PermuteTypeInitializer_134_100_50.npz']):
    moments_this_traj = []
    for d_t in d:
        im = np.zeros((d_t.shape[1], d_t.shape[2], 3)).astype(int)
        im[..., 0] = 255 * (d_t[1] == 1)
        im[..., 1] = 255 * (d_t[1] == 2)
        im[..., 2] = 0
        assert 0 < thresh < 255
        im_thresh = im*(im>thresh)
        m2 = calc_moment_image(im_thresh, moment=2)
        m2_norm = m2 / np.sum(m2, axis=1, keepdims=True)
        moments_this_traj.append(m2_norm[:2, 0])
    moments_dynamic_samples_norm.append(moments_this_traj)
moments_dynamic_samples_norm = np.array(moments_dynamic_samples_norm)

fsize = figsizes.icml2024_half(ncols=1, nrows=1)
fsize['figure.figsize'] = (fsize['figure.figsize'][0] * 1, fsize['figure.figsize'][1] * 1)
with plt.rc_context(fsize):

    colors = [
                np.array([[0.,0.,0.]]),# black
                np.array([[0.,0.,0.25]]),# dark blue
                np.array([[1.,0.,0.]]), #  red
                np.array([[204.,255.,11.]]) / 255. / 1.15 #  light green
            ]

    times = [i * 15 for i in range(to_plot_movie.shape[0])]
    for i in range(2):
        c = colors[i+2]
        if i ==1:
            c = c/1.15
        m_this_type_sample = moments_dynamic_samples_norm[:, :-1:2, i]
        m_this_type_movie = to_plot_movie[:, i]
        plt.plot(times, m_this_type_sample.mean(0), color=c, label=f'type {str(i+1)} - Cellsort')
        plt.fill_between(times,
                        m_this_type_sample.mean(0) - m_this_type_sample.std(0),
                        m_this_type_sample.mean(0) + m_this_type_sample.std(0),
                        alpha=0.2, color=c)
        plt.plot(times, m_this_type_movie, color=c, linestyle='None', marker='*', label=f'type {str(i+1)} - real data\n(Toda et al., 2018)', markersize=6)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.55), ncol=2, frameon=False)
    plt.xlabel('Time (minutes)')
    plt.ylabel('Variance fraction\nalong polar axis')
    plt.ylim(0.1,0.8)
    plt.savefig(FIGURE_SAVEDIR+'moment_dynamics-cellsort-toda.pdf', dpi=400)
    plt.show()



In [None]:
data_to_plot[1:2].shape

In [None]:
# a single plot with one qualitative example for each model:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. / 1.15 #  light green
        ]

# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000

if 'real_data' in samples:
    samples.pop('real_data')


data_to_plot = []
names_of_models = ['experiment_2_cellsort_9950.eqx', 'experiment_2_nch3_2950.eqx']
names_to_plot = ['CPM', 'NeuralCPM']

keys = list(samples.keys())


for name in names_of_models:
    k = list(filter(lambda x: name in x, keys))[0]
    print(name, k)
    v = samples[k]
    sample = v
    data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
    data_this = data_this[:, :100]
    data_to_plot.append(data_this)

data_to_plot = np.concatenate(data_to_plot, axis=0)
ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
fig, axs = plt.subplots(len(data_to_plot)+1, len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(data_to_plot)+1), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})


# plot toda et al movie:

for i in range(0, 5):
    im = imgs_movie[i*10 + 5].copy()
    axs[0, i].imshow(im)
    if i == 0:
        axs[0, i].set_ylabel('Laboratory\nobservations', fontsize=8, labelpad=20,
                    rotation=90, va="center", ha="center")
    axs[0, i].axis("on")  # Explicitly enable axes for adding labels
    axs[0, i].set_xticks([])
    axs[0, i].set_yticks([])


utils.plot_cell_trajectory_data(data_to_plot, len(data_to_plot), ts_to_plot, axs[1:], colors=colors)


# Adjust axis labels and formatting
for i, ax_row in enumerate(axs[1:]):
    for j, ax in enumerate(ax_row):
        ax.axis("on")  # Explicitly enable axes for adding labels
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:  # Add row labels to the left of the subplots
            ax.set_ylabel(
                names_to_plot[i], fontsize=8, labelpad=20,
                rotation=90, va="center", ha="center"
            )


# Adjust layout and show the plot
plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels

# Add a global x-axis label for time in seconds
# fig.text(0.5, 0.04, 'Time (MCS)', ha='center', va='center', fontsize=9)
plt.savefig(FIGURE_SAVEDIR+'intro_figure.pdf', dpi=400)
plt.show()

print('\n')