Skip to content

Commit

Permalink
Merge pull request #4908 from PrimozGodec/rank-concurent
Browse files Browse the repository at this point in the history
[ENH] Rank widget computation in a separate thread
  • Loading branch information
markotoplak committed Oct 6, 2020
2 parents 7d04ab1 + 3836625 commit 3a6e272
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 97 deletions.
267 changes: 170 additions & 97 deletions Orange/widgets/data/owrank.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
import warnings
from collections import namedtuple, OrderedDict
import logging
import warnings
from collections import OrderedDict, namedtuple
from functools import partial
from itertools import chain
from types import SimpleNamespace
from typing import Any, Callable, List, Tuple

import numpy as np
from scipy.sparse import issparse

from AnyQt.QtCore import (
QItemSelection, QItemSelectionModel, QItemSelectionRange, Qt
)
from AnyQt.QtGui import QFontMetrics
from AnyQt.QtWidgets import (
QTableView, QRadioButton, QButtonGroup, QGridLayout,
QStackedWidget, QHeaderView, QCheckBox, QItemDelegate,
QButtonGroup, QCheckBox, QGridLayout, QHeaderView, QItemDelegate,
QRadioButton, QStackedWidget, QTableView
)
from AnyQt.QtCore import (
Qt, QItemSelection, QItemSelectionRange, QItemSelectionModel,
)

from orangewidget.settings import IncompatibleContext
from Orange.data import (Table, Domain, ContinuousVariable, DiscreteVariable,
StringVariable)
from scipy.sparse import issparse

from Orange.data import (
ContinuousVariable, DiscreteVariable, Domain, StringVariable, Table
)
from Orange.data.util import get_unique_names_duplicates
from Orange.misc.cache import memoize_method
from Orange.preprocess import score
from Orange.widgets import report
from Orange.widgets import gui
from Orange.widgets.settings import (DomainContextHandler, Setting,
ContextSetting)
from Orange.widgets import gui, report
from Orange.widgets.settings import (
ContextSetting, DomainContextHandler, Setting
)
from Orange.widgets.unsupervised.owdistances import InterruptException
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.itemmodels import PyTableModel
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.utils.state_summary import format_summary_details
from Orange.widgets.widget import (
OWWidget, Msg, Input, Output, AttributeList
)

from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import AttributeList, Input, Msg, Output, OWWidget

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,7 +167,79 @@ def _argsortData(self, data, order):
return indices


class OWRank(OWWidget):
class Results(SimpleNamespace):
method_scores: Tuple[ScoreMeta, np.ndarray] = None
scorer_scores: Tuple[ScoreMeta, Tuple[np.ndarray, List[str]]] = None


def get_method_scores(data: Table, method: ScoreMeta) -> np.ndarray:
estimator = method.scorer()
# The widget handles infs and nans.
# Any errors in scorers need to be detected elsewhere.
with np.errstate(all="ignore"):
try:
scores = np.asarray(estimator(data))
except ValueError:
try:
scores = np.array(
[estimator(data, attr) for attr in data.domain.attributes]
)
except ValueError:
log.error("%s doesn't work on this data", method.name)
scores = np.full(len(data.domain.attributes), np.nan)
else:
log.warning(
"%s had to be computed separately for each " "variable",
method.name,
)
return scores


def get_scorer_scores(
data: Table, scorer: ScoreMeta
) -> Tuple[np.ndarray, Tuple[str]]:
try:
scores = scorer.scorer.score_data(data).T
except (ValueError, TypeError):
log.error("%s doesn't work on this data", scorer.name)
scores = np.full((len(data.domain.attributes), 1), np.nan)

labels = (
(scorer.shortname,)
if scores.shape[1] == 1
else tuple(
scorer.shortname + "_" + str(i)
for i in range(1, 1 + scores.shape[1])
)
)
return scores, labels


def run(
data: Table,
methods: List[ScoreMeta],
scorers: List[ScoreMeta],
state: TaskState,
) -> Results:
progress_steps = iter(np.linspace(0, 100, len(methods) + len(scorers)))

def call_with_cb(get_scores: Callable, method: ScoreMeta):
scores = get_scores(data, method)
state.set_progress_value(next(progress_steps))
if state.is_interruption_requested():
raise InterruptException
return scores

method_scores = tuple(
(method, call_with_cb(get_method_scores, method)) for method in methods
)
scorer_scores = tuple(
(scorer, call_with_cb(get_scorer_scores, scorer)) for scorer in scorers
)
return Results(method_scores=method_scores, scorer_scores=scorer_scores)


class OWRank(OWWidget, ConcurrentWidgetMixin):
name = "Rank"
description = "Rank and filter data features by their relevance."
icon = "icons/Rank.svg"
Expand Down Expand Up @@ -211,20 +283,23 @@ class Warning(OWWidget.Warning):
renamed_variables = Msg(
"Variables with duplicated names have been renamed.")


def __init__(self):
super().__init__()
OWWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)
self.scorers = OrderedDict()
self.out_domain_desc = None
self.data = None
self.problem_type_mode = ProblemType.CLASSIFICATION

# results caches
self.scorers_results = {}
self.methods_results = {}

if not self.selected_methods:
self.selected_methods = {method.name for method in SCORES
if method.is_default}

# GUI

self.ranksModel = model = TableModel(parent=self) # type: TableModel
self.ranksView = view = TableView(self) # type: TableView
self.mainArea.layout().addWidget(view)
Expand Down Expand Up @@ -312,8 +387,9 @@ def set_data(self, data):
self.ranksModel.clear()
self.ranksModel.resetSorting(True)

self.get_method_scores.cache_clear() # pylint: disable=no-member
self.get_scorer_scores.cache_clear() # pylint: disable=no-member
self.scorers_results = {}
self.methods_results = {}
self.cancel()

self.Error.clear()
self.Information.clear()
Expand Down Expand Up @@ -358,7 +434,7 @@ def set_data(self, data):

def handleNewSignals(self):
self.setStatusMessage('Running')
self.updateScores()
self.update_scores()
self.setStatusMessage('')
self.on_select()

Expand All @@ -370,86 +446,75 @@ def set_learner(self, scorer, id): # pylint: disable=redefined-builtin
# Avoid caching a (possibly stale) previous instance of the same
# Scorer passed via the same signal
if id in self.scorers:
# pylint: disable=no-member
self.get_scorer_scores.cache_clear()
self.scorers_results = {}

self.scorers[id] = ScoreMeta(scorer.name, scorer.name, scorer,
ProblemType.from_variable(scorer.class_type),
False)

@memoize_method()
def get_method_scores(self, method):
# These errors often happen, but they result in nans, which
# are handled correctly by the widget
estimator = method.scorer()
data = self.data
# The widget handles infs and nans.
# Any errors in scorers need to be detected elsewhere.
with np.errstate(all="ignore"):
try:
scores = np.asarray(estimator(data))
except ValueError:
try:
scores = np.array([estimator(data, attr)
for attr in data.domain.attributes])
except ValueError:
log.error("%s doesn't work on this data", method.name)
scores = np.full(len(data.domain.attributes), np.nan)
else:
log.warning("%s had to be computed separately for each "
"variable", method.name)
return scores

@memoize_method()
def get_scorer_scores(self, scorer):
try:
scores = scorer.scorer.score_data(self.data).T
except (ValueError, TypeError):
log.error("%s doesn't work on this data", scorer.name)
scores = np.full((len(self.data.domain.attributes), 1), np.nan)

labels = ((scorer.shortname,)
if scores.shape[1] == 1 else
tuple(scorer.shortname + '_' + str(i)
for i in range(1, 1 + scores.shape[1])))
return scores, labels

def updateScores(self):
def _get_methods(self):
return [
method
for method in SCORES
if (
method.name in self.selected_methods
and method.problem_type == self.problem_type_mode
and (
not issparse(self.data.X)
or method.scorer.supports_sparse_data
)
)
]

def _get_scorers(self):
scorers = []
for scorer in self.scorers.values():
if scorer.problem_type in (
self.problem_type_mode,
ProblemType.UNSUPERVISED,
):
scorers.append(scorer)
else:
self.Error.inadequate_learner(
scorer.name, scorer.learner_adequacy_err_msg
)
return scorers

def update_scores(self):
if self.data is None:
self.ranksModel.clear()
self.Outputs.scores.send(None)
return

methods = [method
for method in SCORES
if (method.name in self.selected_methods and
method.problem_type == self.problem_type_mode and
(not issparse(self.data.X) or
method.scorer.supports_sparse_data))]

scorers = []
self.Error.inadequate_learner.clear()
for scorer in self.scorers.values():
if scorer.problem_type in (self.problem_type_mode, ProblemType.UNSUPERVISED):
scorers.append(scorer)
else:
self.Error.inadequate_learner(scorer.name, scorer.learner_adequacy_err_msg)

method_scores = tuple(self.get_method_scores(method)
for method in methods)
scorers = [
s for s in self._get_scorers() if s not in self.scorers_results
]
methods = [
m for m in self._get_methods() if m not in self.methods_results
]
self.start(run, self.data, methods, scorers)

scorer_scores, scorer_labels = (), ()
if scorers:
scorer_scores, scorer_labels = zip(*(self.get_scorer_scores(scorer)
for scorer in scorers))
scorer_labels = tuple(chain.from_iterable(scorer_labels))
def on_done(self, result: Results) -> None:
self.methods_results.update(result.method_scores)
self.scorers_results.update(result.scorer_scores)

labels = tuple(method.shortname for method in methods) + scorer_labels
methods = self._get_methods()
method_labels = tuple(m.shortname for m in methods)
method_scores = tuple(self.methods_results[m] for m in methods)

scores = [self.scorers_results[s] for s in self._get_scorers()]
scorer_scores, scorer_labels = zip(*scores) if scores else ((), ())

labels = method_labels + tuple(chain.from_iterable(scorer_labels))
model_array = np.column_stack(
([len(a.values) if a.is_discrete else np.nan
for a in self.data.domain.attributes],) +
(method_scores if method_scores else ()) +
(scorer_scores if scorer_scores else ())
(
[len(a.values) if a.is_discrete else np.nan
for a in self.data.domain.attributes],
)
+ method_scores
+ scorer_scores
)
for column, values in enumerate(model_array.T):
self.ranksModel.setExtremesFrom(column, values)
Expand All @@ -464,13 +529,21 @@ def updateScores(self):
if sort_column < len(labels):
# adds 1 for '#' (discrete count) column
self.ranksModel.sort(sort_column + 1, sort_order)
self.ranksView.horizontalHeader().setSortIndicator(sort_column + 1, sort_order)
self.ranksView.horizontalHeader().setSortIndicator(
sort_column + 1, sort_order
)
except ValueError:
pass

self.autoSelection()
self.Outputs.scores.send(self.create_scores_table(labels))

def on_exception(self, ex: Exception) -> None:
raise ex

def on_partial_result(self, result: Any) -> None:
pass

def on_select(self):
# Save indices of attributes in the original, unsorted domain
selected_rows = self.ranksView.selectionModel().selectedRows(0)
Expand Down Expand Up @@ -530,7 +603,7 @@ def methodSelectionChanged(self, state, method_name):
elif method_name in self.selected_methods:
self.selected_methods.remove(method_name)

self.updateScores()
self.update_scores()

def send_report(self):
if not self.data:
Expand Down Expand Up @@ -621,4 +694,4 @@ def migrate_context(cls, context, version):
WidgetPreview(OWRank).run(
set_learner=(RandomForestLearner(), (3, 'Learner', None)),
set_data=Table("heart_disease.tab"))
"""
"""

0 comments on commit 3a6e272

Please sign in to comment.