Skip to content

Commit

Permalink
Chunked pairwise_special_metric for faster computation
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Feb 16, 2021
1 parent a20a8cd commit 05840ef
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion umap/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,49 @@ def parallel_special_metric(X, Y=None, metric=hellinger):

return result

# We can gain efficiency by chunking the matrix into blocks;
# this keeps data vectors in cache better
@numba.njit(parallel=True, nogil=True)
def chunked_parallel_special_metric(X, Y=None, metric=hellinger, chunk_size=16):
if Y is None:
size = X.shape[0]
result = np.zeros((size, size), dtype=np.float32)
n_row_chunks = (size // chunk_size) + 1
for chunk_idx in numba.prange(n_row_chunks):
n = chunk_idx * chunk_size
chunk_end_n = min(n + chunk_size, size)
for m in range(n, size, chunk_size):
chunk_end_m = min(m + chunk_size, size)
if n == m:
for i in range(n, chunk_end_n):
for j in range(m, chunk_end_m):
if j > i:
d = metric(X[i], X[j])
result[i, j] = d
result[j, i] = d
else:
for i in range(n, chunk_end_n):
for j in range(m, chunk_end_m):
d = metric(X[i],X[j])
result[i, j] = d
result[j, i] = d
else:
row_size = X.shape[0]
col_size = Y.shape[0]
result = np.zeros((row_size, col_size), dtype=np.float32)
n_row_chunks = (row_size // chunk_size) + 1
for chunk_idx in numba.prange(n_row_chunks):
n = chunk_idx * chunk_size
chunk_end_n = min(n + chunk_size, row_size)
for m in range(0, col_size, chunk_size):
chunk_end_m = min(m + chunk_size, col_size)
for i in range(n, chunk_end_n):
for j in range(m, chunk_end_m):
d = metric(X[i], Y[j])
result[i, j] = d

return result


def pairwise_special_metric(X, Y=None, metric="hellinger", kwds=None):
if callable(metric):
Expand All @@ -1262,4 +1305,4 @@ def _partial_metric(_X, _Y=None):
return pairwise_distances(X, Y, metric=_partial_metric)
else:
special_metric_func = named_distances[metric]
return parallel_special_metric(X, Y, metric=special_metric_func)
return chunked_parallel_special_metric(X, Y, metric=special_metric_func)

0 comments on commit 05840ef

Please sign in to comment.