Skip to content

Commit

Permalink
Added matplotlib unit tests for ChordPlot
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Feb 9, 2018
1 parent 7ce4ef0 commit 5fe93c8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
Expand Up @@ -5,30 +5,24 @@
import numpy as np
from holoviews.core.data import Dataset
from holoviews.core.options import Store, Cycle
from holoviews.element import Graph, TriMesh, circular_layout
from holoviews.element import Graph, TriMesh, Chord, circular_layout
from holoviews.element.comparison import ComparisonTestCase
from holoviews.plotting import comms

# Standardize backend due to random inconsistencies
try:
from matplotlib import pyplot
pyplot.switch_backend('agg')
from holoviews.plotting.mpl import OverlayPlot
from matplotlib.collections import LineCollection, PolyCollection
mpl_renderer = Store.renderers['matplotlib']
except:
mpl_renderer = None
pass

from .testplot import TestMPLPlot, mpl_renderer

class MplGraphPlotTests(ComparisonTestCase):

class TestMplGraphPlot(TestMPLPlot):

def setUp(self):
if not mpl_renderer:
raise SkipTest('Matplotlib tests require matplotlib to be available')
self.previous_backend = Store.current_backend
Store.current_backend = 'matplotlib'
self.default_comm = mpl_renderer.comms['default']
mpl_renderer.comms['default'] = (comms.Comm, '')
super(TestMplGraphPlot, self).setUp()

N = 8
self.nodes = circular_layout(np.arange(N, dtype=np.int32))
Expand All @@ -42,11 +36,6 @@ def setUp(self):
self.graph3 = Graph(((self.source, self.target), self.node_info2))
self.graph4 = Graph(((self.source, self.target, self.weights),), vdims='Weight')


def tearDown(self):
mpl_renderer.comms['default'] = self.default_comm
Store.current_backend = self.previous_backend

def test_plot_simple_graph(self):
plot = mpl_renderer.get_plot(self.graph)
nodes = plot.handles['nodes']
Expand Down Expand Up @@ -101,25 +90,16 @@ def test_plot_graph_numerically_colored_edges(self):



class TestMplTriMeshPlots(ComparisonTestCase):
class TestMplTriMeshPlot(TestMPLPlot):

def setUp(self):
if not mpl_renderer:
raise SkipTest('Matplotlib tests require matplotlib to be available')
self.previous_backend = Store.current_backend
Store.current_backend = 'matplotlib'
self.default_comm = mpl_renderer.comms['default']
mpl_renderer.comms['default'] = (comms.Comm, '')
super(TestMplTriMeshPlot, self).setUp()

self.nodes = [(0, 0, 0), (0.5, 1, 1), (1., 0, 2), (1.5, 1, 3)]
self.simplices = [(0, 1, 2, 0), (1, 2, 3, 1)]
self.trimesh = TriMesh((self.simplices, self.nodes))
self.trimesh_weighted = TriMesh((self.simplices, self.nodes), vdims='weight')

def tearDown(self):
mpl_renderer.comms['default'] = self.default_comm
Store.current_backend = self.previous_backend

def test_plot_simple_trimesh(self):
plot = mpl_renderer.get_plot(self.trimesh)
nodes = plot.handles['nodes']
Expand Down Expand Up @@ -164,3 +144,39 @@ def test_plot_trimesh_categorically_colored_edges_filled(self):
[0.215686, 0.494118, 0.721569, 1.]])
self.assertEqual(edges.get_facecolors(), colors)


class TestMplChordPlot(TestMPLPlot):

def setUp(self):
super(TestMplChordPlot, self).setUp()
self.edges = [(0, 1, 1), (0, 2, 2), (1, 2, 3)]
self.nodes = Dataset([(0, 'A'), (1, 'B'), (2, 'C')], 'index', 'Label')
self.chord = Chord((self.edges, self.nodes))

def test_chord_nodes_label_text(self):
g = self.chord.opts(plot=dict(label_index='Label'))
plot = mpl_renderer.get_plot(g)
labels = plot.handles['labels']
self.assertEqual([l.get_text() for l in labels], ['A', 'B', 'C'])

def test_chord_nodes_categorically_colormapped(self):
g = self.chord.opts(plot=dict(color_index='Label'),
style=dict(cmap=['#FFFFFF', '#CCCCCC', '#000000']))
plot = mpl_renderer.get_plot(g)
arcs = plot.handles['arcs']
nodes = plot.handles['nodes']
colors = np.array([[ 1., 1., 1., 1. ],
[ 0.8, 0.8, 0.8, 1. ],
[ 0., 0., 0., 1. ]])
self.assertEqual(arcs.get_colors(), colors)
self.assertEqual(nodes.get_facecolors(), colors)

def test_chord_edges_categorically_colormapped(self):
g = self.chord.opts(plot=dict(edge_color_index='start'),
style=dict(edge_cmap=['#FFFFFF', '#000000']))
plot = mpl_renderer.get_plot(g)
edges = plot.handles['edges']
colors = np.array([[ 1., 1., 1., 1. ],
[ 1., 1., 1., 1. ],
[ 0., 0., 0., 1. ]])
self.assertEqual(edges.get_edgecolors(), colors)
2 changes: 2 additions & 0 deletions tests/plotting/matplotlib/testplot.py
Expand Up @@ -9,6 +9,8 @@
from holoviews.plotting import comms

try:
from matplotlib import pyplot
pyplot.switch_backend('agg')
import holoviews.plotting.mpl
mpl_renderer = Store.renderers['matplotlib']
except:
Expand Down

0 comments on commit 5fe93c8

Please sign in to comment.