Skip to content

Commit

Permalink
Limit permutation time (#512)
Browse files Browse the repository at this point in the history
prevent permutation importance from running if is projected to take too long
  • Loading branch information
noamzbr committed Jan 5, 2022
1 parent 44d69b9 commit 8076354
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
6 changes: 6 additions & 0 deletions deepchecks/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ class NumberOfFeaturesLimitError(DeepchecksBaseError):
"""Represents a situation when a dataset contains to much features to be used for calculation."""

pass


class DeepchecksTimeoutError(DeepchecksBaseError):
"""Represents a situation when a computation takes too long and is interrupted."""

pass
17 changes: 15 additions & 2 deletions deepchecks/utils/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#
# pylint: disable=inconsistent-quotes
"""Utils module containing feature importance calculations."""
import time
import typing as t
from warnings import warn
from functools import lru_cache
Expand All @@ -23,6 +24,7 @@
from deepchecks import base
from deepchecks import errors
from deepchecks.utils import validation
from deepchecks.utils.metrics import get_scorer_single
from deepchecks.utils.typing import Hashable
from deepchecks.utils.model import get_model_of_pipeline

Expand All @@ -39,6 +41,7 @@


_NUMBER_OF_FEATURES_LIMIT: int = 200
_PERMUTATION_IMPORTANCE_TIMEOUT: int = 120 # seconds
N_TOP_MESSAGE = '* showing only the top %s columns, you can change it using n_top_columns param'


Expand Down Expand Up @@ -83,14 +86,14 @@ def calculate_feature_importance_or_none(
try:
if model is None:
return None
# calculate feature importance if dataset has label and the model is fitted on it
# calculate feature importance if dataset has a label and the model is fitted on it
return calculate_feature_importance(
model=model,
dataset=dataset,
force_permutation=force_permutation,
permutation_kwargs=permutation_kwargs
)
except (errors.DeepchecksValueError, errors.NumberOfFeaturesLimitError) as error:
except (errors.DeepchecksValueError, errors.NumberOfFeaturesLimitError, errors.DeepchecksTimeoutError) as error:
# DeepchecksValueError:
# if model validation failed;
# if it was not possible to calculate features importance;
Expand Down Expand Up @@ -212,6 +215,16 @@ def _calc_importance(
n_samples = min(n_samples, dataset.n_samples)
dataset_sample_idx = dataset.label_col.sample(n_samples, random_state=random_state).index

scorer = get_scorer_single(model, dataset, multiclass_avg=False)

start_time = time.time()
scorer(model, dataset)
calc_time = time.time() - start_time

if calc_time * n_repeats * len(dataset.features) > _PERMUTATION_IMPORTANCE_TIMEOUT:
raise errors.DeepchecksTimeoutError('Permutation importance calculation was not projected to finish in'
f' {_PERMUTATION_IMPORTANCE_TIMEOUT} seconds.')

r = permutation_importance(
model,
dataset.features_columns.loc[dataset_sample_idx, :],
Expand Down

0 comments on commit 8076354

Please sign in to comment.