Skip to content

Commit

Permalink
Use own broadcasted mgrid for massive speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
jni committed Mar 9, 2017
1 parent 7afb957 commit 641d0a0
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion skan/vendored/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
import numba


def broadcast_mgrid(arrays):
shape = tuple(map(len, arrays))
ndim = len(shape)
result = []
for i, arr in enumerate(arrays, start=1):
reshaped = np.broadcast_to(arr[[...] + [np.newaxis] * (ndim - i)],
shape)
result.append(reshaped)
return result


@numba.jit(nopython=True, cache=True, nogil=True)
def _correlate_nonzeros_offset(input, indices, offsets, values, output):
for off, val in zip(offsets, values):
Expand Down Expand Up @@ -47,7 +58,8 @@ def correlate_nonzeros(padded_array, kernel):
result = np.zeros(np.array(padded_array.shape) - np.array(kernel.shape)
+ 1)
# note: np.mgrid takes up a lot of time. Prioritise finding alternative
corner_multi_indices = np.mgrid[[slice(None, i) for i in result.shape]]
corner_multi_indices = broadcast_mgrid([np.arange(i)
for i in result.shape])
corner_indices = np.ravel_multi_index(corner_multi_indices,
padded_array.shape).ravel()
_correlate_nonzeros_offset(padded_array.ravel(), corner_indices,
Expand Down

0 comments on commit 641d0a0

Please sign in to comment.