In [1]:
def select_coexpressing_cells(image, labels, threshold):
    """
    Selects and returns the subset of cell segments from labels (HxW)
    that co expresses the total intensity of all channels above a
    certain threshold given an input image (HxWxC)
    """
    import numpy as np
    import pandas as pd
    from skimage.measure import regionprops_table

    df = pd.DataFrame(
        regionprops_table(labels, image, properties=("label", "area", "intensity_mean"))
    )
    mean_cols = [c for c in df.columns if c.startswith("intensity_mean")]
    sum_cols = [f"intensity_sum-{i}" for i in range(len(mean_cols))]
    
    df[sum_cols] = df[mean_cols] * df["area"].to_numpy()
    
    selected_ids = df.loc[(df[sum_cols] > threshold).all(axis=1), "label"].to_numpy()
    
    below_threshold_mask = np.isin(labels, selected_ids, invert=False)
    filtered_labels = labels * below_threshold_mask

    return filtered_labels

In [2]:
def check(candidate):
    import numpy as np

    labels = np.asarray([[0, 1, 1],
                         [2, 2, 3],
                         [0, 0, 3]])

    image = np.asarray([[[0, 0, 0], [5, 2, 6], [5, 2, 4]],
                        [[10,5, 3], [0, 5, 7], [1, 4, 4]],
                        [[0, 0, 0], [0, 0, 0], [4, 4, 4]]])

    expected_labels = np.asarray([[0, 0, 0],
                                  [2, 2, 0],
                                  [0, 0, 0]])

    results = candidate(image, labels, 7)

    np.testing.assert_equal(expected_labels, results)

In [3]:
check(select_coexpressing_cells)