Skip to content

Commit

Permalink
Merge pull request #6618 from janezd/roc-output
Browse files Browse the repository at this point in the history
ROC: Output a model with a new operating threshold
  • Loading branch information
VesnaT committed Nov 3, 2023
2 parents 46a9b28 + 410d76a commit 5a5ebcf
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 102 deletions.
25 changes: 7 additions & 18 deletions Orange/widgets/evaluate/owcalibrationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from Orange.widgets import widget, gui, settings
from Orange.widgets.evaluate.contexthandlers import \
EvaluationResultsContextHandler
from Orange.widgets.evaluate.utils import results_for_preview
from Orange.widgets.evaluate.utils import results_for_preview, \
check_can_calibrate
from Orange.widgets.utils import colorpalettes
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.utils.customizableplot import \
Expand Down Expand Up @@ -486,23 +487,11 @@ def commit(self):
wrapped = None
results = self.results
if results is not None:
problems = [
msg for condition, msg in (
(results.folds is not None and len(results.folds) > 1,
"each training data sample produces a different model"),
(results.models is None,
"test results do not contain stored models - try testing "
"on separate data or on training data"),
(len(self.selected_classifiers) != 1,
"select a single model - the widget can output only one"),
(self.score != 0 and len(results.domain.class_var.values) != 2,
"cannot calibrate non-binary classes"))
if condition]
if len(problems) == 1:
self.Information.no_output(problems[0])
elif problems:
self.Information.no_output(
"".join(f"\n - {problem}" for problem in problems))
problems = check_can_calibrate(
self.results, self.selected_classifiers,
require_binary=self.score != 0)
if problems:
self.Information.no_output(problems)
else:
clsf_idx = self.selected_classifiers[0]
model = results.models[0, clsf_idx]
Expand Down
26 changes: 7 additions & 19 deletions Orange/widgets/evaluate/owliftcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from Orange.widgets import widget, gui, settings
from Orange.widgets.evaluate.contexthandlers import \
EvaluationResultsContextHandler
from Orange.widgets.evaluate.utils import check_results_adequacy
from Orange.widgets.evaluate.utils import check_results_adequacy, \
check_can_calibrate
from Orange.widgets.utils import colorpalettes
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.utils.customizableplot import Updater, \
Expand Down Expand Up @@ -267,7 +268,7 @@ def _initialize(self, results):
item = self.classifiers_list_box.item(i)
item.setIcon(colorpalettes.ColorIcon(color))

class_values = results.data.domain.class_var.values
class_values = results.domain.class_var.values
self.target_cb.addItems(class_values)
if class_values:
self.target_index = 0
Expand Down Expand Up @@ -493,23 +494,10 @@ def commit(self):
wrapped = None
results = self.results
if results is not None:
problems = [
msg for condition, msg in (
(results.folds is not None and len(results.folds) > 1,
"each training data sample produces a different model"),
(results.models is None,
"test results do not contain stored models - try testing "
"on separate data or on training data"),
(len(self.selected_classifiers) != 1,
"select a single model - the widget can output only one"),
(len(results.domain.class_var.values) != 2,
"cannot calibrate non-binary classes"))
if condition]
if len(problems) == 1:
self.Information.no_output(problems[0])
elif problems:
self.Information.no_output(
"".join(f"\n - {problem}" for problem in problems))
problems = check_can_calibrate(
self.results, self.selected_classifiers)
if problems:
self.Information.no_output(problems)
else:
clsf_idx = self.selected_classifiers[0]
model = results.models[0, clsf_idx]
Expand Down
48 changes: 39 additions & 9 deletions Orange/widgets/evaluate/owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
import pyqtgraph as pg

import Orange
from Orange.base import Model
from Orange.classification import ThresholdClassifier
from Orange.evaluation.testing import Results
from Orange.widgets import widget, gui, settings
from Orange.widgets.evaluate.contexthandlers import \
EvaluationResultsContextHandler
from Orange.widgets.evaluate.utils import check_results_adequacy
from Orange.widgets.evaluate.utils import check_results_adequacy, \
check_can_calibrate
from Orange.widgets.utils import colorpalettes
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.utils.plotutils import GraphicsView, PlotItem
from Orange.widgets.widget import Input
from Orange.widgets.widget import Input, Output, Msg
from Orange.widgets import report

from Orange.widgets.evaluate.utils import results_for_preview
from Orange.evaluation.testing import Results


#: Points on a ROC curve
Expand Down Expand Up @@ -305,6 +308,12 @@ class OWROCAnalysis(widget.OWWidget):
class Inputs:
evaluation_results = Input("Evaluation Results", Orange.evaluation.Results)

class Outputs:
calibrated_model = Output("Calibrated Model", Model)

class Information(widget.OWWidget.Information):
no_output = Msg("Can't output a model: {}")

buttons_area_orientation = None
settingsHandler = EvaluationResultsContextHandler()
target_index = settings.ContextSetting(0)
Expand Down Expand Up @@ -466,7 +475,7 @@ def _initialize(self, results):
listitem = self.classifiers_list_box.item(i)
listitem.setIcon(colorpalettes.ColorIcon(self.colors[i]))

class_var = results.data.domain.class_var
class_var = results.domain.class_var
self.target_cb.addItems(class_var.values)
self.target_index = 0
self._set_target_prior()
Expand Down Expand Up @@ -620,8 +629,7 @@ def no_averaging():
pen.setCosmetic(True)
self.plot.plot([0, 1], [0, 1], pen=pen, antialias=True)

if self.roc_averaging == OWROCAnalysis.Merge:
self._update_perf_line()
self._update_perf_line()

self._update_axes_ticks()

Expand Down Expand Up @@ -730,8 +738,7 @@ def _on_target_prior_changed(self):
self._on_display_perf_line_changed()

def _on_display_perf_line_changed(self):
if self.roc_averaging == OWROCAnalysis.Merge:
self._update_perf_line()
self._update_perf_line()

if self.perf_line is not None:
self.perf_line.setVisible(self.display_perf_line)
Expand All @@ -745,9 +752,12 @@ def _replot(self):
self._setup_plot()

def _update_perf_line(self):
if self._perf_line is None:

if self._perf_line is None or self.roc_averaging != OWROCAnalysis.Merge:
self._update_output(None)
return

ind = None
self._perf_line.setVisible(self.display_perf_line)
if self.display_perf_line:
m = roc_iso_performance_slope(
Expand All @@ -762,6 +772,26 @@ def _update_perf_line(self):
else:
self._perf_line.setVisible(False)

self._update_output(None if ind is None else hull.thresholds[ind[0]])

def _update_output(self, threshold):
self.Information.no_output.clear()

if threshold is None:
self.Outputs.calibrated_model.send(None)
return

problems = check_can_calibrate(self.results, self.selected_classifiers)
if problems:
self.Information.no_output(problems)
self.Outputs.calibrated_model.send(None)
return

model = ThresholdClassifier(
self.results.models[0][self.selected_classifiers[0]],
threshold)
self.Outputs.calibrated_model.send(model)

def onDeleteWidget(self):
self.clear()

Expand Down
54 changes: 52 additions & 2 deletions Orange/widgets/evaluate/tests/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,59 @@
from unittest.mock import Mock

import numpy as np

from Orange import classification, evaluation
from Orange.data import Table
from Orange.data import Table, Domain, DiscreteVariable
from Orange.evaluation import Results
from Orange.evaluation.performance_curves import Curves
from Orange.tests import test_filename

from Orange.widgets.tests.base import WidgetTest


class EvaluateTest(WidgetTest):
def setUp(self):
super().setUp()

n, p = (0, 1)
actual, probs = np.array([
(p, .8), (n, .7), (p, .6), (p, .55), (p, .54), (n, .53), (n, .52),
(p, .51), (n, .505), (p, .4), (n, .39), (p, .38), (n, .37),
(n, .36), (n, .35), (p, .34), (n, .33), (p, .30), (n, .1)]).T
self.curves = Curves(actual, probs)
probs2 = (probs + 1) / 2
self.curves2 = Curves(actual, probs2)
pred = probs > 0.5
pred2 = probs2 > 0.5
probs = np.vstack((1 - probs, probs)).T
probs2 = np.vstack((1 - probs2, probs2)).T
domain = Domain([], DiscreteVariable("y", values=("a", "b")))
self.results = Results(
domain=domain,
actual=actual,
folds=np.array([Ellipsis]),
models=np.array([[Mock(), Mock()]]),
row_indices=np.arange(19),
predicted=np.array((pred, pred2)),
probabilities=np.array([probs, probs2]))

self.lenses = data = Table(test_filename("datasets/lenses.tab"))
majority = classification.MajorityLearner()
majority.name = "majority"
knn3 = classification.KNNLearner(n_neighbors=3)
knn3.name = "knn-3"
knn1 = classification.KNNLearner(n_neighbors=1)
knn1.name = "knn-1"
self.lenses_results = evaluation.TestOnTestData(
store_data=True, store_models=True)(
data=data[::2], test_data=data[1::2],
learners=[majority, knn3, knn1])
self.lenses_results.learner_names = ["majority", "knn-3", "knn-1"]

class EvaluateTest:
def test_many_evaluation_results(self):
if not hasattr(self, "widget"):
return

data = Table("iris")
learners = [
classification.MajorityLearner(),
Expand Down
49 changes: 5 additions & 44 deletions Orange/widgets/evaluate/tests/test_owcalibrationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,15 @@

from orangewidget.utils.combobox import qcombobox_emit_activated

from Orange.data import Table, DiscreteVariable, Domain, ContinuousVariable
import Orange.evaluation
import Orange.classification
from Orange.evaluation import Results
from Orange.data import Domain, ContinuousVariable
from Orange.evaluation.performance_curves import Curves
from Orange.widgets.evaluate.tests.base import EvaluateTest
from Orange.widgets.evaluate.owcalibrationplot import OWCalibrationPlot
from Orange.widgets.tests.base import WidgetTest
from Orange.tests import test_filename


class TestOWCalibrationPlot(WidgetTest, EvaluateTest):
class TestOWCalibrationPlot(EvaluateTest):
def setUp(self):
super().setUp()

n, p = (0, 1)
actual, probs = np.array([
(p, .8), (n, .7), (p, .6), (p, .55), (p, .54), (n, .53), (n, .52),
(p, .51), (n, .505), (p, .4), (n, .39), (p, .38), (n, .37),
(n, .36), (n, .35), (p, .34), (n, .33), (p, .30), (n, .1)]).T
self.curves = Curves(actual, probs)
probs2 = (probs + 0.5) / 2 + 1
self.curves2 = Curves(actual, probs2)
pred = probs > 0.5
pred2 = probs2 > 0.5
probs = np.vstack((1 - probs, probs)).T
probs2 = np.vstack((1 - probs2, probs2)).T
domain = Domain([], DiscreteVariable("y", values=("a", "b")))
self.results = Results(
domain=domain,
actual=actual,
folds=np.array([Ellipsis]),
models=np.array([[Mock(), Mock()]]),
row_indices=np.arange(19),
predicted=np.array((pred, pred2)),
probabilities=np.array([probs, probs2]))

self.lenses = data = Table(test_filename("datasets/lenses.tab"))
majority = Orange.classification.MajorityLearner()
majority.name = "majority"
knn3 = Orange.classification.KNNLearner(n_neighbors=3)
knn3.name = "knn-3"
knn1 = Orange.classification.KNNLearner(n_neighbors=1)
knn1.name = "knn-1"
self.lenses_results = Orange.evaluation.TestOnTestData(
store_data=True, store_models=True)(
data=data[::2], test_data=data[1::2],
learners=[majority, knn3, knn1])
self.lenses_results.learner_names = ["majority", "knn-3", "knn-1"]

self.widget = self.create_widget(OWCalibrationPlot) # type: OWCalibrationPlot
warnings.filterwarnings("ignore", ".*", ConvergenceWarning)

Expand Down Expand Up @@ -382,6 +341,8 @@ def test_threshold_flips_on_two_classes(self):
@patch("Orange.widgets.evaluate.owcalibrationplot.CalibratedLearner")
def test_apply_no_output(self, *_):
"""Test no output warnings"""
# Similar to test_owcalibrationplot, but just a little different, hence
# pylint: disable=duplicate-code
widget = self.widget
model_list = widget.controls.selected_classifiers

Expand All @@ -395,7 +356,7 @@ def test_apply_no_output(self, *_):
multiple_selected:
"select a single model - the widget can output only one",
non_binary_class:
"cannot calibrate non-binary classes"}
"cannot calibrate non-binary models"}

def test_shown(shown):
widget_msg = widget.Information.no_output
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/tests/test_owliftcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SKIP_REASON = "Only test precision-recall with scikit-learn>=1.1.1"


class TestOWLiftCurve(WidgetTest, EvaluateTest):
class TestOWLiftCurve(EvaluateTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
Expand Down

0 comments on commit 5a5ebcf

Please sign in to comment.