Skip to content

Commit

Permalink
check nn graphs building against pdist reference
Browse files Browse the repository at this point in the history
  • Loading branch information
naspert committed Mar 20, 2018
1 parent b298600 commit 77d7a0a
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions pygsp/tests/test_graphs.py
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import scipy.linalg
import scipy.sparse.linalg
from skimage import data, img_as_float

from pygsp import graphs
Expand Down Expand Up @@ -199,20 +200,45 @@ def test_nngraph(self):
graphs.NNGraph(Xin, NNtype='knn',
backend=cur_backend,
dist_type=dist_type, order=order)
self.assertRaises(ValueError, graphs.NNGraph, Xin,
NNtype='badtype', backend=cur_backend,
dist_type=dist_type)
self.assertRaises(ValueError, graphs.NNGraph, Xin,
NNtype='knn', backend='badtype',
dist_type=dist_type)

def test_nngraph_consistency(self):
#Xin = np.arange(180).reshape(60, 3)
Xin = np.random.uniform(-5, 5, (60, 3))
dist_types = ['euclidean', 'manhattan', 'max_dist', 'minkowski']
backends = ['scipy-kdtree', 'flann']
num_neighbors=5
num_neighbors=4
epsilon=0.1

# use pdist as ground truth
G = graphs.NNGraph(Xin, NNtype='knn',
backend='scipy-pdist', k=num_neighbors)
for cur_backend in backends:
for cur_backend in backends:
for dist_type in dist_types:
if cur_backend == 'flann' and dist_type == 'max_dist':
continue
#print("backend={} dist={}".format(cur_backend, dist_type))
Gt = graphs.NNGraph(Xin, NNtype='knn',
backend=cur_backend, k=num_neighbors)
d = scipy.sparse.linalg.norm(G.W - Gt.W)
self.assertTrue(d < 0.01, 'Graphs (knn) are not identical error='.format(d))

G = graphs.NNGraph(Xin, NNtype='radius',
backend='scipy-pdist', epsilon=epsilon)
for cur_backend in backends:
for dist_type in dist_types:
if cur_backend == 'flann' and dist_type == 'max_dist':
continue
#print("backend={} dist={}".format(cur_backend, dist_type))
Gt = graphs.NNGraph(Xin, NNtype='radius',
backend=cur_backend, epsilon=epsilon)
d = scipy.sparse.linalg.norm(G.W - Gt.W, ord=1)
self.assertTrue(d < 0.01,
'Graphs (radius) are not identical error='.format(d))

def test_bunny(self):
graphs.Bunny()
Expand Down

0 comments on commit 77d7a0a

Please sign in to comment.