Skip to content

Commit

Permalink
Merge pull request #6496 from ales-erjavec/scatter-tests-teardown
Browse files Browse the repository at this point in the history
Scatter plot tests teardown
  • Loading branch information
markotoplak committed Jul 11, 2023
2 parents ff152a9 + 3669fb1 commit 711a863
Showing 1 changed file with 73 additions and 54 deletions.
127 changes: 73 additions & 54 deletions Orange/widgets/visualize/tests/test_owscatterplotbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,49 @@
class MockWidget(OWWidget):
name = "Mock"

get_coordinates_data = Mock(return_value=(None, None))
get_size_data = Mock(return_value=None)
get_shape_data = Mock(return_value=None)
get_color_data = Mock(return_value=None)
get_label_data = Mock(return_value=None)
get_color_labels = Mock(return_value=None)
get_shape_labels = Mock(return_value=None)
get_subset_mask = Mock(return_value=None)
get_tooltip = Mock(return_value="")

is_continuous_color = Mock(return_value=False)
can_draw_density = Mock(return_value=True)
combined_legend = Mock(return_value=False)
selection_changed = Mock(return_value=None)
def __init__(self):
super().__init__()
self.graph = OWScatterPlotBase(self)
self.xy = None, None

def get_coordinates_data(self):
return self.xy

def get_size_data(self):
return None

def get_shape_data(self):
return None

def get_color_data(self):
return None

def get_label_data(self):
return None

def get_color_labels(self):
return None

def get_shape_labels(self):
return None

def get_subset_mask(self):
return None

def get_tooltip(self):
return ""

def is_continuous_color(self):
return False

def can_draw_density(self):
return True

def combined_legend(self):
return False

def selection_changed(self):
return None

GRAPH_CLASS = OWScatterPlotBase
graph = SettingProvider(OWScatterPlotBase)
Expand All @@ -46,27 +75,19 @@ def get_palette(self):
else:
return colorpalettes.DefaultDiscretePalette

@staticmethod
def reset_mocks():
for m in MockWidget.__dict__.values():
if isinstance(m, Mock):
m.reset_mock()
def onDeleteWidget(self):
self.graph.clear()
super().onDeleteWidget()


class TestOWScatterPlotBase(WidgetTest):
def setUp(self):
super().setUp()
self.master = MockWidget()
self.graph = OWScatterPlotBase(self.master)

self.xy = (np.arange(10, dtype=float), np.arange(10, dtype=float))
self.master.get_coordinates_data = lambda: self.xy
self.master = self.create_widget(MockWidget)
self.graph = self.master.graph
self.master.xy = (np.arange(10, dtype=float), np.arange(10, dtype=float))

def tearDown(self):
self.master.onDeleteWidget()
self.master.deleteLater()
# Clear mocks as they keep ref to widget instance when called
MockWidget.reset_mocks()
del self.master
del self.graph
super().tearDown()
Expand All @@ -79,19 +100,19 @@ def setRange(self, rect=None, *_, **__):
[rect.top(), rect.bottom()]]

def test_update_coordinates_no_data(self):
self.xy = None, None
self.master.xy = None, None
self.graph.reset_graph()
self.assertIsNone(self.graph.scatterplot_item)
self.assertIsNone(self.graph.scatterplot_item_sel)

self.xy = [], []
self.master.xy = [], []
self.graph.reset_graph()
self.assertIsNone(self.graph.scatterplot_item)
self.assertIsNone(self.graph.scatterplot_item_sel)

def test_update_coordinates(self):
graph = self.graph
xy = self.xy = (np.array([1, 2]), np.array([3, 4]))
xy = self.master.xy = (np.array([1, 2]), np.array([3, 4]))
graph.reset_graph()

scatterplot_item = graph.scatterplot_item
Expand Down Expand Up @@ -124,7 +145,7 @@ def test_update_coordinates(self):

def test_update_coordinates_and_labels(self):
graph = self.graph
xy = self.xy = (np.array([1., 2]), np.array([3, 4]))
xy = self.master.xy = (np.array([1., 2]), np.array([3, 4]))
self.master.get_label_data = lambda: np.array(["a", "b"])
graph.reset_graph()
self.assertEqual(graph.labels[0].pos().x(), 1)
Expand All @@ -137,7 +158,7 @@ def test_update_coordinates_and_labels(self):

def test_update_coordinates_and_density(self):
graph = self.graph
xy = self.xy = (np.array([1, 2]), np.array([3, 4]))
xy = self.master.xy = (np.array([1, 2]), np.array([3, 4]))
self.master.get_label_data = lambda: np.array(["a", "b"])
graph.reset_graph()
self.assertEqual(graph.labels[0].pos().x(), 1)
Expand All @@ -149,7 +170,7 @@ def test_update_coordinates_and_density(self):
def test_update_coordinates_reset_view(self):
graph = self.graph
graph.view_box.setRange = self.setRange
xy = self.xy = (np.array([2, 1]), np.array([3, 10]))
xy = self.master.xy = (np.array([2, 1]), np.array([3, 10]))
self.master.get_label_data = lambda: np.array(["a", "b"])
graph.reset_graph()
self.assertEqual(self.last_setRange, [[1, 2], [3, 10]])
Expand All @@ -159,15 +180,15 @@ def test_update_coordinates_reset_view(self):
self.assertEqual(self.last_setRange, [[0, 2], [3, 10]])

def test_reset_graph_no_data(self):
self.xy = (None, None)
self.master.xy = (None, None)
self.graph.scatterplot_item = ScatterPlotItem([1, 2], [3, 4])
self.graph.reset_graph()
self.assertIsNone(self.graph.scatterplot_item)
self.assertIsNone(self.graph.scatterplot_item_sel)

def test_update_coordinates_indices(self):
graph = self.graph
self.xy = (np.array([2, 1]), np.array([3, 10]))
self.master.xy = (np.array([2, 1]), np.array([3, 10]))
graph.reset_graph()
np.testing.assert_almost_equal(
graph.scatterplot_item.data["data"], [0, 1])
Expand All @@ -178,8 +199,8 @@ def test_sampling(self):

# Enable sampling before getting the data
graph.set_sample_size(3)
xy = self.xy = (np.arange(10, dtype=float),
np.arange(0, 30, 3, dtype=float))
xy = self.master.xy = (np.arange(10, dtype=float),
np.arange(0, 30, 3, dtype=float))
d = np.arange(10, dtype=float)
master.get_size_data = lambda: d
master.get_shape_data = lambda: d % 5 if d is not None else None
Expand Down Expand Up @@ -286,9 +307,9 @@ def test_sampling(self):
(x[2] - x[1]) / (x[1] - x[0]))

# Reset graph when data is present and sampling is enabled
self.xy = (np.arange(100, 105, dtype=float),
np.arange(100, 105, dtype=float))
d = self.xy[0] - 100
self.master.xy = (np.arange(100, 105, dtype=float),
np.arange(100, 105, dtype=float))
d = self.master.xy[0] - 100
graph.reset_graph()
self.process_events(until=lambda: not (
self.graph.timer is not None and self.graph.timer.isActive()))
Expand All @@ -303,7 +324,7 @@ def test_sampling(self):
(x[2] - x[1]) / (x[1] - x[0]))

# Don't sample when unnecessary
self.xy = (np.arange(100, dtype=float), ) * 2
self.master.xy = (np.arange(100, dtype=float), ) * 2
d = None
delattr(master, "get_label_data")
graph.reset_graph()
Expand All @@ -315,8 +336,8 @@ def test_sampling(self):
def test_sampling_keeps_selection(self):
graph = self.graph

self.xy = (np.arange(100, dtype=float),
np.arange(100, dtype=float))
self.master.xy = (np.arange(100, dtype=float),
np.arange(100, dtype=float))
graph.reset_graph()
graph.select_by_indices(np.arange(1, 100, 2))
graph.set_sample_size(30)
Expand All @@ -326,14 +347,13 @@ def test_sampling_keeps_selection(self):

base = "Orange.widgets.visualize.owscatterplotgraph.OWScatterPlotBase."

@staticmethod
@patch(base + "update_sizes")
@patch(base + "update_colors")
@patch(base + "update_selection_colors")
@patch(base + "update_shapes")
@patch(base + "update_labels")
def test_reset_calls_all_updates_and_update_doesnt(*mocks):
master = MockWidget()
def test_reset_calls_all_updates_and_update_doesnt(self, *mocks):
master = self.create_widget(MockWidget)
graph = OWScatterPlotBase(master)
for mock in mocks:
mock.assert_not_called()
Expand Down Expand Up @@ -586,8 +606,8 @@ def test_colors_continuous_reused(self):
self.master.is_continuous_color = lambda: True
graph = self.graph

self.xy = (np.arange(100, dtype=float),
np.arange(100, dtype=float))
self.master.xy = (np.arange(100, dtype=float),
np.arange(100, dtype=float))

d = np.arange(100, dtype=float)
self.master.get_color_data = lambda: d
Expand Down Expand Up @@ -1211,15 +1231,14 @@ def test_label_mask_with_invisible_and_view(self):

def test_labels_observes_mask(self):
graph = self.graph
get_label_data = graph.master.get_label_data
graph.reset_graph()

self.assertEqual(graph.labels, [])

get_label_data.reset_mock()
graph._label_mask = lambda *_: None
graph.update_labels()
get_label_data.assert_not_called()
with patch.object(graph.master, "get_label_data") as m:
graph._label_mask = lambda *_: None
graph.update_labels()
m.assert_not_called()

self.master.get_label_data = lambda: \
np.array([str(x) for x in range(10)], dtype=object)
Expand Down

0 comments on commit 711a863

Please sign in to comment.