In [None]:
import jax
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import os
import utils
import optax
import time
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from tueplots import bundles
# bundle = bundles.icml2024()
plt.rcParams.update(bundles.icml2024(usetex=False, family='sans-serif'))
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 1e-3
WEIGHT_DECAY=1e-3
STEPS = 676
PRINT_EVERY = 25
SEED = 12345678
USE_RESIDUAL = True
STRIDE_FIRST_LAYER=2

FIGURE_SAVEDIR = 'Experiment_figures/'
if not os.path.exists(FIGURE_SAVEDIR):
    os.makedirs(FIGURE_SAVEDIR)

In [None]:
def load_data_batch_for_classifier(bs, data_dir, t=-1):

    files = os.listdir(data_dir)
    # select some random files to load:
    selected_files = np.random.choice(files, size=bs, replace=True)  # list of .npz files
    data = []
    labels = []
    for f in selected_files:
        # load the final state of each simulation. shape is (t, 2, grid_size, grid_size)
        npz = np.load(os.path.join(data_dir, f))
        arr = npz['data']
        label = npz['label']
        if t is not None:
            t_int = t
            arr = arr[t_int]
        data.append(
            arr[None, ...]
        )  # now shape (1, [t], 2, grid_size, grid_size)
        labels.append(np.array([label]))
    return np.concatenate(data), np.concatenate(labels)

In [None]:
# test_arr, test_label = load_data_batch_for_classifier(1, '../data/MNIST_jaxCPM_data', t=-1)

In [None]:
path = '../data/MNIST_jaxCPM_data'
all_data, all_labels = load_data_batch_for_classifier(len(os.listdir(path)),
        path, t=-1)

In [None]:
all_data = all_data.astype(np.float32)

In [None]:
fig, axs = plt.subplots(3, 5)
for i, ax in enumerate(axs.flat):
    plotted = utils.plot_cell_image(all_data[i], ax)
    # ax.set_title(f"Label: {all_labels[i]}")
    ax.axis('off')
plt.tight_layout()
plt.savefig(FIGURE_SAVEDIR+'some_real_cpm_mnist.png')
plt.show()

In [None]:
# plot four example mnist images:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. #  light green
        ]


fig, axs = plt.subplots(1, 4, figsize=(4,1),  gridspec_kw={'wspace': 0.05, 'hspace': 0.05})
idx_to_plot = [14,5,8,13]
for i, ax in enumerate(axs.flat):
    idx = idx_to_plot[i]
    utils.plot_cell_image(all_data[idx], ax, colors=colors)
    ax.set_xticks([])
    ax.set_yticks([])
    if i == 0:
        ax.set_ylabel(
    'Cellular\nMNIST', fontsize=8, labelpad=10, 
    rotation=90, va="center", ha="center"
)
        
# Add group labels for "Type A" and "Type B"
fig.text(0.325, 0.05, 'Open structures', fontsize=8, ha='center', va='top')
fig.text(0.72, 0.05, 'Enclosed structures', fontsize=8, ha='center', va='top')

plt.tight_layout()
plt.savefig(FIGURE_SAVEDIR+'real_cellular_mnist.png', dpi=400, transparent=True)
plt.show()

In [None]:
# Split the data into a training, validation and test set
train_idx = int(0.8 * len(all_data))
val_idx = int(0.9 * len(all_data))
test_idx = len(all_data)

train_data = all_data[:train_idx, 1:2]  # only take second channel since this is the type channel
train_labels = all_labels[:train_idx]
val_data = all_data[train_idx:val_idx, 1:2]
val_labels = all_labels[train_idx:val_idx]
test_data = all_data[val_idx:test_idx, 1:2]
test_labels = all_labels[val_idx:test_idx]

In [None]:
def dataloader(data, labels, bs, infinite=False):
    if not infinite:
        for i in range(0, len(data), bs):
            yield data[i:i+bs], labels[i:i+bs]
    else:
        while True:
            for i in range(0, len(data), bs):
                yield data[i:i+bs], labels[i:i+bs]

In [None]:
# plot some samples from the datalaoder:
for x, y in dataloader(train_data, train_labels, 16):
    fig, axs = plt.subplots(2, 4)
    for i, ax in enumerate(axs.flat):
        ax.imshow(x[i, 0])
        ax.set_title(f"Label: {y[i]}")
        ax.axis('off')
    plt.show()
    break

In [None]:
class CNN(eqx.Module):
    conv_layers: list
    residual_conv_layers: dict
    use_residual: bool
    mlp_layers: list
    residual_mlp_layers: dict
    first_strided_layer: callable
    map_to_lsm: callable

    def __init__(self, key, conv_layer_channels=(32,64,64,64), mlp_layer_channels=(64,32), use_residual=False, stride_first_layer=1, conv_kwargs={}):
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        
        self.conv_layers=[]
        self.mlp_layers = []
        
        self.residual_conv_layers = {}
        self.residual_mlp_layers = {}
        self.use_residual = use_residual
        num_ch = 3
        
        key, use_key = jax.random.split(key)
        self.first_strided_layer = eqx.nn.Identity() if stride_first_layer == 1 else eqx.nn.MaxPool2d(kernel_size=2, stride=2)

        
        layer_counter = 0
        for i, ch_next in enumerate(conv_layer_channels):
            key, use_key = jax.random.split(key)
            self.conv_layers.append(eqx.nn.Conv2d(num_ch, ch_next, kernel_size=3, key=use_key, padding='SAME'))
            self.conv_layers.append(jax.nn.silu)    
            key, use_key = jax.random.split(key)
            self.conv_layers.append(eqx.nn.Conv2d(ch_next, ch_next, kernel_size=3, key=use_key, padding='SAME'))
            self.conv_layers.append(jax.nn.silu)  
            self.conv_layers.append(eqx.nn.MaxPool2d(kernel_size=2, stride=2))
            layer_counter += 5

            if self.use_residual:
                key, use_key = jax.random.split(key)
                self.residual_conv_layers[layer_counter] = eqx.nn.Conv2d(num_ch, ch_next, kernel_size=2, stride=2, key=use_key)
                                                                     
            num_ch = ch_next
        
        layer_counter = 0
        for i, ch_next in enumerate(mlp_layer_channels):
            key, use_key = jax.random.split(key)
            self.mlp_layers.append(eqx.nn.Linear(num_ch, ch_next, key=use_key))
            self.mlp_layers.append(jax.nn.silu) 
            layer_counter += 2
            if self.use_residual:
                key, use_key = jax.random.split(key)
                self.residual_mlp_layers[layer_counter] = eqx.nn.Linear(num_ch, ch_next, key=use_key)
            num_ch = ch_next
            
        key, use_key = jax.random.split(key)
        self.map_to_lsm = eqx.nn.Sequential([eqx.nn.Linear(num_ch, 10, key=use_key),
                                            eqx.nn.Lambda(jax.nn.log_softmax)])
        

    def __call__(self, x):
        x = self.get_final_emb(x)
        return self.map_to_lsm(x)
    
    def get_final_emb(self, x):
        x = jax.nn.one_hot(x[0], 3)
        x = jnp.permute_dims(x, (2,0,1))
        x = self.first_strided_layer(x)
        last_x_for_residual = x
        for i, layer in enumerate(self.conv_layers):
            # print(i, layer)
            if i in self.residual_conv_layers:
                x = x + self.residual_conv_layers[i](last_x_for_residual)
                last_x_for_residual = x
            x = layer(x)
        
        x = jnp.mean(x, axis=(-2, -1))
        last_x_for_residual = x
        
        for i, layer in enumerate(self.mlp_layers):
            if i in self.residual_mlp_layers:
                x = x + self.residual_mlp_layers[i](last_x_for_residual)
                last_x_for_residual = x
            x = layer(x)
        
        return x

        

In [None]:
model = CNN(jax.random.PRNGKey(SEED), conv_layer_channels=(32,64), mlp_layer_channels=(256,32),
            use_residual=USE_RESIDUAL, stride_first_layer=STRIDE_FIRST_LAYER)

In [None]:
print(model)

In [None]:
def cross_entropy(y, pred_y):
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

def loss(model: CNN, x, y):
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)

    return cross_entropy(y, pred_y)

In [None]:
loss = eqx.filter_jit(loss)

@eqx.filter_jit
def compute_accuracy(model: CNN, x, y):
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)


In [None]:
def evaluate(model: CNN, test_data_iterable):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    count = 0
    for x, y in test_data_iterable:
        x = x
        y = y
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
        count += 1
    return avg_loss / count, avg_acc / count

In [None]:
evaluate(model, dataloader(val_data, val_labels, 32))

In [None]:
optim = optax.adamw(LEARNING_RATE, weight_decay=WEIGHT_DECAY)


In [None]:
def train(
    model: CNN,
    train_data_iterable,
    train_labels_iterable,
    test_data_iterable,
    test_labels_iterable,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
):
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state,
        x,
        y,
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(
            grads, opt_state, eqx.filter(model, eqx.is_array)
        )
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value
    
    trainloader = dataloader(train_data_iterable, train_labels_iterable, BATCH_SIZE, infinite=True)

    # # Loop over our training dataset as many times as we need.
    # def infinite_trainloader():
    #     while True:
    #         yield from trainloader

    for step, (x, y) in zip(range(steps), trainloader):
        tic = time.time()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        toc = time.time()
        if (step % print_every) == 0 or (step == steps - 1):
            # fig, axs = plt.subplots(2, 4)
            # for i, ax in enumerate(axs.flat):
            #     ax.imshow(x[i, 0])
            #     ax.set_title(f"Label: {y[i]}")
            #     ax.axis('off')
            # plt.show()
            testloader = dataloader(test_data_iterable, test_labels_iterable, BATCH_SIZE)
            test_loss, test_accuracy = evaluate(model, testloader)
            print(
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
                f" time={toc-tic} s"
            )
    return model

In [None]:
# model = train(model, train_data, train_labels, val_data, val_labels, optim, STEPS, PRINT_EVERY)


In [None]:
# eqx.tree_serialise_leaves('CPM_MNIST_classifier.eqx', model)

In [None]:
evaluate(model, dataloader(test_data, test_labels, 32))


In [None]:
  ### load the model:

In [None]:
model = eqx.tree_deserialise_leaves('CPM_MNIST_classifier.eqx', model)

In [None]:
evaluate(model, dataloader(test_data, test_labels, 32))

In [None]:
# now calculate the inception score with the model on the training data itself:

def inception_score(model, data, labels, bs, num_samples=1000):
    # sample num_samples images from the data
    idx = np.random.choice(len(data), num_samples, replace=False)
    data = data[idx]
    labels = labels[idx]
    loader = dataloader(data, labels, bs)
    preds = []
    for x, y in loader:
        pred = jax.vmap(model)(x)
        preds.append(pred)
    preds = jnp.concatenate(preds)
    preds = jnp.exp(preds)  # from log_softmax to softmax
    preds = preds / jnp.sum(preds, axis=1, keepdims=True)
    kl_divs = jnp.sum(preds * (jnp.log(preds + 1e-6) - jnp.log(jnp.mean(preds, axis=0))), axis=1)
    return jnp.exp(jnp.mean(kl_divs))

In [None]:
IS = inception_score(model, train_data, train_labels, 32, len(train_data))

In [None]:
print(IS)
print(train_data.shape)

In [None]:
# load samples:
def shuffle(arr):
    idx = np.random.choice(np.arange(arr.shape[0]), replace=False, size=arr.shape[0])
    return arr[idx]


paths = []
for p in os.listdir('Exp1/cpm/'):
    if p.endswith('.eqx'):
        files = os.listdir(os.path.join('Exp1/cpm/', p))
        for f in files:
            if f.endswith('.npz'):
                paths.append(os.path.join('Exp1/cpm/', p, f))


datas = {}
samples = {}
energies = {}
for path in paths:
    data = np.load(path)
    datas[path] = data
    samples[path] = data['all_samples']
    energies[path] = data['all_energies']
    print(path, samples[path].shape)
    

In [None]:
samples.keys()

In [None]:
print('INCEPTION SCORES - higher is better\n')
for k, v in samples.items():
    sample = v
    # print(v.shape)
    samples_type = sample[:, -1, 1:2]
    IS_samples = inception_score(model, samples_type, np.zeros(len(samples_type)), 16, len(samples_type))
    print(k.split('/')[-2], '\t\t', float(np.round(IS_samples, 2)))



In [None]:
for k, v in samples.items():
    sample = v
    samples_type = sample[:, -1, 1:2]
    preds = jax.vmap(model)(samples_type).argmax(axis=1)
    print(k, np.unique(preds, return_counts=True))

In [None]:
print('plotting real data trajectories')
num_samples = 10
ts_to_plot = [1, 4, 7, 10, 13, 16, 19]
some_data_incl_time, _ = load_data_batch_for_classifier(num_samples,
        '../data/MNIST_jaxCPM_data', t=None, )
fig, axs = plt.subplots(num_samples, len(ts_to_plot), figsize=(10, num_samples))

utils.plot_cell_trajectory_data(some_data_incl_time, num_samples, ts_to_plot, axs, colors=colors)
for ax in axs.flat: ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000
num_samples_to_plot = 10
if 'real_data' in samples:
    samples.pop('real_data')
for k, v in samples.items():
    
    name = k.split('/')[-2]
    sample = v
    print(name)
    data_to_plot = shuffle(sample)[0:num_samples_to_plot] #[shuffle(some_data_incl_time)[0:1]]

    ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
    fig, axs = plt.subplots(data_to_plot.shape[0], len(ts_to_plot), figsize=(len(ts_to_plot)*2, len(samples.items())), squeeze=False)
    
    utils.plot_cell_trajectory_data(data_to_plot, num_samples, ts_to_plot, axs)
    
    for ax in axs.flat: ax.axis('off')
    # for i, row in enumerate(axs):
    #     row[0].set_ylabel(names[i])
    #     print(row[0].get_ylabel())
    plt.tight_layout()
    plt.show()
    print('\n')



In [None]:
# a single plot with one qualitative example for each model:

# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000
num_samples_to_plot = 10
if 'real_data' in samples:
    samples.pop('real_data')

data_to_plot = []
names_of_models = ['experiment_1_cellsort_4990.eqx', 'experiment_1_extpot_4990.eqx', 'experiment_1_conv_ham_850.eqx',
                   'experiment_1_shallow_nh_6250.eqx', 'experiment_1_nh_5650.eqx', 'experiment_1_nch_agg_downsample_8660.eqx']
names_to_plot = ['Cellsort\nHamiltonian', 'Cellsort\nHamiltonian\n+External\nPotential', 'CNN', '1 NH layer\n+CNN',
                 'Neural\nHamiltonian', 'Neural\nHamiltonian\n+closure']

keys = list(samples.keys())

for name in names_of_models:
    k = list(filter(lambda x: name in x, keys))[0]
    print(name, k)
    v = samples[k]
    sample = v
    data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
    data_to_plot.append(data_this)

data_to_plot = np.concatenate(data_to_plot, axis=0)

# ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
ts_to_plot = [0, 24, 48, 72, 96]
ts_to_display =  [0, 200, 400, 600, 800]

mcs_correction_factor = 2/3 * 13.1  
#2/3: we saved at each 2/3 *h*w spin flip attempts -- 13.1: correction factor since we only sampled on the boundary of the pixels

fig, axs = plt.subplots(len(samples.items()), len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(samples.items())), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})

utils.plot_cell_trajectory_data(data_to_plot, num_samples, ts_to_plot, axs, colors=colors)


# Adjust axis labels and formatting
for i, ax_row in enumerate(axs):
    for j, ax in enumerate(ax_row):
        ax.axis("on")  # Explicitly enable axes for adding labels
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:  # Add row labels to the left of the subplots
            ax.set_ylabel(
                names_to_plot[i], fontsize=8, labelpad=20, 
                rotation=90, va="center", ha="center"
            )
        if i == len(axs) - 1:  # Add x-axis labels below the bottom row
            ax.set_xlabel(f"{int(ts_to_display[j])}", fontsize=8, labelpad=10)

# Adjust layout and show the plot
plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels

# Add a global x-axis label for time in seconds
fig.text(0.5, 0.04, 'Time (MCS)', ha='center', va='center', fontsize=9)
plt.savefig(FIGURE_SAVEDIR+'traj_exp1.pdf', dpi=400, transparent=True)
plt.show()

print('\n')


In [None]:
# Rebuttal plots:

# a single plot with one qualitative example for each model:

# qualitative plots for true and generated samples:
stepsize = 20
max_steps = 1000
num_samples_to_plot = 10
if 'real_data' in samples:
    samples.pop('real_data')

data_to_plot = []
names_of_models = 2*['experiment_1_closure_gnn_seed42_9950.eqx'] + 2*['experiment_1_nch_no_interactions_9950.eqx'] + 2*['experiment_1_nch_no_pooling_3750.eqx']
names_to_plot = 2*['GNN\n+ closure'] + 2*['NH + closure\nno interactions'] + 2*['NH + closure\nno pooling']
keys = list(samples.keys())

for name in names_of_models:
    k = list(filter(lambda x: name in x, keys))[0]
    print(name, k)
    v = samples[k]
    sample = v
    data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
    data_to_plot.append(data_this)

data_to_plot = np.concatenate(data_to_plot, axis=0)

# ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
ts_to_plot = [0, 24, 48, 72, 96]
ts_to_display =  [0, 200, 400, 600, 800]

mcs_correction_factor = 2/3 * 13.1  
#2/3: we saved at each 2/3 *h*w spin flip attempts -- 13.1: correction factor since we only sampled on the boundary of the pixels

fig, axs = plt.subplots(len(names_of_models), len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(names_of_models)+0.75), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})

utils.plot_cell_trajectory_data(data_to_plot, num_samples, ts_to_plot, axs, colors=colors)


# Adjust axis labels and formatting
for i, ax_row in enumerate(axs):
    for j, ax in enumerate(ax_row):
        ax.axis("on")  # Explicitly enable axes for adding labels
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:  # Add row labels to the left of the subplots
            ax.set_ylabel(
                names_to_plot[i], fontsize=8, labelpad=20, 
                rotation=90, va="center", ha="center"
            )
        if i == len(axs) - 1:  # Add x-axis labels below the bottom row
            ax.set_xlabel(f"{int(ts_to_display[j])}", fontsize=8, labelpad=10)

# Adjust layout and show the plot
plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels

# Add a global x-axis label for time in seconds
fig.text(0.5, 0.04, 'Time (MCS)', ha='center', va='center', fontsize=9)
plt.savefig(FIGURE_SAVEDIR+'traj_exp1_rebuttal.jpg', dpi=400, transparent=True)
plt.show()

print('\n')


In [None]:
if not os.path.exists(os.path.join(FIGURE_SAVEDIR, 'add_qual_exp1')):
    os.makedirs(os.path.join(FIGURE_SAVEDIR, 'add_qual_exp1'))
# a single plot with one qualitative example for each model:
for plot_id in range(4):
    # qualitative plots for true and generated samples:
    stepsize = 20
    max_steps = 1000
    num_samples_to_plot = 10
    if 'real_data' in samples:
        samples.pop('real_data')
    
    data_to_plot = []
    names_of_models = ['experiment_1_cellsort_4990.eqx', 'experiment_1_extpot_4990.eqx', 'experiment_1_conv_ham_850.eqx',
                       'experiment_1_shallow_nh_6250.eqx', 'experiment_1_nh_5650.eqx', 'experiment_1_nch_agg_downsample_8660.eqx']
    names_to_plot = ['Cellsort\nHamiltonian', 'Cellsort\nHamiltonian\n+External\nPotential', 'CNN', '1 NH layer\n+CNN',
                     'Neural\nHamiltonian', 'Neural\nHamiltonian\n+closure']
    
    keys = list(samples.keys())
    
    for name in names_of_models:
        k = list(filter(lambda x: name in x, keys))[0]
        print(name, k)
        v = samples[k]
        sample = v
        data_this = shuffle(sample)[0:1] #[shuffle(some_data_incl_time)[0:1]]
        data_to_plot.append(data_this)
    
    data_to_plot = np.concatenate(data_to_plot, axis=0)
    
    # ts_to_plot = [i * stepsize for i in range(min(data_to_plot.shape[1] // stepsize, max_steps))]
    ts_to_plot = [0, 24, 48, 72, 96]
    ts_to_display =  [0, 200, 400, 600, 800]
    
    mcs_correction_factor = 2/3 * 13.1  
    #2/3: we saved at each 2/3 *h*w spin flip attempts -- 13.1: correction factor since we only sampled on the boundary of the pixels
    
    fig, axs = plt.subplots(len(samples.items()), len(ts_to_plot), figsize=(len(ts_to_plot)*1, len(samples.items())), squeeze=False, gridspec_kw={'wspace': 0.05, 'hspace': 0.05})
    
    utils.plot_cell_trajectory_data(data_to_plot, num_samples, ts_to_plot, axs, colors=colors)
    
    
    # Adjust axis labels and formatting
    for i, ax_row in enumerate(axs):
        for j, ax in enumerate(ax_row):
            ax.axis("on")  # Explicitly enable axes for adding labels
            ax.set_xticks([])
            ax.set_yticks([])
            if j == 0:  # Add row labels to the left of the subplots
                ax.set_ylabel(
                    names_to_plot[i], fontsize=8, labelpad=20, 
                    rotation=90, va="center", ha="center"
                )
            if i == len(axs) - 1:  # Add x-axis labels below the bottom row
                ax.set_xlabel(f"{int(ts_to_display[j])}", fontsize=8, labelpad=10)
    
    # Adjust layout and show the plot
    plt.tight_layout(rect=[0.05, 0.05, 0.05, 0.05])  # Leave space for labels
    
    # Add a global x-axis label for time in seconds
    fig.text(0.5, 0.04, 'Time (MCS)', ha='center', va='center', fontsize=9)
    plt.savefig(os.path.join(FIGURE_SAVEDIR, 'add_qual_exp1', f'traj_exp1_{plot_id}.pdf'), dpi=400, transparent=True)
    plt.show()
    
    print('\n')


In [None]:
# calculate the stable states:

num_cells = len(np.unique(all_data[:,0]))
volumes = jax.vmap(utils.calculate_all_cell_volumes, in_axes=(0, None))(all_data, num_cells)[:, 1:] # ignore medium

vol_low, vol_high = np.min(volumes), np.max(volumes)

In [None]:
print(vol_low, vol_high)

In [None]:
samples['real_data'] = all_data[:, None]

In [None]:
stable_dict = {}
frag_dict = {}
vol_dict = {}

from scipy.ndimage import label
neighborhood_order = 4


def calc_frag_batched(samples, num_cells, neighborhood_order): # this is not transformable under jax unfortunately
    num_fragmented = np.zeros(samples.shape[0])
    for i, sample in enumerate(samples):
        num_fragmented[i] = utils.count_num_fragmented(sample, num_cells, neighborhood_order)
    return num_fragmented

# calc_frag_batched = eqx.filter_jit(jax.vmap(
#             utils.count_num_fragmented, in_axes=(0, None, None)
#         ))
calc_vol_batched = eqx.filter_jit(jax.vmap(
            utils.calculate_all_cell_volumes, in_axes=(0, None)
        ))
for k, v in samples.items():
    print(k)
    sample = v
    stable_times = []
    frag_times = []
    vol_times = []
    for t in tqdm(range(0, sample.shape[1], 5)):
        # stable = calc_stable_batched(sample[:, t], num_cells, vol_low, vol_high, 0, 0, 2)
        frag = calc_frag_batched(np.array(sample[:, t]), num_cells, neighborhood_order)
        vol = calc_vol_batched(sample[:, t], num_cells)[:, 1:]
        # stable_times.append(stable[:, None, ...])
        frag_times.append(frag[:, None, ...])
        vol_times.append(vol[:, None, ...])
    # stable_dict[k] = np.concatenate(stable_times, axis=1)
    frag_dict[k] = np.concatenate(frag_times, axis=1)
    vol_dict[k] = np.concatenate(vol_times, axis=1)

In [None]:
min_vol = vol_low - 10
max_vol = vol_high + 10
max_num_vol_violate = 0
max_num_fragmented = 3

fig, ax = plt.subplots(1, 3)

for k, v in frag_dict.items():
    num_frag = v
    vol = vol_dict[k]
    vol_good = jnp.logical_or((vol < min_vol), (vol > max_vol)).sum(-1) <= max_num_vol_violate
    frag_good = num_frag <= max_num_fragmented
    stable = jnp.logical_and(frag_good, vol_good)
    # stable = np.cumprod(stable, axis=1)
    ax[0].plot(stable.mean(0), label=k)
    # put legend above plot
    ax[0].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), shadow=True, ncol=2)
    ax[1].plot(vol_good.mean(0))
    ax[2].plot(frag_good.mean(0))
    name = k.split('/')[-2] if '/' in k else k
    print(name, 'vol: ', vol_good.mean(), 'frag: ',frag_good.mean(), 'stable: ',stable.mean())
# plt.legend()
plt.show()