Skip to content

Commit

Permalink
Fix the Simple Model Comp condition (#2647)
Browse files Browse the repository at this point in the history
* Fix the Simple Model Comp condition - handle scorers that are not per class
  • Loading branch information
noamzbr committed Jul 26, 2023
1 parent de04894 commit deb16bc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
Expand Up @@ -371,7 +371,8 @@ def add_condition_gain_greater_than(self,
Used in classification models to limit condition only to given classes.
average : bool , default: False
Used in classification models to flag if to run condition on average of classes, or on
each class individually
each class individually. If any scorer that return a single value is used, this parameter
is ignored (will act as if average=True).
"""
name = f'Model performance gain over simple model is greater than {format_percent(min_allowed_gain)}'
if classes:
Expand All @@ -390,8 +391,13 @@ def condition(result: Dict, include_classes=None, average=False, max_gain=None,
task_type = result['type']
scorers_perfect = result['scorers_perfect']

# If the depth of the nested scores dict is 2, average is not relevant and is set to True
inner_dict = scores[list(scores.keys())[0]]
inner_inner_dict = inner_dict[list(inner_dict.keys())[0]]
force_average = isinstance(inner_inner_dict, Number)

passed_condition = True
if task_type in [TaskType.MULTICLASS, TaskType.BINARY] and not average:
if task_type in [TaskType.MULTICLASS, TaskType.BINARY] and not average and not force_average:
passed_metrics = {}
failed_classes = defaultdict(dict)
perfect_metrics = []
Expand Down Expand Up @@ -433,7 +439,7 @@ def condition(result: Dict, include_classes=None, average=False, max_gain=None,
passed_metrics = {}
failed_metrics = {}
perfect_metrics = []
if task_type in [TaskType.MULTICLASS, TaskType.BINARY]:
if task_type in [TaskType.MULTICLASS, TaskType.BINARY] and not force_average:
scores = average_scores(scores, include_classes)
for metric, models_scores in scores.items():
# If origin model is perfect, skip the gain calculation
Expand Down
Expand Up @@ -194,6 +194,23 @@ def test_condition_pass_for_multiclass_avg(iris_split_dataset_and_model):
))


def test_condition_pass_for_custom_scorer(iris_dataset_single_class, iris_random_forest_single_class):
train_ds = iris_dataset_single_class
test_ds = iris_dataset_single_class
clf = iris_random_forest_single_class
# Arrange
check = SimpleModelComparison(scorers=['f1'], strategy='most_frequent').add_condition_gain_greater_than(0.43)
# Act X
result = check.run(train_ds, test_ds, clf)
# Assert
assert_that(result.conditions_results, has_items(
equal_condition_result(
is_pass=True,
details='Found metrics with perfect score, no gain is calculated: [\'f1\']',
name='Model performance gain over simple model is greater than 43%')
))


def test_condition_pass_for_multiclass_avg_with_classes(iris_split_dataset_and_model):
train_ds, test_ds, clf = iris_split_dataset_and_model
# Arrange
Expand Down

0 comments on commit deb16bc

Please sign in to comment.