Skip to content

Commit

Permalink
Merge pull request #4997 from borondics/OWNeighbors_output_all_distances
Browse files Browse the repository at this point in the history
[ENH] Neighbors: improve exclusion of references, checkbox to (un)limit output data
  • Loading branch information
markotoplak committed Oct 6, 2020
2 parents d588d69 + ee6591e commit 7d04ab1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 90 deletions.
51 changes: 28 additions & 23 deletions Orange/widgets/data/owneighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class Inputs:
class Outputs:
data = Output("Neighbors", Table)

class Info(OWWidget.Warning):
removed_references = \
Msg("Input data includes reference instance(s).\n"
"Reference instances are excluded from the output.")

class Warning(OWWidget.Warning):
all_data_as_reference = \
Msg("Every data instance is same as some reference")
Expand All @@ -52,8 +57,8 @@ class Error(OWWidget.Error):
distance_index: int

n_neighbors = Setting(10)
limit_neighbors = Setting(True)
distance_index = Setting(0)
exclude_reference = Setting(True)
auto_apply = Setting(True)

want_main_area = False
Expand All @@ -70,17 +75,13 @@ def __init__(self):
box = gui.vBox(self.controlArea, box=True)
gui.comboBox(
box, self, "distance_index", orientation=Qt.Horizontal,
label="Distance: ", items=[d[0] for d in METRICS],
label="Distance metric: ", items=[d[0] for d in METRICS],
callback=self.recompute)
gui.spin(
box, self, "n_neighbors", label="Number of neighbors:",
step=1, spinType=int, minv=0, maxv=100,
# call apply by gui.auto_commit, pylint: disable=unnecessary-lambda
callback=lambda: self.apply())
gui.checkBox(
box, self, "exclude_reference",
label="Exclude rows (equal to) references",
box, self, "n_neighbors", label="Limit number of neighbors to:",
step=1, spinType=int, minv=0, maxv=100, checked='limit_neighbors',
# call apply by gui.auto_commit, pylint: disable=unnecessary-lambda
checkCallback=lambda: self.apply(),
callback=lambda: self.apply())

self.apply_button = gui.auto_apply(self.controlArea, self, commit=self.apply)
Expand All @@ -104,6 +105,7 @@ def _set_input_summary(self):

@Inputs.data
def set_data(self, data):
self.controls.n_neighbors.setMaximum(len(data) if data else 100)
self.data = data

@Inputs.reference
Expand Down Expand Up @@ -157,21 +159,24 @@ def apply(self):

def _compute_indices(self):
self.Warning.all_data_as_reference.clear()
dist = self.distances
if dist is None:
self.Info.removed_references.clear()

if self.distances is None:
return None
if self.exclude_reference:
non_ref = dist > 1e-5
skip = len(dist) - non_ref.sum()
up_to = min(self.n_neighbors + skip, len(dist))
if skip >= up_to:
self.Warning.all_data_as_reference()
return None
indices = np.argpartition(dist, up_to - 1)[:up_to]
return indices[non_ref[indices]]
else:
up_to = min(self.n_neighbors, len(dist))
return np.argpartition(dist, up_to - 1)[:up_to]

inrefs = np.isin(self.data.ids, self.reference.ids)
if np.all(inrefs):
self.Warning.all_data_as_reference()
return None
if np.any(inrefs):
self.Info.removed_references()

dist = np.copy(self.distances)
dist[inrefs] = np.max(dist) + 1
up_to = len(dist) - np.sum(inrefs)
if self.limit_neighbors and self.n_neighbors < up_to:
up_to = self.n_neighbors
return np.argpartition(dist, up_to - 1)[:up_to]

def _data_with_similarity(self, indices):
data = self.data
Expand Down
100 changes: 33 additions & 67 deletions Orange/widgets/data/tests/test_owneighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,6 @@ def test_exclude_reference(self):
for inst in reference:
self.assertNotIn(inst, neighbors)

def test_include_reference(self):
"""Check neighbors when reference is included"""
widget = self.widget
widget.controls.exclude_reference.setChecked(False)
reference = self.iris[:5]
self.send_signal(widget.Inputs.data, self.iris)
self.send_signal(widget.Inputs.reference, reference)
widget.apply_button.button.click()
neighbors = self.get_output("Neighbors")
for inst in reference:
self.assertIn(inst, neighbors)

def test_similarity(self):
widget = self.widget
reference = self.iris[:10]
Expand Down Expand Up @@ -192,7 +180,7 @@ def test_compute_distances_apply_called(self):
def test_compute_distances_calls_distance(self):
widget = self.widget
widget.distance_index = 2
dists = np.random.random((5, 10))
dists = np.random.random((10, 5))
distance = Mock(return_value=dists)
try:
orig_metrics = METRICS[widget.distance_index]
Expand Down Expand Up @@ -242,75 +230,44 @@ def test_compute_distances_distance_no_data(self):
finally:
METRICS[widget.distance_index] = orig_metrics

def test_compute_indices_with_reference(self):
def test_compute_indices_without_reference(self):
widget = self.widget
# Indices for easier reading: 0 1 2 3 4 5 6 7 8 9 10 11 12
widget.distances = np.array([4., 1, 7, 0, 5, 2, 4, 0, 2, 2, 2, 9, 8])

widget.exclude_reference = False
widget.n_neighbors = 3
self.assertEqual(sorted(widget._compute_indices()), [1, 3, 7])

widget.n_neighbors = 1
self.assertIn(list(widget._compute_indices()), ([3], [7]))

widget.n_neighbors = 5
ind = set(widget._compute_indices())
self.assertEqual(len(ind), 5)
self.assertTrue({1, 3, 7} < ind)
self.assertTrue(len({5, 8, 9, 10} & ind) == 2)

widget.n_neighbors = 100
self.assertEqual(sorted(widget._compute_indices()), list(range(13)))
widget.limit_neighbours = True

widget.n_neighbors = 13
self.assertEqual(sorted(widget._compute_indices()), list(range(13)))
# Indices for easier reading: 0 1 2 3 4 5 6 7 8 9 10 11 12
widget.distances = np.array([4, 1, 7, 0, 5, 2, 4, 0, 2, 2, 2, 9, 8])

widget.n_neighbors = 14
self.assertEqual(sorted(widget._compute_indices()), list(range(13)))
widget.data = Mock()
widget.data.ids = np.arange(13)
widget.reference = Mock()
widget.reference.ids = np.array([1, 3])

widget.n_neighbors = 12
self.assertEqual(
sorted(widget._compute_indices()),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12])

def test_compute_indices_without_reference(self):
widget = self.widget
# Indices for easier reading: 0 1 2 3 4 5 6 7 8 9 10 11 12
widget.distances = np.array([4., 1, 7, 0, 5, 2, 4, 0, 2, 2, 2, 9, 8])

widget.exclude_reference = True
widget.n_neighbors = 5
self.assertEqual(sorted(widget._compute_indices()), [1, 5, 8, 9, 10])
self.assertEqual(sorted(widget._compute_indices()), [5, 7, 8, 9, 10])

widget.n_neighbors = 1
self.assertEqual(list(widget._compute_indices()), [1])
self.assertEqual(list(widget._compute_indices()), [7])

widget.n_neighbors = 3
ind = set(widget._compute_indices())
self.assertEqual(len(ind), 3)
self.assertIn(1, ind)
self.assertIn(7, ind)
self.assertTrue(len({5, 8, 9, 10} & ind) == 2)

widget.n_neighbors = 100
self.assertEqual(
sorted(widget._compute_indices()),
[0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12])
[0, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12])

widget.n_neighbors = 11
widget.n_neighbors = 10
self.assertEqual(
sorted(widget._compute_indices()),
[0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12])
[0, 2, 4, 5, 6, 7, 8, 9, 10, 12])

widget.n_neighbors = 12
widget.limit_neighbours = False
self.assertEqual(
sorted(widget._compute_indices()),
[0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12])

widget.n_neighbors = 10
self.assertEqual(
sorted(widget._compute_indices()),
[0, 1, 2, 4, 5, 6, 8, 9, 10, 12])
[0, 2, 4, 5, 6, 7, 8, 9, 10, 12])

def test_data_with_similarity(self):
widget = self.widget
Expand Down Expand Up @@ -359,18 +316,16 @@ def test_all_equal_ref(self):
widget.auto_apply = True

data = Table("iris")
self.send_signal(widget.Inputs.data, data)
self.send_signal(widget.Inputs.reference, data[42:43])

orig_distances = widget.distances
widget.distances = np.zeros(len(data), dtype=float)
widget.apply()
self.send_signal(widget.Inputs.data, data[:10])
self.send_signal(widget.Inputs.reference, data[:10])
self.assertTrue(widget.Warning.all_data_as_reference.is_shown())
self.assertFalse(widget.Info.removed_references.is_shown())
self.assertIsNone(self.get_output(widget.Outputs.data))

widget.distances = orig_distances
self.send_signal(widget.Inputs.data, data[:15])
widget.apply()
self.assertFalse(widget.Warning.all_data_as_reference.is_shown())
self.assertTrue(widget.Info.removed_references.is_shown())
self.assertIsNotNone(self.get_output(widget.Outputs.data))

def test_different_domains(self):
Expand Down Expand Up @@ -516,6 +471,17 @@ def test_different_domains_same_names(self):
output = self.get_output(w.Outputs.data)
self.assertEqual(10, len(output))

def test_n_neighbours_spin_max(self):
w = self.widget
sb = w.controls.n_neighbors
default = sb.maximum()
self.send_signal(w.Inputs.data, self.iris)
self.assertEqual(sb.maximum(), len(self.iris))
self.send_signal(w.Inputs.data, self.iris[:20])
self.assertEqual(sb.maximum(), 20)
self.send_signal(w.Inputs.data, None)
self.assertEqual(sb.maximum(), default)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7d04ab1

Please sign in to comment.