Skip to content

Commit

Permalink
better test of save and load
Browse files Browse the repository at this point in the history
  • Loading branch information
mdeff committed Mar 14, 2019
1 parent e46656a commit 7525506
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions pygsp/tests/test_graphs.py
Expand Up @@ -733,35 +733,34 @@ def test_networkx_signal_import(self):
@unittest.skipIf(sys.version_info < (3, 6), 'need ordered dicts')
def test_save_load(self):

graph = graphs.Sensor(seed=42)
signal = np.random.RandomState(42).uniform(size=graph.N)
graph.set_signal(signal, "signal")

# save
nx_gt = ['gml', 'graphml']
all_files = []
for fmt in nx_gt:
all_files += ["graph_gt.{}".format(fmt), "graph_nx.{}".format(fmt)]
graph.save("graph_gt.{}".format(fmt), backend='graph-tool')
graph.save("graph_nx.{}".format(fmt), backend='networkx')
graph.save("graph_nx.{}".format('gexf'), backend='networkx')
all_files += ["graph_nx.{}".format('gexf')]

# load
for filename in all_files:
if not "_gt" in filename:
graph_loaded_nx = graphs.Graph.load(filename, backend='networkx')
np.testing.assert_array_equal(graph.W.todense(), graph_loaded_nx.W.todense())
np.testing.assert_array_equal(signal, graph_loaded_nx.signals['signal'])
if not ".gexf" in filename:
graph_loaded_gt = graphs.Graph.load(filename, backend='graph-tool')
np.testing.assert_allclose(graph.W.todense(), graph_loaded_gt.W.todense(), atol=0.000001)
np.testing.assert_allclose(signal, graph_loaded_gt.signals['signal'], atol=0.000001)

# clean
for filename in all_files:
os.remove(filename)

G1 = graphs.Sensor(seed=42)
W = G1.W.toarray()
sig = np.random.RandomState(42).normal(size=G1.N)
G1.set_signal(sig, 's')

for fmt in ['graphml', 'gml', 'gexf']:
for backend in ['networkx', 'graph-tool']:

if fmt == 'gexf' and backend == 'graph-tool':
self.assertRaises(ValueError, G1.save, 'g', fmt, backend)
self.assertRaises(ValueError, graphs.Graph.load, 'g', fmt,
backend)
continue

atol = 1e-5 if fmt == 'gml' and backend == 'graph-tool' else 0

for filename, fmt in [('graph.' + fmt, None), ('graph', fmt)]:
G1.save(filename, fmt, backend)
G2 = graphs.Graph.load(filename, fmt, backend)
np.testing.assert_allclose(G2.W.toarray(), W, atol=atol)
np.testing.assert_allclose(G2.signals['s'], sig, atol=atol)
os.remove(filename)

self.assertRaises(ValueError, graphs.Graph.load, 'g', fmt='unk')
self.assertRaises(ValueError, graphs.Graph.load, 'g', backend='unk')
self.assertRaises(ValueError, G1.save, 'g', fmt='unk')
self.assertRaises(ValueError, G1.save, 'g', backend='unk')
os.remove('g')

@unittest.skipIf(sys.version_info < (3, 3), 'need unittest.mock')
def test_import_errors(self):
Expand Down

0 comments on commit 7525506

Please sign in to comment.