Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Feat/multiclass metrics #280

Closed
wants to merge 9 commits into from
10 changes: 5 additions & 5 deletions credoai/artifacts/model/base_model.py
Expand Up @@ -40,7 +40,11 @@ def __init__(
self.model_info = get_model_info(model_like)
self._validate(necessary_functions)
self._build(possible_functions)
self._update_functionality()
self.__post_init__()

def __post_init__(self):
"""Optional custom functionality to call after Base Model init"""
pass

@property
def tags(self):
Expand Down Expand Up @@ -88,7 +92,3 @@ def _add_functionality(self, key: str):
func = getattr(self.model_like, key, None)
if func:
self.__dict__[key] = func

def _update_functionality(self):
"""Optional framework specific functionality update"""
pass
18 changes: 11 additions & 7 deletions credoai/artifacts/model/classification_model.py
@@ -1,7 +1,6 @@
"""Model artifact wrapping any classification model"""
from .base_model import Model

PREDICT_PROBA_FRAMEWORKS = ["sklearn", "xgboost"]
from .constants_model import SKLEARN_LIKE_FRAMEWORKS


class ClassificationModel(Model):
Expand All @@ -24,20 +23,25 @@ class ClassificationModel(Model):

def __init__(self, name: str, model_like=None, tags=None):
super().__init__(
"Classification",
"classification",
["predict", "predict_proba"],
["predict"],
name,
model_like,
tags,
)

def _update_functionality(self):
def __post_init__(self):
"""Conditionally updates functionality based on framework"""
if self.model_info["framework"] in PREDICT_PROBA_FRAMEWORKS:
if self.model_info["framework"] in SKLEARN_LIKE_FRAMEWORKS:
func = getattr(self, "predict_proba", None)
if func and len(self.model_like.classes_) == 2:
self.__dict__["predict_proba"] = lambda x: func(x)[:, 1]
if len(self.model_like.classes_) == 2:
self.type = "binary_classification"
# if binary, replace probability array with one-dimensional vector
if func:
self.__dict__["predict_proba"] = lambda x: func(x)[:, 1]
else:
self.type = "multiclass_classification"


class DummyClassifier:
Expand Down
2 changes: 1 addition & 1 deletion credoai/artifacts/model/comparison_model.py
Expand Up @@ -23,7 +23,7 @@ class ComparisonModel(Model):

def __init__(self, name: str, model_like=None):
super().__init__(
"ComparisonModel",
"comparison",
["compare"],
["compare"],
name,
Expand Down
7 changes: 7 additions & 0 deletions credoai/artifacts/model/constants_model.py
@@ -0,0 +1,7 @@
SKLEARN_LIKE_FRAMEWORKS = ["sklearn", "xgboost"]
MODEL_TYPES = [
"regression",
"binary_classification",
"multiclass_classification",
"comparison",
]
2 changes: 1 addition & 1 deletion credoai/artifacts/model/regression_model.py
Expand Up @@ -21,7 +21,7 @@ class RegressionModel(Model):
"""

def __init__(self, name: str, model_like=None, tags=None):
super().__init__("Regression", ["predict"], ["predict"], name, model_like, tags)
super().__init__("regression", ["predict"], ["predict"], name, model_like, tags)


class DummyRegression:
Expand Down
4 changes: 3 additions & 1 deletion credoai/evaluators/fairness.py
Expand Up @@ -304,7 +304,9 @@ def _process_metrics(self, metrics):
for metric in metrics:
if isinstance(metric, str):
metric_name = metric
metric = find_metrics(metric, MODEL_METRIC_CATEGORIES)
metric_categories_to_include = MODEL_METRIC_CATEGORIES
metric_categories_to_include.append(self.model.type.upper())
metric = find_metrics(metric, metric_categories_to_include)
if len(metric) == 1:
metric = metric[0]
elif len(metric) == 0:
Expand Down
18 changes: 16 additions & 2 deletions credoai/modules/constants_metrics.py
Expand Up @@ -19,10 +19,13 @@

THRESHOLD_METRIC_CATEGORIES = ["BINARY_CLASSIFICATION_THRESHOLD"]

MODEL_METRIC_CATEGORIES = [
MODEL_TYPE_METRIC_CATEGORIES = [
fabrizio-credo marked this conversation as resolved.
Show resolved Hide resolved
"BINARY_CLASSIFICATION",
"MULTICLASS_CLASSIFICATION",
"REGRESSION",
]

MODEL_METRIC_CATEGORIES = [
"CLUSTERING",
"FAIRNESS",
] + THRESHOLD_METRIC_CATEGORIES
Expand All @@ -35,12 +38,17 @@
]

METRIC_CATEGORIES = (
MODEL_METRIC_CATEGORIES + THRESHOLD_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES
MODEL_TYPE_METRIC_CATEGORIES
+ MODEL_METRIC_CATEGORIES
+ THRESHOLD_METRIC_CATEGORIES
+ NON_MODEL_METRIC_CATEGORIES
)

SCALAR_METRIC_CATEGORIES = MODEL_METRIC_CATEGORIES + NON_MODEL_METRIC_CATEGORIES

# MODEL METRICS
# Define Binary classification name mapping.
# Binary classification metrics must have a similar signature to sklearn metrics
BINARY_CLASSIFICATION_FUNCTIONS = {
"false_positive_rate": fl_metrics.false_positive_rate,
"false_negative_rate": fl_metrics.false_negative_rate,
Expand All @@ -61,6 +69,12 @@
"gini_coefficient": gini_coefficient_discriminatory,
}

# Define Multiclass classification name mapping.
# Multiclass classification metrics must have a similar signature to sklearn metrics
MULTICLASS_CLASSIFICATION_FUNCTIONS = {
"precision_score": partial(sk_metrics.precision_score, average="macro")
}

# Define Fairness Metric Name Mapping
# Fairness metrics must have a similar signature to fairlearn.metrics.equalized_odds_difference
# (they should take sensitive_features and method)
Expand Down
10 changes: 9 additions & 1 deletion credoai/modules/metrics.py
Expand Up @@ -127,7 +127,14 @@ def find_metrics(metric_name, metric_category=None):
# Convert To List of Metrics
BINARY_CLASSIFICATION_METRICS = metrics_from_dict(
BINARY_CLASSIFICATION_FUNCTIONS,
"BINARY_CLASSIFICATION",
"binary_classification",
PROBABILITY_FUNCTIONS,
METRIC_EQUIVALENTS,
)

MULTICLASS_CLASSIFICATION_METRICS = metrics_from_dict(
MULTICLASS_CLASSIFICATION_FUNCTIONS,
"MULTICLASS_CLASSIFICATION",
PROBABILITY_FUNCTIONS,
METRIC_EQUIVALENTS,
)
Expand Down Expand Up @@ -168,6 +175,7 @@ def find_metrics(metric_name, metric_category=None):

ALL_METRICS = (
list(BINARY_CLASSIFICATION_METRICS.values())
+ list(MULTICLASS_CLASSIFICATION_METRICS.values())
+ list(THRESHOLD_VARYING_METRICS.values())
+ list(FAIRNESS_METRICS.values())
+ list(DATASET_METRICS.values())
Expand Down
1 change: 1 addition & 0 deletions credoai/utils/model_utils.py
Expand Up @@ -23,6 +23,7 @@ def get_generic_classifier():


def get_model_info(model):
"""Returns basic information about model info"""
try:
framework = model.__class__.__module__.split(".")[0]
except AttributeError:
Expand Down