Skip to content

Commit

Permalink
Changed model_classes logger warning and made sure tabular only warns…
Browse files Browse the repository at this point in the history
… once (#2604)

* Changed model_classes logger warning, and made sure that tabular only prints that warning once by updating model_classes like in NLP

* Added fix for BoostingOverfit error message
  • Loading branch information
nirhutnik committed Jun 18, 2023
1 parent f15f3be commit b1fa92e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
4 changes: 3 additions & 1 deletion deepchecks/nlp/context.py
Expand Up @@ -343,7 +343,9 @@ def model_classes(self) -> t.List:
# If in infer_observed_and_model_labels we didn't find classes on model, or user didn't pass any,
# then using the observed
self._model_classes = self._observed_classes
get_logger().warning('Could not find model\'s classes, using the observed classes')
get_logger().warning('Could not find model\'s classes, using the observed classes. '
'In order to make sure the classes used by the model are inferred correctly, '
'please use the model_classes argument')
return self._model_classes

@property
Expand Down
35 changes: 19 additions & 16 deletions deepchecks/tabular/checks/model_evaluation/boosting_overfit.py
Expand Up @@ -32,10 +32,12 @@
class PartialBoostingModel:
"""Wrapper for boosting models which limits the number of estimators being used in the prediction."""

_UNSUPPORTED_MODEL_ERROR = (
'Check is relevant for Boosting models of type '
'{supported_models}, but received model of type {model_type}'
)
_UNSUPPORTED_MODEL_ERROR = \
'Check is relevant for Boosting models of type {supported_models}, but received model of type {model_type}'

_NO_MODEL_ERROR = \
'Check is relevant only when receiving the model, but predictions/probabilities were received instead. ' \
'In order to use this check, please pass the model to the run() method.'

_SUPPORTED_CLASSIFICATION_MODELS = (
'AdaBoostClassifier',
Expand Down Expand Up @@ -78,6 +80,16 @@ def __init__(self, model, step):
else:
self.model = model

@classmethod
def _raise_not_supported_model_error(cls, model_class):
if model_class != '_DummyModel':
raise ModelValidationError(cls._UNSUPPORTED_MODEL_ERROR.format(
supported_models=cls._SUPPORTED_MODELS,
model_type=model_class
))
else:
raise ModelValidationError(cls._NO_MODEL_ERROR)

def predict_proba(self, x):
if self.model_class in ['AdaBoostClassifier', 'GradientBoostingClassifier']:
return self.model.predict_proba(x)
Expand All @@ -88,10 +100,7 @@ def predict_proba(self, x):
elif self.model_class == 'CatBoostClassifier':
return self.model.predict_proba(x, ntree_end=self.step)
else:
raise ModelValidationError(self._UNSUPPORTED_MODEL_ERROR.format(
supported_models=self._SUPPORTED_CLASSIFICATION_MODELS,
model_type=self.model_class
))
self._raise_not_supported_model_error(self.model_class)

def predict(self, x):
if self.model_class in ['AdaBoostClassifier', 'GradientBoostingClassifier', 'AdaBoostRegressor',
Expand All @@ -104,10 +113,7 @@ def predict(self, x):
elif self.model_class in ['CatBoostClassifier', 'CatBoostRegressor']:
return self.model.predict(x, ntree_end=self.step)
else:
raise ModelValidationError(self._UNSUPPORTED_MODEL_ERROR.format(
supported_models=self._SUPPORTED_MODELS,
model_type=self.model_class
))
self._raise_not_supported_model_error(self.model_class)

@classmethod
def n_estimators(cls, model):
Expand All @@ -123,10 +129,7 @@ def n_estimators(cls, model):
elif model_class in ['CatBoostClassifier', 'CatBoostRegressor']:
return model.tree_count_
else:
raise ModelValidationError(cls._UNSUPPORTED_MODEL_ERROR.format(
supported_models=cls._SUPPORTED_MODELS,
model_type=model_class
))
cls._raise_not_supported_model_error(model_class=model_class)


class BoostingOverfit(TrainTestCheck):
Expand Down
7 changes: 5 additions & 2 deletions deepchecks/tabular/context.py
Expand Up @@ -291,8 +291,11 @@ def model_classes(self) -> t.List:
"""Return ordered list of possible label classes for classification tasks or None for regression."""
if self._model_classes is None and self.task_type in (TaskType.BINARY, TaskType.MULTICLASS):
# If in infer_task_type we didn't find classes on model, or user didn't pass any, then using the observed
get_logger().warning('Could not find model\'s classes, using the observed classes')
return self.observed_classes
get_logger().warning('Could not find model\'s classes, using the observed classes. '
'In order to make sure the classes used by the model are inferred correctly, '
'please use the model_classes argument')
self._model_classes = self.observed_classes

return self._model_classes

@property
Expand Down

0 comments on commit b1fa92e

Please sign in to comment.