diff --git a/pygsp/_nearest_neighbor.py b/pygsp/_nearest_neighbor.py index 15dd97ce..c3cb1ae5 100644 --- a/pygsp/_nearest_neighbor.py +++ b/pygsp/_nearest_neighbor.py @@ -46,7 +46,7 @@ def _scipy_kdtree(features, _, order, kind, k, radius, params): return neighbors, distances -def _scipy_ckdtree(features, _, order, kind, k, radius, params): +def _scipy_ckdtree(features, metric, order, kind, k, radius, params): if order is None: raise ValueError('invalid metric for scipy-kdtree') eps = params.pop('eps', 0) @@ -54,20 +54,23 @@ def _scipy_ckdtree(features, _, order, kind, k, radius, params): params = dict(p=order, eps=eps, n_jobs=-1) if kind == 'knn': params['k'] = k + 1 - elif kind == 'radius': - params['k'] = features.shape[0] # number of vertices - params['distance_upper_bound'] = radius - distances, neighbors = tree.query(features, **params) - if kind == 'knn': + distances, neighbors = tree.query(features, **params) return neighbors, distances elif kind == 'radius': + neighbors = tree.query_ball_point(features, + radius * np.ones((features.shape[0],)), + p=order) dist = [] - neigh = [] - for distance, neighbor in zip(distances, neighbors): - mask = (distance != np.inf) - dist.append(distance[mask]) - neigh.append(neighbor[mask]) - return neigh, dist + metric = 'cityblock' if metric == 'manhattan' else metric + metric = 'chebyshev' if metric == 'max_dist' else metric + params = dict(metric=metric) + if metric == 'minkowski': + params['p'] = order + for i, neighbor in enumerate(neighbors): + dist.append(spatial.distance.cdist([features[i]], + features[neighbor], + **params).flatten()) + return neighbors, dist def _flann(features, metric, order, kind, k, radius, params):