diff --git a/deepchecks/tabular/checks/model_evaluation/simple_model_comparison.py b/deepchecks/tabular/checks/model_evaluation/simple_model_comparison.py index b5e6438274..b92d23cd84 100644 --- a/deepchecks/tabular/checks/model_evaluation/simple_model_comparison.py +++ b/deepchecks/tabular/checks/model_evaluation/simple_model_comparison.py @@ -371,7 +371,8 @@ def add_condition_gain_greater_than(self, Used in classification models to limit condition only to given classes. average : bool , default: False Used in classification models to flag if to run condition on average of classes, or on - each class individually + each class individually. If any scorer that return a single value is used, this parameter + is ignored (will act as if average=True). """ name = f'Model performance gain over simple model is greater than {format_percent(min_allowed_gain)}' if classes: @@ -390,8 +391,13 @@ def condition(result: Dict, include_classes=None, average=False, max_gain=None, task_type = result['type'] scorers_perfect = result['scorers_perfect'] + # If the depth of the nested scores dict is 2, average is not relevant and is set to True + inner_dict = scores[list(scores.keys())[0]] + inner_inner_dict = inner_dict[list(inner_dict.keys())[0]] + force_average = isinstance(inner_inner_dict, Number) + passed_condition = True - if task_type in [TaskType.MULTICLASS, TaskType.BINARY] and not average: + if task_type in [TaskType.MULTICLASS, TaskType.BINARY] and not average and not force_average: passed_metrics = {} failed_classes = defaultdict(dict) perfect_metrics = [] @@ -433,7 +439,7 @@ def condition(result: Dict, include_classes=None, average=False, max_gain=None, passed_metrics = {} failed_metrics = {} perfect_metrics = [] - if task_type in [TaskType.MULTICLASS, TaskType.BINARY]: + if task_type in [TaskType.MULTICLASS, TaskType.BINARY] and not force_average: scores = average_scores(scores, include_classes) for metric, models_scores in scores.items(): # If origin model is perfect, skip the gain calculation diff --git a/tests/tabular/checks/model_evaluation/simple_model_comparison_test.py b/tests/tabular/checks/model_evaluation/simple_model_comparison_test.py index 017d89adc2..d8e8fdee08 100644 --- a/tests/tabular/checks/model_evaluation/simple_model_comparison_test.py +++ b/tests/tabular/checks/model_evaluation/simple_model_comparison_test.py @@ -194,6 +194,23 @@ def test_condition_pass_for_multiclass_avg(iris_split_dataset_and_model): )) +def test_condition_pass_for_custom_scorer(iris_dataset_single_class, iris_random_forest_single_class): + train_ds = iris_dataset_single_class + test_ds = iris_dataset_single_class + clf = iris_random_forest_single_class + # Arrange + check = SimpleModelComparison(scorers=['f1'], strategy='most_frequent').add_condition_gain_greater_than(0.43) + # Act X + result = check.run(train_ds, test_ds, clf) + # Assert + assert_that(result.conditions_results, has_items( + equal_condition_result( + is_pass=True, + details='Found metrics with perfect score, no gain is calculated: [\'f1\']', + name='Model performance gain over simple model is greater than 43%') + )) + + def test_condition_pass_for_multiclass_avg_with_classes(iris_split_dataset_and_model): train_ds, test_ds, clf = iris_split_dataset_and_model # Arrange