In [15]:
from itertools import product 
from collections import deque

import numpy as np
import jax.numpy as jnp
import scipy.sparse as sparse
from scipy.sparse.csgraph import shortest_path, connected_components

## Finding islands in map

### Ground truth computation using BFS, not good for parallization...

In [33]:
directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]

def n_islands(int_map: np.ndarray):
    h, w = int_map.shape
    visited = np.zeros_like(int_map, dtype=bool)
    
    def in_bounds(pos, bounds):
        return 0 <= pos[0] < h and 0 <= pos[1] < w
    
    def visit(i, j):
        to_visit = deque()

        visited[i, j] = True
        to_visit.append((i, j))

        while len(to_visit) > 0:
            i, j = to_visit.popleft()
            for di, dj in directions:
                nb = i + di, j + dj
                if in_bounds(nb, int_map.shape) and int_map[nb] and not visited[nb]:
                    to_visit.append(nb)
                    visited[nb] = True
    
    n_components = 0
    for i, j in product(range(h), range(w)):
        if int_map[i, j] and not visited[i, j]:
            visit(i, j)
            n_components += 1
            
    return n_components

In [35]:
# int_map = jnp.asarray([
#     [ 0, 0, 1, 1, 0 ],
#     [ 1, 1, 0, 1, 0 ],
#     [ 1, 1, 0, 0, 1 ],
#     [ 0, 0, 0, 1, 1 ],
#     [ 1, 1, 1, 0, 0 ],§
# ])
int_map = jnp.asarray([
    [ 0, 0, 1 ],
    [ 1, 1, 0 ],
    [ 1, 1, 0 ]
])

n_islands(int_map)

2

### Using jax a diffusing the labels

In [68]:
from functools import partial
import jax
import jax.numpy as jnp
import jax.lax as lax

from equinox.nn import Pool

In [76]:
# _conv2d = partial(jax.scipy.signal.convolve2d, mode='same', boundary='fill', fillvalue=0)
manhattan_kernel = jnp.array([
            [0.0, 1.0, 0.0],
            [1.0, 1.0, 1.0],
            [0.0, 1.0, 0.0]
        ])


def manhattan_max(x, y):
    print(x.shape)
    print(y.shape)
    jax.debug.print("{x}, {y}", x=x, y=y)
    

pool = Pool(0, manhattan_max, 2, (3, 3), (1, 1), (1, 1))

def n_components_diffusion(int_map):
    nodes = jnp.arange(1, jnp.prod(jnp.asarray(int_map.shape)) + 1, dtype=jnp.float32).reshape(int_map.shape)
    nodes = (nodes * int_map)[None]

    
    def step(diff, _):
        return pool(diff), None
    
    diffused, _ = lax.scan(step, nodes, jnp.arange(jnp.sum(jnp.asarray(int_map.shape))))
    
    print(diffused)
    
n_components_diffusion(int_map)

()
()


ValueError: reduce_window output must have the same tree structure as the operands PyTreeDef(*) vs. PyTreeDef(None)