Skip to content

Commit

Permalink
simplify test_nngraph (PR #21)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeff authored and nperraud committed Dec 17, 2020
1 parent b01fd31 commit 7b35e88
Showing 1 changed file with 21 additions and 36 deletions.
57 changes: 21 additions & 36 deletions pygsp/tests/test_graphs.py
Expand Up @@ -491,6 +491,7 @@ def test_subgraph(self, n_vertices=100):
self.assertEqual(graph.plotting, self._G.plotting)

def test_nngraph(self, n_vertices=30):
"""Test all the combinations of metric, kind, backend."""
features = np.random.RandomState(42).normal(size=(n_vertices, 3))
metrics = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
backends = ['scipy-kdtree', 'scipy-ckdtree', 'scipy-pdist', 'nmslib',
Expand All @@ -499,46 +500,30 @@ def test_nngraph(self, n_vertices=30):

for backend in backends:
for metric in metrics:
if ((backend == 'flann' and metric == 'max_dist') or
(backend == 'nmslib' and metric == 'minkowski')):
self.assertRaises(ValueError, graphs.NNGraph, features,
kind='knn', backend=backend,
metric=metric)
self.assertRaises(ValueError, graphs.NNGraph, features,
kind='radius', backend=backend,
metric=metric)
else:
if backend == 'nmslib':
self.assertRaises(ValueError, graphs.NNGraph, features,
kind='radius', backend=backend,
metric=metric, order=order)
for kind in ['knn', 'radius']:
params = dict(features=features, metric=metric,
order=order, kind=kind, backend=backend)
# Unsupported combinations.
if backend == 'flann' and metric == 'max_dist':
self.assertRaises(ValueError, graphs.NNGraph, **params)
elif backend == 'nmslib' and metric == 'minkowski':
self.assertRaises(ValueError, graphs.NNGraph, **params)
elif backend == 'nmslib' and kind == 'radius':
self.assertRaises(ValueError, graphs.NNGraph, **params)
else:
graphs.NNGraph(features, kind='radius',
backend=backend,
metric=metric, order=order)
graphs.NNGraph(features, kind='knn',
backend=backend,
metric=metric, order=order)
graphs.NNGraph(features, kind='knn',
backend=backend,
metric=metric, order=order,
center=False)
graphs.NNGraph(features, kind='knn',
backend=backend,
metric=metric, order=order,
rescale=False)
graphs.NNGraph(features, kind='knn',
backend=backend,
metric=metric, order=order,
rescale=False, center=False)
graphs.NNGraph(**params, center=False)
graphs.NNGraph(**params, rescale=False)
graphs.NNGraph(**params, center=False, rescale=False)

# Invalid parameters.
self.assertRaises(ValueError, graphs.NNGraph, features,
metric='invalid')
self.assertRaises(ValueError, graphs.NNGraph, features,
kind='invalid', backend=backend,
metric=metric)
kind='invalid')
self.assertRaises(ValueError, graphs.NNGraph, features,
kind='knn', backend='invalid',
metric=metric)
backend='invalid')
self.assertRaises(ValueError, graphs.NNGraph, features,
kind='knn', k=n_vertices+1)
kind='knn', k=n_vertices+1)

def test_nngraph_consistency(self):
features = np.arange(90).reshape(30, 3)
Expand Down

0 comments on commit 7b35e88

Please sign in to comment.