Skip to content

Commit

Permalink
Merge pull request #22967 from greglucas/tst-quadmesh-cursor
Browse files Browse the repository at this point in the history
TST: Add some tests for QuadMesh contains function
  • Loading branch information
timhoffm committed May 5, 2022
2 parents 2e70254 + 887eb62 commit 078a9cb
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 12 deletions.
12 changes: 3 additions & 9 deletions lib/matplotlib/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,12 +2197,6 @@ def draw(self, renderer):

def get_cursor_data(self, event):
contained, info = self.contains(event)
if len(info["ind"]) == 1:
ind, = info["ind"]
array = self.get_array()
if array is not None:
return array[ind]
else:
return None
else:
return None
if contained and self.get_array() is not None:
return self.get_array().ravel()[info["ind"]]
return None
80 changes: 77 additions & 3 deletions lib/matplotlib/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import matplotlib.path as mpath
import matplotlib.transforms as mtransforms
from matplotlib.collections import (Collection, LineCollection,
EventCollection, PolyCollection)
EventCollection, PolyCollection,
QuadMesh)
from matplotlib.testing.decorators import check_figures_equal, image_comparison
from matplotlib._api.deprecation import MatplotlibDeprecationWarning

Expand Down Expand Up @@ -483,6 +484,81 @@ def test_picking():
assert_array_equal(indices['ind'], [0])


def test_quadmesh_contains():
x = np.arange(4)
X = x[:, None] * x[None, :]

fig, ax = plt.subplots()
mesh = ax.pcolormesh(X)
fig.draw_without_rendering()
xdata, ydata = 0.5, 0.5
x, y = mesh.get_transform().transform((xdata, ydata))
mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
found, indices = mesh.contains(mouse_event)
assert found
assert_array_equal(indices['ind'], [0])

xdata, ydata = 1.5, 1.5
x, y = mesh.get_transform().transform((xdata, ydata))
mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
found, indices = mesh.contains(mouse_event)
assert found
assert_array_equal(indices['ind'], [5])


def test_quadmesh_contains_concave():
# Test a concave polygon, V-like shape
x = [[0, -1], [1, 0]]
y = [[0, 1], [1, -1]]
fig, ax = plt.subplots()
mesh = ax.pcolormesh(x, y, [[0]])
fig.draw_without_rendering()
# xdata, ydata, expected
points = [(-0.5, 0.25, True), # left wing
(0, 0.25, False), # between the two wings
(0.5, 0.25, True), # right wing
(0, -0.25, True), # main body
]
for point in points:
xdata, ydata, expected = point
x, y = mesh.get_transform().transform((xdata, ydata))
mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
found, indices = mesh.contains(mouse_event)
assert found is expected


def test_quadmesh_cursor_data():
x = np.arange(4)
X = x[:, None] * x[None, :]

fig, ax = plt.subplots()
mesh = ax.pcolormesh(X)
# Empty array data
mesh._A = None
fig.draw_without_rendering()
xdata, ydata = 0.5, 0.5
x, y = mesh.get_transform().transform((xdata, ydata))
mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
# Empty collection should return None
assert mesh.get_cursor_data(mouse_event) is None

# Now test adding the array data, to make sure we do get a value
mesh.set_array(np.ones((X.shape)))
assert_array_equal(mesh.get_cursor_data(mouse_event), [1])


def test_quadmesh_cursor_data_multiple_points():
x = [1, 2, 1, 2]
fig, ax = plt.subplots()
mesh = ax.pcolormesh(x, x, np.ones((3, 3)))
fig.draw_without_rendering()
xdata, ydata = 1.5, 1.5
x, y = mesh.get_transform().transform((xdata, ydata))
mouse_event = SimpleNamespace(xdata=xdata, ydata=ydata, x=x, y=y)
# All quads are covering the same square
assert_array_equal(mesh.get_cursor_data(mouse_event), np.ones(9))


def test_linestyle_single_dashes():
plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.]))
plt.draw()
Expand Down Expand Up @@ -749,8 +825,6 @@ def test_quadmesh_deprecated_signature(
fig_test, fig_ref, flat_ref, kwargs):
# test that the new and old quadmesh signature produce the same results
# remove when the old QuadMesh.__init__ signature expires (v3.5+2)
from matplotlib.collections import QuadMesh

x = [0, 1, 2, 3.]
y = [1, 2, 3.]
X, Y = np.meshgrid(x, y)
Expand Down

0 comments on commit 078a9cb

Please sign in to comment.