diff --git a/doc/index.rst b/doc/index.rst index bac76c9..06f6a02 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -12,6 +12,7 @@ Widgets widgets/explain-model widgets/explain-prediction widgets/explain-predictions + widgets/ice Indices and tables ================== diff --git a/doc/widgets/ice.md b/doc/widgets/ice.md new file mode 100644 index 0000000..5cc0b7c --- /dev/null +++ b/doc/widgets/ice.md @@ -0,0 +1,34 @@ +ICE +=== + +Displays one line per instance that shows how the instance’s prediction changes when a feature changes. + +**Inputs** + +- Model: model +- Data: dataset + +The **ICE** (Individual Conditional Expectation) widget visualizes the dependence of the prediction on a feature for each instance separately, resulting in one line per instance, compared to one line overall in partial dependence plots. + + +![](images/ICE.png) + +1. Select a target class. +2. Select a feature. +3. Order features by importance (partial dependence averaged across all the samples). +4. Apply the color of a discrete feature. +5. If **Centered** is ticked, the plot lines will start at the origin of the y-axis. +5. If **Show mean** is ticked, the average across all the samples in the dataset is shown. +6. If **Send Automatically** is ticked, the output is sent automatically after any change. + Alternatively, click **Send**. +7. Get help, save the plot, make the report, set plot properties, or observe the size of input and output data. +8. Plot shows a line for each instance in the input dataset. + +Example +------- + +In the flowing example, we use the ICE widget to explain Random Forest model. In the File widget, we open the Housing dataset. We connect it to the Random Forest widget, which trains the model. The ICE widget accepts the model and data which are used to explain the model. + +By selecting some arbitrary lines, the selected instances of the input dataset appear on the output of the ICE widget. + +![](images/ICE-Example.png) diff --git a/doc/widgets/images/ICE-example.png b/doc/widgets/images/ICE-example.png new file mode 100644 index 0000000..e9e2efe Binary files /dev/null and b/doc/widgets/images/ICE-example.png differ diff --git a/doc/widgets/images/ICE.png b/doc/widgets/images/ICE.png new file mode 100644 index 0000000..defc4d7 Binary files /dev/null and b/doc/widgets/images/ICE.png differ diff --git a/orangecontrib/explain/inspection.py b/orangecontrib/explain/inspection.py index aa89d1a..36475e2 100644 --- a/orangecontrib/explain/inspection.py +++ b/orangecontrib/explain/inspection.py @@ -1,12 +1,13 @@ """ Permutation feature importance for models. """ -from typing import Callable +from typing import Callable, Tuple, Optional, Dict import numpy as np import scipy.sparse as sp +from sklearn.inspection import partial_dependence -from Orange.base import Model +from Orange.base import Model, SklModel from Orange.classification import Model as ClsModel -from Orange.data import Table +from Orange.data import Table, Variable, DiscreteVariable from Orange.evaluation import Results from Orange.evaluation.scoring import Score, TargetScore, RegressionScore, R2 from Orange.regression import Model as RegModel @@ -19,7 +20,7 @@ def permutation_feature_importance( score: Score, n_repeats: int = 5, progress_callback: Callable = None -): +) -> np.ndarray: """ Function calculates feature importance of a model for a given data. @@ -174,3 +175,46 @@ def _calculate_permutation_scores( progress_callback(1) return scores + + +def individual_condition_expectation( + model: SklModel, + data: Table, + feature: Variable, + grid_resolution: int = 1000, + kind: str = "both", + progress_callback: Callable = dummy_callback +) -> Dict[str, np.ndarray]: + progress_callback(0) + _check_data(data) + needs_pp = _check_model(model, data) + if needs_pp: + data = model.data_to_model_domain(data) + + assert feature.name in [a.name for a in data.domain.attributes] + feature_index = data.domain.index(feature.name) + + assert isinstance(model, SklModel), f"Model ({model}) is not supported." + progress_callback(0.1) + + dep = partial_dependence(model.skl_model, + data.X, + [feature_index], + grid_resolution=grid_resolution, + kind=kind) + + results = {"average": dep["average"], "values": dep["values"][0]} + if kind == "both": + results["individual"] = dep["individual"] + + if data.domain.has_discrete_class and \ + len(data.domain.class_var.values) == 2: + results = {"average": np.vstack([1 - dep["average"], dep["average"]]), + "values": dep["values"][0]} + if kind == "both": + results["individual"] = \ + np.vstack([1 - dep["individual"], dep["individual"]]) + + progress_callback(1) + + return results diff --git a/orangecontrib/explain/tests/test_inspection.py b/orangecontrib/explain/tests/test_inspection.py index 7a1f4b9..d371213 100644 --- a/orangecontrib/explain/tests/test_inspection.py +++ b/orangecontrib/explain/tests/test_inspection.py @@ -3,19 +3,19 @@ import pkg_resources import numpy as np -from sklearn.inspection import permutation_importance +from sklearn.inspection import permutation_importance, partial_dependence from Orange.base import Model from Orange.classification import NaiveBayesLearner, RandomForestLearner, \ LogisticRegressionLearner, TreeLearner -from Orange.data import Table, Domain +from Orange.data import Table, Domain, DiscreteVariable from Orange.data.table import DomainTransformationError from Orange.evaluation import CA, MSE, AUC from Orange.regression import RandomForestRegressionLearner, \ TreeLearner as TreeRegressionLearner from orangecontrib.explain.inspection import permutation_feature_importance, \ - _wrap_score, _check_model + _wrap_score, _check_model, individual_condition_expectation def _permutation_feature_importance_skl( @@ -284,5 +284,101 @@ def test_sparse_data(self): ) +class TestIndividualConditionalExpectation(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.iris = Table.from_file("iris") + cls.heart = Table.from_file("heart_disease") + cls.housing = Table.from_file("housing") + + def test_discrete_class(self): + data = self.iris[:100] + class_var = DiscreteVariable("iris", data.domain.class_var.values[:2]) + data = data.transform(Domain(data.domain.attributes, class_var)) + model = RandomForestLearner(n_estimators=10, random_state=0)(data) + res = individual_condition_expectation(model, data, data.domain[0]) + self.assertIsInstance(res, dict) + self.assertEqual(res["average"].shape, (2, 28)) + self.assertEqual(res["individual"].shape, (2, 100, 28)) + self.assertEqual(res["values"].shape, (28,)) + + def test_discrete_class_result_values(self): + data = self.iris[:100] + class_var = DiscreteVariable("iris", data.domain.class_var.values[:2]) + data = data.transform(Domain(data.domain.attributes, class_var)) + model1 = RandomForestLearner(n_estimators=10, random_state=0)(data) + + data.Y = np.abs(data.Y - 1) + model2 = RandomForestLearner(n_estimators=10, random_state=0)(data) + + res = individual_condition_expectation(model1, data, data.domain[0]) + dep1 = partial_dependence(model1.skl_model, data.X, [0], kind="both") + dep2 = partial_dependence(model2.skl_model, data.X, [0], kind="both") + np.testing.assert_array_almost_equal( + res["average"][:1], dep2["average"]) + np.testing.assert_array_almost_equal( + res["average"][1:], dep1["average"]) + np.testing.assert_array_almost_equal( + res["individual"][:1], dep2["individual"]) + np.testing.assert_array_almost_equal( + res["individual"][1:], dep1["individual"]) + + def test_continuous_class(self): + data = self.housing + model = RandomForestRegressionLearner(n_estimators=10, random_state=0)(data) + res = individual_condition_expectation(model, data, data.domain[0]) + self.assertIsInstance(res, dict) + self.assertEqual(res["average"].shape, (1, 504)) + self.assertEqual(res["individual"].shape, (1, 506, 504)) + self.assertEqual(res["values"].shape, (504,)) + + def test_multi_class(self): + data = self.iris + model = RandomForestLearner(n_estimators=10, random_state=0)(data) + res = individual_condition_expectation(model, data, data.domain[0]) + self.assertIsInstance(res, dict) + self.assertEqual(res["average"].shape, (3, 35)) + self.assertEqual(res["individual"].shape, (3, 150, 35)) + self.assertEqual(res["values"].shape, (35,)) + + def test_mixed_features(self): + data = self.heart + model = RandomForestLearner(n_estimators=10, random_state=0)(data) + res = individual_condition_expectation(model, data, data.domain[0]) + self.assertIsInstance(res, dict) + self.assertEqual(res["average"].shape, (2, 41)) + self.assertEqual(res["individual"].shape, (2, 303, 41)) + self.assertEqual(res["values"].shape, (41,)) + + def _test_sklearn(self): + from matplotlib import pyplot as plt + from sklearn.ensemble import RandomForestClassifier, \ + RandomForestRegressor + from sklearn.inspection import PartialDependenceDisplay + + X = self.housing.X + y = self.housing.Y + model = RandomForestRegressor(random_state=0) + + # X = self.iris.X[:100] + # y = self.iris.Y[:100] + # y = np.abs(y - 1) + # model = RandomForestClassifier(random_state=0) + model.fit(X, y) + display = PartialDependenceDisplay.from_estimator( + model, + X, + [X.shape[1] - 1], + target=0, + kind="both", + centered=True, + subsample=1000, + # grid_resolution=100, + random_state=0, + ) + + plt.show() + + if __name__ == "__main__": unittest.main() diff --git a/orangecontrib/explain/widgets/icons/ICE.svg b/orangecontrib/explain/widgets/icons/ICE.svg new file mode 100644 index 0000000..c4243e2 --- /dev/null +++ b/orangecontrib/explain/widgets/icons/ICE.svg @@ -0,0 +1,23 @@ + + + + + + + + + + diff --git a/orangecontrib/explain/widgets/owice.py b/orangecontrib/explain/widgets/owice.py new file mode 100644 index 0000000..4b91569 --- /dev/null +++ b/orangecontrib/explain/widgets/owice.py @@ -0,0 +1,862 @@ +import bisect +from types import SimpleNamespace +from typing import Optional, Dict, List, Tuple, Any +from xml.sax.saxutils import escape + +import numpy as np +from AnyQt.QtCore import Qt, QSortFilterProxyModel, QSize, QModelIndex, \ + QItemSelection, QPointF, Signal, QLineF +from AnyQt.QtGui import QColor +from AnyQt.QtWidgets import QComboBox, QSizePolicy, QGraphicsSceneHelpEvent, \ + QToolTip, QGraphicsLineItem, QApplication + +import pyqtgraph as pg + +from orangecanvas.gui.utils import disconnected +from orangewidget.utils.listview import ListViewSearch + +from Orange.base import Model, SklModel, RandomForestModel +from Orange.data import Table, ContinuousVariable, Variable, \ + DiscreteVariable +from Orange.data.table import DomainTransformationError +from Orange.widgets import gui +from Orange.widgets.settings import ContextSetting, Setting, \ + PerfectDomainContextHandler +from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_SIGNAL_NAME, \ + create_annotated_table +from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin +from Orange.widgets.utils.itemmodels import VariableListModel, DomainModel +from Orange.widgets.utils.sql import check_sql_input +from Orange.widgets.utils.widgetpreview import WidgetPreview +from Orange.widgets.visualize.owdistributions import LegendItem +from Orange.widgets.visualize.utils.customizableplot import Updater, \ + CommonParameterSetter +from Orange.widgets.visualize.utils.plotutils import PlotWidget, \ + HelpEventDelegate +from Orange.widgets.widget import Input, OWWidget, Msg, Output + +from orangecontrib.explain.inspection import individual_condition_expectation +from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog + + +class RunnerResults(SimpleNamespace): + x_data: Optional[np.ndarray] = None + y_average: Optional[np.ndarray] = None + y_individual: Optional[np.ndarray] = None + + +def run( + data: Table, + feature: Variable, + model: Model, + state: TaskState +) -> Optional[RunnerResults]: + if not data or not model or not feature: + return None + + def callback(i: float, status=""): + state.set_progress_value(i * 100) + if status: + state.set_status(status) + if state.is_interruption_requested(): + raise Exception + + result = individual_condition_expectation( + model, data, feature, progress_callback=callback + ) + return RunnerResults(x_data=result["values"], + y_average=result["average"], + y_individual=result["individual"]) + + +def ccw(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: + """ + Checks whether three points are listed in a counterclockwise order. + """ + ax, ay = (a[:, 0], a[:, 1]) if a.ndim == 2 else (a[0], a[1]) + bx, by = (b[:, 0], b[:, 1]) if b.ndim == 2 else (b[0], b[1]) + cx, cy = (c[:, 0], c[:, 1]) if c.ndim == 2 else (c[0], c[1]) + return (cy - ay) * (bx - ax) > (by - ay) * (cx - ax) + + +def intersects( + a: np.ndarray, + b: np.ndarray, + c: np.ndarray, + d: np.ndarray +) -> np.ndarray: + """ + Checks whether line segment a (given points a and b) intersects with line + segment b (given points c and d). + """ + return np.logical_and(ccw(a, c, d) != ccw(b, c, d), + ccw(a, b, c) != ccw(a, b, d)) + + +def line_intersects_profiles( + p1: np.ndarray, + p2: np.ndarray, + lines: np.ndarray +) -> np.ndarray: + """ + Checks if a line intersects any line segments. + """ + res = np.zeros(len(lines[0]), dtype=bool) + for i in range(len(lines) - 1): + res = np.logical_or(res, intersects(p1, p2, lines[i], lines[i + 1])) + return res + + +class ICEPlotViewBox(pg.ViewBox): + sigSelectionChanged = Signal(object) + + def __init__(self): + super().__init__(enableMenu=False) + self.__lines = None + self.__selection_line = QGraphicsLineItem() + self.__selection_line.setPen(pg.mkPen(QColor(Qt.black), width=2)) + self.__selection_line.setZValue(1e9) + self.addItem(self.__selection_line, ignoreBounds=True) + + def __update_selection_line(self, button_down_pos, current_pos): + p1 = self.childGroup.mapFromParent(button_down_pos) + p2 = self.childGroup.mapFromParent(current_pos) + self.__selection_line.setLine(QLineF(p1, p2)) + self.__selection_line.resetTransform() + self.__selection_line.show() + + def __get_selected(self, p1, p2): + if self.__lines is None: + return np.array(False) + return line_intersects_profiles(np.array([p1.x(), p1.y()]), + np.array([p2.x(), p2.y()]), + self.__lines) + + def set_lines(self, x_data: np.ndarray, y_data: np.ndarray): + if x_data is None or y_data is None: + self.__lines = None + return + self.__lines = np.array([np.vstack((np.full((1, y_data.shape[0]), x), + y_data[:, i].flatten())).T + for i, x in enumerate(x_data)]) + + def mouseDragEvent(self, ev, axis=None): + if axis is None: + ev.accept() + if ev.button() == Qt.LeftButton: + self.__update_selection_line(ev.buttonDownPos(), ev.pos()) + if ev.isFinish(): + self.__selection_line.hide() + p1 = self.childGroup.mapFromParent( + ev.buttonDownPos(ev.button())) + p2 = self.childGroup.mapFromParent(ev.pos()) + indices = np.flatnonzero(self.__get_selected(p1, p2)) + selection = list(indices) if len(indices) else None + self.sigSelectionChanged.emit(selection) + + def mouseClickEvent(self, ev): + ev.accept() + self.sigSelectionChanged.emit(None) + + +class SortProxyModel(QSortFilterProxyModel): + def lessThan(self, left: QModelIndex, right: QModelIndex) -> bool: + role = self.sortRole() + l_score = left.data(role) + r_score = right.data(role) + return r_score is not None and (l_score is None or l_score < r_score) + + +class ParameterSetter(CommonParameterSetter): + def __init__(self, master: "ICEPlot"): + super().__init__() + self.master: ICEPlot = master + + def update_setters(self): + self.initial_settings = { + self.LABELS_BOX: { + self.FONT_FAMILY_LABEL: self.FONT_FAMILY_SETTING, + self.TITLE_LABEL: self.FONT_SETTING, + self.AXIS_TITLE_LABEL: self.FONT_SETTING, + self.AXIS_TICKS_LABEL: self.FONT_SETTING, + self.LEGEND_LABEL: self.FONT_SETTING, + }, + self.ANNOT_BOX: { + self.TITLE_LABEL: {self.TITLE_LABEL: ("", "")}, + }, + } + + @property + def title_item(self): + return self.master.getPlotItem().titleLabel + + @property + def axis_items(self): + return [value["item"] for value in + self.master.getPlotItem().axes.values()] + + @property + def legend_items(self): + return self.master.legend.items + + +class ICEPlot(PlotWidget): + DEFAULT_COLOR = np.array([100, 100, 100]) + MAX_POINTS_IN_TOOLTIP = 5 + + def __init__(self, parent: OWWidget): + super().__init__(parent, enableMenu=False, viewBox=ICEPlotViewBox()) + self.legend = self._create_legend(((1, 0), (1, 0))) + self.setAntialiasing(True) + self.setMouseEnabled(False, False) + self.getPlotItem().setContentsMargins(10, 10, 10, 10) + self.getPlotItem().buttonsHidden = True + self.getPlotItem().scene().sigMouseMoved.connect(self.__on_mouse_moved) + + self.__data: Table = None + self.__feature: ContinuousVariable = None + self.__x_data: Optional[np.ndarray] = None + self.__y_individual: Optional[np.ndarray] = None + self.__lines_items: List[pg.PlotCurveItem] = [] + self.__sel_lines_item: Optional[pg.PlotCurveItem] = None + self.__mean_line_item: Optional[pg.PlotCurveItem] = None + self.__hovered_lines_item: Optional[pg.PlotCurveItem] = None + self.__hovered_scatter_item: Optional[pg.ScatterPlotItem] = None + + self._help_delegate = HelpEventDelegate(self._help_event) + self.scene().installEventFilter(self._help_delegate) + + self.parameter_setter = ParameterSetter(self) + + def _create_legend(self, anchor: Tuple) -> LegendItem: + legend = LegendItem() + legend.setParentItem(self.getViewBox()) + legend.restoreAnchor(anchor) + legend.hide() + return legend + + def __on_mouse_moved(self, point: QPointF): + if self.__hovered_lines_item is None: + return + + self.__hovered_lines_item.setData(None, None, connect=None) + self.__hovered_scatter_item.setData(None, None) + + if QApplication.mouseButtons() != Qt.NoButton: + return + + view_pos: QPointF = self.getViewBox().mapSceneToView(point) + indices = self._indices_at(view_pos) + if not indices: + return + + # lines + y_individual = self.__y_individual[indices] + connect = np.ones(y_individual.shape) + connect[:, -1] = 0 + self.__hovered_lines_item.setData( + np.tile(self.__x_data, len(y_individual)), + y_individual.flatten(), + connect=connect.flatten() + ) + + # points + x_data = self.__data.get_column_view(self.__feature)[0][indices] + n_dec = self.__feature.number_of_decimals + y_data = [] + for i, x in zip(indices, x_data): + mask = np.round(self.__x_data, n_dec) == round(x, n_dec) + idx = np.flatnonzero(mask) + y = self.__y_individual[i, idx[0]] if len(idx) > 0 else np.nan + y_data.append(y) + + y_data = np.array(y_data) + mask = ~np.isnan(y_data) + self.__hovered_scatter_item.setData(x_data[mask], y_data[mask]) + + def set_data( + self, + data: Table, + feature: ContinuousVariable, + x_data: np.ndarray, + y_average: np.ndarray, + y_individual: np.ndarray, + y_label: str, + colors: Optional[np.ndarray], + color_col: Optional[np.ndarray], + color_labels: Optional[Tuple[str, str, str]], + show_mean: bool, + ): + self.__data = data + self.__feature = feature + self.__x_data = x_data + self.__y_individual = y_individual + self._add_lines(y_average, show_mean, colors, color_col) + self._set_axes(feature.name, y_label) + self._set_legend(color_labels, colors) + self.getViewBox().set_lines(x_data, y_individual) + + def set_selection(self, selection: Optional[List[int]]): + if self.__sel_lines_item is None: + return + self.__sel_lines_item.setData(None, None, connect=None) + if selection is not None: + y_individual = self.__y_individual[selection] + connect = np.ones(y_individual.shape) + connect[:, -1] = 0 + self.__sel_lines_item.setData( + np.tile(self.__x_data, len(y_individual)), + y_individual.flatten(), + connect=connect.flatten() + ) + + def _set_legend(self, labels: Optional[Tuple], + colors: Optional[np.ndarray]): + self.legend.clear() + self.legend.hide() + if labels is not None: + for name, color in zip(labels, colors): + c = QColor(*color) + dots = pg.ScatterPlotItem(pen=c, brush=c, size=10, shape="s") + self.legend.addItem(dots, escape(name)) + self.legend.show() + Updater.update_legend_font(self.parameter_setter.legend_items, + **self.parameter_setter.legend_settings) + + def _set_axes(self, x_label: str, y_label: str): + self.getAxis("bottom").setLabel(x_label) + self.getAxis("left").setLabel(y_label) + + def _add_lines( + self, + y_average: np.ndarray, + show_mean: bool, + colors: np.ndarray, + color_col: np.ndarray, + ): + if colors is None: + colors = self.DEFAULT_COLOR[None, :] + color_col = np.zeros(len(self.__y_individual)) + + x_data = self.__x_data + for i, color in enumerate(colors): + y_data = self.__y_individual[color_col == i] + self.__add_curve_item(x_data, y_data, color) + + mask = np.isnan(color_col) + if any(mask): + y_data = self.__y_individual[mask] + self.__add_curve_item(x_data, y_data, self.DEFAULT_COLOR) + + self.__sel_lines_item = pg.PlotCurveItem( + pen=pg.mkPen(QColor("#555"), width=2), antialias=True + ) + self.addItem(self.__sel_lines_item) + + color = QColor("#1f77b4") + self.__hovered_lines_item = pg.PlotCurveItem( + pen=pg.mkPen(color, width=2), antialias=True + ) + self.addItem(self.__hovered_lines_item) + + size = 8 + self.__hovered_scatter_item = pg.ScatterPlotItem( + pen=color, brush=color, size=size, shape="o" + ) + self.addItem(self.__hovered_scatter_item) + + self.__mean_line_item = pg.PlotCurveItem( + x_data, y_average, + pen=pg.mkPen(color=QColor("#ffbe00"), width=5), + antialias=True + ) + self.addItem(self.__mean_line_item) + self.set_show_mean(show_mean) + + color = QColor(0, 0, 0, 0) + dummy = pg.ScatterPlotItem( + [np.min(x_data), np.max(x_data)], + [np.min(self.__y_individual), np.max(self.__y_individual)], + pen=color, brush=color, size=size, shape="o" + ) + self.addItem(dummy) + + def set_show_mean(self, show: bool): + if self.__mean_line_item is not None: + self.__mean_line_item.setVisible(show) + + def clear_all(self): + self.__data = None + self.__feature = None + self.__x_data = None + self.__y_individual = None + if self.__mean_line_item is not None: + self.removeItem(self.__mean_line_item) + self.__mean_line_item = None + for lines in self.__lines_items: + self.removeItem(lines) + self.__lines_items.clear() + if self.__sel_lines_item is not None: + self.removeItem(self.__sel_lines_item) + self.__sel_lines_item = None + if self.__hovered_lines_item is not None: + self.removeItem(self.__hovered_lines_item) + self.__hovered_lines_item = None + if self.__hovered_scatter_item is not None: + self.removeItem(self.__hovered_scatter_item) + self.__hovered_scatter_item = None + self.clear() + self.legend.hide() + self._set_axes(None, None) + self.getViewBox().set_lines(None, None) + + def __add_curve_item(self, x_data, y_data, color): + connect = np.ones(y_data.shape) + connect[:, -1] = 0 + lines = pg.PlotCurveItem( + np.tile(x_data, len(y_data)), y_data.flatten(), + connect=connect.flatten(), antialias=True, + pen=pg.mkPen(color=QColor(*color, 100), width=1) + ) + self.addItem(lines) + self.__lines_items.append(lines) + + def _indices_at(self, pos: QPointF) -> List[int]: + if not self.__x_data[0] <= pos.x() <= self.__x_data[-1]: + return [] + + index = bisect.bisect(self.__x_data, round(pos.x(), 2)) - 1 + assert 0 <= index < len(self.__x_data) + if index < len(self.__x_data) - 1: + x = pos.x() + x_left = self.__x_data[index] + x_right = self.__x_data[index + 1] + y_left = self.__y_individual[:, index] + y_right = self.__y_individual[:, index + 1] + y = (x - x_left) * (y_right - y_left) / (x_right - x_left) + y_left + + else: + y = self.__y_individual[:, -1] + + # eps is pixel size dependent + vb: ICEPlotViewBox = self.getViewBox() + _, px_height = vb.viewPixelSize() + mask = np.abs(y - pos.y()) < px_height * 5 # 5 px + return np.flatnonzero(mask).tolist() + + def _help_event(self, event: QGraphicsSceneHelpEvent): + if self.__mean_line_item is None: + return False + + pos = self.__mean_line_item.mapFromScene(event.scenePos()) + indices = self._indices_at(pos) + text = self._get_tooltip(indices) + if text: + QToolTip.showText(event.screenPos(), text, widget=self) + return True + return False + + def _get_tooltip(self, indices: List[int]) -> str: + text = "
".join(self.__instance_tooltip(self.__data, idx) for idx + in indices[:self.MAX_POINTS_IN_TOOLTIP]) + if len(indices) > self.MAX_POINTS_IN_TOOLTIP: + text = f"{len(indices)} instances
{text}
..." + return text + + @staticmethod + def __instance_tooltip(data: Table, idx: int) -> str: + def show_part(_point_data, singular, plural, max_shown, _vars): + cols = [escape('{} = {}'.format(var.name, _point_data[var])) + for var in _vars[:max_shown + 2]][:max_shown] + if not cols: + return "" + n_vars = len(_vars) + if n_vars > max_shown: + cols[-1] = "... and {} others".format(n_vars - max_shown + 1) + return \ + "{}:
".format(singular if n_vars < 2 else plural) \ + + "
".join(cols) + + parts = (("Class", "Classes", 4, data.domain.class_vars), + ("Meta", "Metas", 4, data.domain.metas), + ("Feature", "Features", 10, data.domain.attributes)) + return "
".join(show_part(data[idx], *cols) for cols in parts) + + +class OWICE(OWWidget, ConcurrentWidgetMixin): + name = "ICE" + description = "Dependence between a target and a feature of interest." + keywords = ["ICE", "PDP", "partial", "dependence"] + icon = "icons/ICE.svg" + priority = 130 + + class Inputs: + model = Input("Model", (SklModel, RandomForestModel)) + data = Input("Data", Table) + + class Outputs: + selected_data = Output("Selected Data", Table, default=True) + annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table) + + class Error(OWWidget.Error): + domain_transform_err = Msg("{}") + unknown_err = Msg("{}") + not_enough_data = Msg("At least two instances are needed.") + no_cont_features = Msg("At least one numeric feature is required.") + + class Information(OWWidget.Information): + data_sampled = Msg("Data has been sampled.") + + buttons_area_orientation = Qt.Vertical + + settingsHandler = PerfectDomainContextHandler() + target_index = ContextSetting(0) + feature = ContextSetting(None) + order_by_importance = Setting(False) + color_var = ContextSetting(None) + centered = Setting(True) + show_mean = Setting(True) + auto_send = Setting(True) + selection = Setting(None, schema_only=True) + visual_settings = Setting({}, schema_only=True) + + graph_name = "graph.plotItem" + MIN_INSTANCES = 2 + MAX_INSTANCES = 300 + + def __init__(self): + OWWidget.__init__(self) + ConcurrentWidgetMixin.__init__(self) + + self.__results: Optional[RunnerResults] = None + self.__results_avgs: Optional[Dict[ContinuousVariable, float]] = None + self.__sampled_mask: Optional[np.ndarray] = None + self.__pending_selection = self.selection + self.model: Optional[Model] = None + self.data: Optional[Table] = None + self.graph: ICEPlot = None + self._target_combo: QComboBox = None + self._features_view: ListViewSearch = None + self._features_model: VariableListModel = None + self._color_model: DomainModel = None + + self.setup_gui() + + VisualSettingsDialog(self, self.graph.parameter_setter.initial_settings) + + def setup_gui(self): + self._add_plot() + self._add_controls() + self._add_buttons() + + def _add_plot(self): + box = gui.vBox(self.mainArea) + self.graph = ICEPlot(self) + view_box = self.graph.getViewBox() + view_box.sigSelectionChanged.connect(self.__on_selection_changed) + box.layout().addWidget(self.graph) + + def __on_selection_changed(self, selection: Optional[List[int]]): + self.select_instances(selection) + self.commit.deferred() + + def _add_controls(self): + box = gui.vBox(self.controlArea, "Target class") + self._target_combo = gui.comboBox( + box, self, "target_index", contentsLength=12, + callback=self.__on_target_changed + ) + + box = gui.vBox(self.controlArea, "Feature") + self._features_model = VariableListModel() + sorted_model = SortProxyModel(sortRole=Qt.UserRole) + sorted_model.setSourceModel(self._features_model) + sorted_model.sort(0) + self._features_view = ListViewSearch() + self._features_view.setModel(sorted_model) + self._features_view.setMinimumSize(QSize(30, 100)) + self._features_view.setSizePolicy(QSizePolicy.Expanding, + QSizePolicy.Expanding) + self._features_view.selectionModel().selectionChanged.connect( + self.__on_feature_changed + ) + box.layout().addWidget(self._features_view) + gui.checkBox(box, self, "order_by_importance", "Order by importance", + callback=self.__on_order_changed) + + box = gui.vBox(self.controlArea, "Display") + self._color_model = DomainModel(placeholder="None", separators=False, + valid_types=DiscreteVariable) + gui.comboBox(box, self, "color_var", label="Color:", searchable=True, + model=self._color_model, orientation=Qt.Horizontal, + contentsLength=12, callback=self.__on_parameter_changed) + gui.checkBox(box, self, "centered", "Centered", + callback=self.__on_parameter_changed) + gui.checkBox(box, self, "show_mean", "Show mean", + callback=self.__on_show_mean_changed) + + def __on_target_changed(self): + self._apply_feature_sorting() + self.__on_parameter_changed() + + def __on_feature_changed(self, selection: QItemSelection): + if not selection: + return + + self.feature = selection.indexes()[0].data(gui.TableVariable) + self._apply_feature_sorting() + self._run() + + def __on_order_changed(self): + self._apply_feature_sorting() + + def __on_parameter_changed(self): + self.__pending_selection = self.selection + self.setup_plot() + self.apply_selection() + + def __on_show_mean_changed(self): + self.graph.set_show_mean(self.show_mean) + + def _add_buttons(self): + gui.auto_send(self.buttonsArea, self, "auto_send") + + @Inputs.data + @check_sql_input + def set_data(self, data: Optional[Table]): + self.closeContext() + self.data = data + self.__sampled_mask = None + self._check_data() + self._setup_controls() + self.openContext(self.data.domain if self.data else None) + self.set_list_view_selection() + + @Inputs.model + def set_model(self, model: Optional[Model]): + self.model = model + + def _check_data(self): + self.Error.no_cont_features.clear() + self.Error.not_enough_data.clear() + self.Information.data_sampled.clear() + if self.data is None: + return + + self.__sampled_mask = np.ones(len(self.data), dtype=bool) + + if len(self.data) < self.MIN_INSTANCES: + self.data = None + self.Error.not_enough_data() + + if self.data and not self.data.domain.has_continuous_attributes(): + self.data = None + self.Error.no_cont_features() + + if self.data and len(self.data) > self.MAX_INSTANCES: + self.__sampled_mask[:] = False + np.random.seed(0) + kws = {"size": self.MAX_INSTANCES, "replace": False} + self.__sampled_mask[np.random.choice(len(self.data), **kws)] = True + self.Information.data_sampled() + + def _setup_controls(self): + domain = self.data.domain if self.data else None + + self._target_combo.clear() + self._target_combo.setEnabled(True) + self._features_model.clear() + self._color_model.set_domain(domain) + self.color_var = None + + if domain is not None: + features = [var for var in domain.attributes if var.is_continuous + and not var.attributes.get("hidden", False)] + self._features_model[:] = features + if domain.has_discrete_class: + self._target_combo.addItems(domain.class_var.values) + self.target_index = 0 + elif domain.has_continuous_class: + self._target_combo.setEnabled(False) + self.target_index = -1 + if len(self._features_model) > 0: + self.feature = self._features_model[0] + + def set_list_view_selection(self): + model = self._features_view.model() + sel_model = self._features_view.selectionModel() + src_model = model.sourceModel() + if self.feature not in src_model: + return + + with disconnected(sel_model.selectionChanged, + self.__on_feature_changed): + row = src_model.indexOf(self.feature) + sel_model.select(model.index(row, 0), sel_model.ClearAndSelect) + + self._ensure_selection_visible(self._features_view) + + @staticmethod + def _ensure_selection_visible(view): + selection = view.selectedIndexes() + if len(selection) == 1: + view.scrollTo(selection[0]) + + def handleNewSignals(self): + self.__results_avgs = None + self._apply_feature_sorting() + self._run() + self.selection = None + self.commit.now() + + def _apply_feature_sorting(self): + if self.data is None or self.model is None: + return + + order = list(range(len(self._features_model))) + if self.order_by_importance: + def compute_score(feature): + values = self.__results_avgs[feature][self.target_index] + return -np.sum(np.abs(values - np.mean(values))) + + try: + if self.__results_avgs is None: + msk = self.__sampled_mask + self.__results_avgs = { + feature: individual_condition_expectation( + self.model, self.data[msk], feature, kind="average" + )["average"] for feature in self._features_model + } + order = [compute_score(f) for f in self._features_model] + except Exception: + pass + + for i in range(self._features_model.rowCount()): + self._features_model.setData(self._features_model.index(i), + order[i], Qt.UserRole) + + self._ensure_selection_visible(self._features_view) + + def _run(self): + self.clear() + data = self.data[self.__sampled_mask] if self.data else None + self.start(run, data, self.feature, self.model) + + def clear(self): + self.__results = None + self.selection = None + self.cancel() + self.Error.domain_transform_err.clear() + self.Error.unknown_err.clear() + self.graph.clear_all() + + def setup_plot(self): + self.graph.clear_all() + if not self.__results or not self.data: + return + + x_data = self.__results.x_data + y_average = self.__results.y_average[self.target_index] + y_individual = self.__results.y_individual[self.target_index] + + class_var: Variable = self.model.original_domain.class_var + if class_var.is_discrete: + cls_val = class_var.values[self.target_index] + y_label = f"P({class_var.name}={cls_val})" + else: + y_label = f"{class_var.name}" + + if self.centered: + y_average = y_average - y_average[0, None] + y_individual = y_individual - y_individual[:, 0, None] + y_label = "Δ " + y_label + + mask = self.__sampled_mask + colors = None + color_col = None + color_labels = None + if self.color_var and self.color_var.is_discrete: + colors = self.color_var.colors + color_col = self.data[mask].get_column_view(self.color_var)[0] + color_labels = self.color_var.values + + self.graph.set_data(self.data[mask], self.feature, + x_data, y_average, y_individual, y_label, colors, + color_col, color_labels, self.show_mean) + + def on_partial_result(self, _): + pass + + def on_done(self, results: Optional[RunnerResults]): + self.__results = results + self.setup_plot() + self.apply_selection() + self.commit.deferred() + + def apply_selection(self): + if self.__pending_selection is not None: + n_inst = len(self.data) + self.__pending_selection = \ + [i for i in self.__pending_selection if i < n_inst] + + mask = self.__sampled_mask + if not all(mask): + selection = np.zeros(len(mask), dtype=int) + selection[mask] = np.arange(sum(mask)) + self.__pending_selection = selection[self.__pending_selection] + + self.select_instances(self.__pending_selection) + self.__pending_selection = None + + def select_instances(self, selection: Optional[List[int]]): + self.graph.set_selection(selection) + if selection is not None: + indices = np.arange(len(self.__sampled_mask))[self.__sampled_mask] + self.selection = list(indices[selection]) + else: + self.selection = None + + def on_exception(self, ex: Exception): + if isinstance(ex, DomainTransformationError): + self.Error.domain_transform_err(ex) + else: + self.Error.unknown_err(ex) + + @gui.deferred + def commit(self): + selected = self.data[self.selection] \ + if self.data is not None and self.selection is not None else None + annotated = create_annotated_table(self.data, self.selection) + self.Outputs.selected_data.send(selected) + self.Outputs.annotated_data.send(annotated) + + def onDeleteWidget(self): + self.shutdown() + super().onDeleteWidget() + + def send_report(self): + if not self.data or not self.model: + return + items = {"Target class": "None"} + if self.model.domain.has_discrete_class: + class_var = self.model.domain.class_var + items["Target class"] = class_var.values[self.target_index] + self.report_items(items) + self.report_plot() + + def set_visual_settings(self, key: Tuple[str, str, str], value: Any): + self.graph.parameter_setter.set_parameter(key, value) + self.visual_settings[key] = value + + +if __name__ == "__main__": + from Orange.classification import RandomForestLearner + from Orange.regression import RandomForestRegressionLearner + + table = Table("iris") + kwargs_ = {"n_estimators": 100, "random_state": 0} + if table.domain.has_continuous_class: + model_ = RandomForestRegressionLearner(**kwargs_)(table) + else: + model_ = RandomForestLearner(**kwargs_)(table) + WidgetPreview(OWICE).run(set_data=table, set_model=model_) diff --git a/orangecontrib/explain/widgets/tests/test_owice.py b/orangecontrib/explain/widgets/tests/test_owice.py new file mode 100644 index 0000000..e67ff9e --- /dev/null +++ b/orangecontrib/explain/widgets/tests/test_owice.py @@ -0,0 +1,145 @@ +# pylint: disable=missing-docstring +import unittest +from unittest.mock import Mock + +from AnyQt.QtCore import Qt, QPointF + +from Orange.classification import RandomForestLearner +from Orange.data import Table +from Orange.regression import RandomForestRegressionLearner +from Orange.widgets.tests.base import WidgetTest +from orangecontrib.explain.widgets.owice import OWICE + + +class TestOWICE(WidgetTest): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.iris = Table("iris") + cls.heart = Table("heart_disease") + cls.housing = Table("housing") + cls.titanic = Table("titanic") + kwargs = {"random_state": 0} + cls.rf_cls = RandomForestLearner(**kwargs)(cls.heart) + cls.rf_reg = RandomForestRegressionLearner(**kwargs)(cls.housing) + + def setUp(self): + self.widget = self.create_widget(OWICE) + + def test_input_cls(self): + self.send_signal(self.widget.Inputs.data, self.heart) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + self.wait_until_finished() + self.assertFalse(self.widget.Error.unknown_err.is_shown()) + + self.send_signal(self.widget.Inputs.model, self.rf_reg) + self.wait_until_finished() + self.assertTrue(self.widget.Error.unknown_err.is_shown()) + + self.send_signal(self.widget.Inputs.model, None) + self.assertFalse(self.widget.Error.unknown_err.is_shown()) + + self.send_signal(self.widget.Inputs.data, self.iris) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + self.wait_until_finished() + self.assertTrue(self.widget.Error.domain_transform_err.is_shown()) + + def test_output(self): + self.send_signal(self.widget.Inputs.data, self.heart) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + self.assertIsNone(self.get_output(self.widget.Outputs.selected_data)) + annotated = self.get_output(self.widget.Outputs.annotated_data) + self.assertEqual(len(annotated), len(self.heart)) + + def test_discrete_features(self): + self.send_signal(self.widget.Inputs.data, self.titanic) + self.assertTrue(self.widget.Error.no_cont_features.is_shown()) + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertFalse(self.widget.Error.no_cont_features.is_shown()) + + def test_order_features(self): + self.send_signal(self.widget.Inputs.data, self.heart) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + + model = self.widget._features_view.model() + model_data = [model.data(model.index(i, 0)) + for i in range(model.rowCount())] + attrs = self.heart.domain.attributes + cont_var_names = [a.name for a in attrs if a.is_continuous] + self.assertEqual(model_data, cont_var_names) + + self.widget.controls.order_by_importance.setChecked(True) + model_data = [model.data(model.index(i, 0)) + for i in range(model.rowCount())] + cont_var_names = ["max HR", "ST by exercise", "cholesterol", + "age", "rest SBP", "major vessels colored"] + self.assertEqual(model_data, cont_var_names) + + def test_sample_data(self): + self.send_signal(self.widget.Inputs.data, self.heart[:1]) + self.assertTrue(self.widget.Error.not_enough_data.is_shown()) + self.send_signal(self.widget.Inputs.data, self.heart) + self.assertTrue(self.widget.Information.data_sampled.is_shown()) + self.assertFalse(self.widget.Error.not_enough_data.is_shown()) + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Information.data_sampled.is_shown()) + + def test_selection(self): + self.send_signal(self.widget.Inputs.data, self.heart) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + self.wait_until_finished() + + event = Mock() + event.button.return_value = Qt.LeftButton + event.buttonDownPos.return_value = QPointF(30, -0.2) + event.pos.return_value = QPointF(50, -0.3) + event.isFinish.return_value = True + + self.widget.graph.getViewBox().mouseDragEvent(event) + self.assertIsInstance(self.widget.selection, list) + self.assertListEqual(self.widget.selection, [52, 214]) + selected = self.get_output(self.widget.Outputs.selected_data) + self.assertEqual(len(selected), 2) + + self.widget.graph.getViewBox().mouseClickEvent(event) + self.assertIsNone(self.widget.selection) + self.assertIsNone(self.get_output(self.widget.Outputs.selected_data)) + + self.widget.graph.getViewBox().mouseDragEvent(event) + self.assertIsNotNone(self.get_output(self.widget.Outputs.selected_data)) + + self.send_signal(self.widget.Inputs.model, None) + self.assertIsNone(self.get_output(self.widget.Outputs.selected_data)) + + def test_saved_selection(self): + self.send_signal(self.widget.Inputs.data, self.heart) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + self.wait_until_finished() + event = Mock() + event.button.return_value = Qt.LeftButton + event.buttonDownPos.return_value = QPointF(30, -0.2) + event.pos.return_value = QPointF(50, -0.3) + event.isFinish.return_value = True + self.widget.graph.getViewBox().mouseDragEvent(event) + output1 = self.get_output(self.widget.Outputs.selected_data) + + settings = self.widget.settingsHandler.pack_data(self.widget) + widget = self.create_widget(OWICE, stored_settings=settings) + self.send_signal(widget.Inputs.data, self.heart, widget=widget) + self.send_signal(widget.Inputs.model, self.rf_cls, widget=widget) + self.wait_until_finished(widget=widget) + output2 = self.get_output(widget.Outputs.selected_data, widget=widget) + self.assert_table_equal(output1, output2) + + def test_send_report(self): + self.widget.send_report() + self.send_signal(self.widget.Inputs.data, self.heart[:10]) + self.send_signal(self.widget.Inputs.model, self.rf_cls) + self.widget.send_report() + self.send_signal(self.widget.Inputs.data, self.housing[:10]) + self.send_signal(self.widget.Inputs.model, self.rf_reg) + self.widget.send_report() + + +if __name__ == "__main__": + unittest.main() diff --git a/setup.py b/setup.py index 050ca35..f536a1f 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ NAME = "Orange3-Explain" -VERSION = "0.5.3" +VERSION = "0.5.4" AUTHOR = "Bioinformatics Laboratory, FRI UL" AUTHOR_EMAIL = "contact@orange.biolab.si" @@ -45,6 +45,7 @@ "pyqtgraph", "scipy", "shap ==0.40.*", # shap makes significant changes between versions + "scikit-learn>=1.0.1", ] EXTRAS_REQUIRE = {