In [1]:
import inspect
import itertools

import numpy as np

import skimage as ski
import skimage.segmentation as ss

In [2]:
cell_4d = ski.data.cells3d()
imgs = (cell_4d[:, 0],  # Membranes
        cell_4d[:, 1],  # Nuclei
        ski.data.brain())
img = imgs[0]

In [3]:
img.shape

(60, 256, 256)

In [11]:
# Add random noise to remove ties.
rng = np.random.default_rng()
fimg = ski.util.img_as_float(img) + rng.normal(0, 0.01, size=img.shape)
uvals = np.unique(fimg.ravel())
assert len(uvals) == fimg.size

In [12]:
ws_orig = ss.watershed(fimg)

In [13]:
def rolled_ws(img, axes):
    r_img = np.transpose(img, axes)
    f_r_img = ss.watershed(r_img)
    return np.transpose(f_r_img, np.argsort(axes))

In [14]:
def assert_labels_equivalent(label_1, label_2):
    uq_labels_1 = np.unique(label_1)
    uq_labels_2 = np.unique(label_2)
    assert np.all(uq_labels_1 == uq_labels_2)
    unclaimed = list(uq_labels_2)
    for label in uq_labels_1:
        mask = label_1 == label
        in_mask = label_2[mask]
        label_other = in_mask[0]
        assert np.all(in_mask == label_other)
        unclaimed.remove(label_other)
    assert len(unclaimed) == 0

In [15]:
ws_rolled = rolled_ws(img, (2, 1, 0))
assert assert_labels_equivalent(ws_rolled, ws_orig)

ValueError: operands could not be broadcast together with shapes (485374,) (508036,) 

In [None]:
orderings = set(itertools.permutations(range(3), 3))
orderings.remove((0, 1, 2))

In [None]:
for i, img in enumerate(imgs):
    ws_start = ss.watershed(img)
    print(f'Image {i}')
    for order in orderings:
        print(f'Ordering {order}')
        ws_s_r = rolled_ws(img, order)
        assert assert_labels_equavalent(ws_s_r, ws_start)