In [1]:
import numpy as np
import numba

In [2]:
@numba.jit(nopython=True)
def find_root(x, pntr_tree):
    """ Returns the root node-ID of the connected component containing `x`
        
        Performs path compression. I.e. redirects all pointer values 
        along the recursive path to point directly to the root node, compressing 
        future root finding paths.
        
        Parameters
        ----------
        x : int
            A valid node-ID.
        pntr_tree : Sequence[int, ...]
            A pointer-tree, indicating the connected component membership of nodes in 
            the graph.
        
        Returns
        -------
        int
            The root node-ID of the connected component `x` """
    if pntr_tree[x] < 0:  # x is the root of a connected component
        return x
    pntr_tree[x] = find_root(pntr_tree[x], pntr_tree)  # find the root that x points to, and update tree
    return pntr_tree[x]

In [3]:
@numba.jit(nopython=True)
def init_pntr_tree(nrows, ncols):
     return -1*np.ones(nrows*ncols, dtype=np.int64)

In [4]:
init_pntr_tree(3,2)

array([-1, -1, -1, -1, -1, -1], dtype=int64)

In [5]:
@numba.jit(nopython=True)
def union(x, y, pntr_tree):
    """ Joins the connected components containing `x` and `y`, respectively.
        
        Performs union by rank: the root of the smaller component is pointed
        to the root of the larger one.
        
        Parameters
        ----------
        x : int
            A valid node-ID
        y : int
            A valid node-ID.
        pntr_tree : Sequence[int, ...]
            A pointer-tree, indicating the connected component membership of nodes in 
            the graph."""
    r_x = find_root(x, pntr_tree)
    r_y = find_root(y, pntr_tree)
    
    if r_x != r_y:
        if pntr_tree[r_x] <= pntr_tree[r_y]:  # subgraph containing x is bigger (in magnitude!!)
            pntr_tree[r_x] += pntr_tree[r_y]  # add the smaller subgraph to the larger
            pntr_tree[r_y] = r_x  # point root of cluster y to x
        else:
            pntr_tree[r_y] += pntr_tree[r_x]
            pntr_tree[r_x] = r_y
    return None
    

In [273]:
img = (np.random.rand(200, 200) < .4).astype(np.int64)
row, col = np.where(img > 0)
ptr = -1*np.ones(len(row), dtype=np.int64)

In [274]:
import numpy as np
import numba

print(numba.__version__)

def f():
    x = np.arange(6).reshape(2, 3)
    for item in np.ndindex(3, 2, 1):
        return item

jitted_f = numba.jit(f)

print(f())
print(jitted_f())

0.29.0
(0, 0, 0)
(0, 0, 0)


In [275]:
f()

(0, 0, 0)

In [276]:
@numba.jit(nopython=True)
def f():
    neighbor_links = [(1, (0, 1)), (num_col, (1, 0)), (num_col + 1, (1, 1))]
    for dn, (drow, dcol) in neighbor_links:
        return dn

In [277]:
f()

1

In [278]:
@numba.jit(nopython=True)
def crawl_img(img, thresh):
    num_row, num_col = img.shape
    ptr = init_pntr_tree(num_row, num_col)
    bkgrnd_id = -1
    n = -1
    neighbor_links = [(1, (0, 1)), (num_col, (1, 0)), (num_col + 1, (1, 1))]
    for row in range(num_row):
        for col in range(num_col):
            n += 1
            value = img[row, col]  # ndenumerate isn't working with numba...
            if value < thresh:
                if bkgrnd_id == -1:
                    bkgrnd_id = n
                else:
                    union(bkgrnd_id, n, ptr)
            else:
                for dn, (drow, dcol) in neighbor_links:
                    n_n = n + dn
                    n_row = row + drow
                    n_col = col + dcol
                    if n_row > num_row - 1 or n_col > num_col -1:
                        continue
                    n_value = img[n_row, n_col]
                    if n_value < thresh:
                        continue
                    union(n, n_n, ptr)
    return ptr
            

def crawl_img2(img, thresh):
    num_row, num_col = img.shape
    ptr = init_pntr_tree(num_row, num_col)
    bkgrnd_id = -1
    n = -1
    neighbor_links = [(1, (0, 1)), (num_col, (1, 0)), (num_col + 1, (1, 1))]
    for row in range(num_row):
        for col in range(num_col):
            n += 1
            value = img[row, col]  # ndenumerate isn't working with numba...
            if value < thresh:
                if bkgrnd_id == -1:
                    bkgrnd_id = n
                else:
                    union(bkgrnd_id, n, ptr)
            else:
                for dn, (drow, dcol) in neighbor_links:
                    n_n = n + dn
                    n_row = row + drow
                    n_col = col + dcol
                    if n_row > num_row - 1 or n_col > num_col -1:
                        continue
                    n_value = img[n_row, n_col]
                    if n_value < thresh:
                        continue
                    union(n, n_n, ptr)
    return ptr       
            
            
                
            

In [279]:
img

array([[0, 0, 1, ..., 1, 1, 1],
       [0, 0, 0, ..., 0, 1, 0],
       [1, 0, 0, ..., 0, 1, 1],
       ..., 
       [0, 1, 1, ..., 0, 1, 1],
       [1, 0, 0, ..., 0, 1, 1],
       [0, 1, 0, ..., 1, 1, 0]], dtype=int64)

In [280]:
%%timeit
out = crawl_img(img, .5)

The slowest run took 53.63 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 3: 2.61 ms per loop


In [281]:
%%timeit
out = crawl_img2(img, .5)

10 loops, best of 3: 113 ms per loop


In [253]:
out = crawl_img(img, .5)

In [261]:
np.where(out < 0)

(array([  0,   6,  10,  22,  37,  53,  62,  65,  89, 102, 125, 131, 142,
        158, 170, 172, 174, 180, 198, 202, 211, 213, 215, 226, 240, 243,
        245, 251, 270, 275, 299, 314, 316, 320, 323, 332, 347, 354, 360,
        377, 394], dtype=int64),)

In [259]:
len(np.unique(out[np.where(out < 0)]))

6

In [258]:
np.unique(out[np.where(out > -1)])

array([  0,  10,  37,  53,  65, 158, 180, 198, 215, 270, 316, 323, 347,
       354, 360, 377], dtype=int64)

In [236]:
out[72]

9

In [257]:
for i in range(len(out)):
    find_root(i, out)

In [238]:
np.sum(out[np.where(out < 0)])

-400

In [237]:
print(out.reshape(img.shape))

[[-183 -137    1    0    0    0    0   25    0  -44    9    9    9    9
     9    9    9    9    9    0]
 [   0    1    1    1    0   -7   25   25    0    9    0    0    9    9
     0    0    9    0    0    9]
 [   1    1    1    1    0    0   25   25    0    9    0    9    0    0
     9    0    9    0    9    0]
 [   0    1    0    1    1    1    0   25    0    9    0    9    9    0
     0    0    9    9    9    9]
 [   1    1    0    0    0    1    1    0    0    0    9    9    9    0
     1    1    0    0    0    0]
 [   1    0    0    1    1    0    0    1    1    1    0    0    9    0
     0    1    1    1    0    1]
 [   0    0    1    1    1    1    1    1    1    0    0    0    9    0
     0    1    1    0    0    1]
 [   0    0    0    1    1    1    1    1    1    0    0    9    9    9
     0    1    0    1    1    1]
 [   0    0    0    1    0    1    0    0    0    1    1    0    9    0
     9    0    0    1    0    0]
 [   0    0    0    1    1    0    1    0    1    1    

In [16]:
[(r,c) for r,c in np.ndindex(3, 3)]

[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]