Skip to content

Commit

Permalink
move the root all data statistics to ErrorReport and ErrorAnalysisData
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Apr 13, 2022
1 parent 96016cb commit 7e08266
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 25 deletions.
23 changes: 23 additions & 0 deletions erroranalysis/erroranalysis/_internal/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,29 @@ class Metrics(str, Enum):
ERROR_RATE = 'error_rate'


class MetricKeys(str, Enum):
"""Provide keys for properties related to metrics.
"""
METRIC_NAME = 'metricName'
METRIC_VALUE = 'metricValue'


class RootKeys(str, Enum):
"""Provide keys for the root cohort.
"""
METRIC_NAME = MetricKeys.METRIC_NAME.value
METRIC_VALUE = MetricKeys.METRIC_VALUE.value
TOTAL_SIZE = 'totalSize'
ERROR_COVERAGE = 'errorCoverage'


class TreeNode(str, Enum):
"""Provide the tree node properties.
"""
METRIC_NAME = MetricKeys.METRIC_NAME.value
METRIC_VALUE = MetricKeys.METRIC_VALUE.value


metric_to_display_name = {
Metrics.ACCURACY_SCORE: 'Accuracy score',
Metrics.MEAN_ABSOLUTE_ERROR: 'Mean absolute error',
Expand Down
8 changes: 4 additions & 4 deletions erroranalysis/erroranalysis/_internal/matrix_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from erroranalysis._internal.cohort_filter import filter_from_cohort
from erroranalysis._internal.constants import (DIFF, PRED_Y, ROW_INDEX, TRUE_Y,
MatrixParams, Metrics,
ModelTask,
MatrixParams, MetricKeys,
Metrics, ModelTask,
metric_to_display_name)
from erroranalysis._internal.metrics import (get_ordered_classes,
is_multi_agg_metric,
Expand All @@ -26,8 +26,8 @@
INTERVAL_MIN = 'intervalMin'
INTERVAL_MAX = 'intervalMax'
MATRIX = 'matrix'
METRIC_VALUE = 'metricValue'
METRIC_NAME = 'metricName'
METRIC_VALUE = MetricKeys.METRIC_VALUE
METRIC_NAME = MetricKeys.METRIC_NAME
VALUES = 'values'
PRECISION = 100
TP = 'tp'
Expand Down
6 changes: 3 additions & 3 deletions erroranalysis/erroranalysis/_internal/surrogate_error_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
PRED_Y, ROW_INDEX,
SPLIT_FEATURE, SPLIT_INDEX,
TRUE_Y, CohortFilterMethods,
Metrics, ModelTask,
Metrics, ModelTask, TreeNode,
error_metrics, f1_metrics,
metric_to_display_name,
precision_metrics,
Expand Down Expand Up @@ -620,8 +620,8 @@ def get_json_node(arg, condition, error, nodeid, method, node_name,
"size": float(total),
"sourceRowKeyHash": "hashkey", # Note: remove this eventually
"success": float(success), # Note: remove this eventually
"metricName": metric_name,
"metricValue": float(metric_value),
TreeNode.METRIC_NAME: metric_name,
TreeNode.METRIC_VALUE: float(metric_value),
"isErrorMetric": is_error_metric
}

Expand Down
64 changes: 61 additions & 3 deletions erroranalysis/erroranalysis/analyzer/error_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
mutual_info_regression)
from sklearn.preprocessing import OrdinalEncoder

from erroranalysis._internal.constants import MatrixParams, Metrics, ModelTask
from erroranalysis._internal.constants import (MatrixParams, Metrics,
ModelTask, RootKeys,
metric_to_display_name)
from erroranalysis._internal.matrix_filter import \
compute_matrix as _compute_matrix
from erroranalysis._internal.metrics import metric_to_func
from erroranalysis._internal.surrogate_error_tree import \
compute_error_tree as _compute_error_tree
from erroranalysis._internal.utils import generate_random_unique_indexes
Expand All @@ -23,6 +26,7 @@

BIN_THRESHOLD = MatrixParams.BIN_THRESHOLD
IMPORTANCES_THRESHOLD = 50000
ROOT_COVERAGE = 100


class BaseAnalyzer(ABC):
Expand Down Expand Up @@ -298,7 +302,8 @@ def create_error_report(self,
max_depth=None,
num_leaves=None,
min_child_samples=None,
compute_importances=False):
compute_importances=False,
compute_root_stats=False):
"""Creates the error analysis ErrorReport.
The ErrorReport contains the importances, heatmap and tree view json.
Expand All @@ -316,6 +321,9 @@ def create_error_report(self,
:type min_child_samples: int
:param compute_importances: If true, computes and adds the
correlation of features and the error to the ErrorReport.
:type compute_importances: bool
:param compute_root_stats: If true, computes and adds the root stats.
:type compute_root_stats: bool
:return: The computed error analysis ErrorReport.
:rtype: dict
"""
Expand All @@ -333,10 +341,14 @@ def create_error_report(self,
importances = None
if compute_importances:
importances = self.compute_importances()
root_stats = None
if compute_root_stats:
root_stats = self.compute_root_stats()
return ErrorReport(tree, matrix,
tree_features=self.feature_names,
matrix_features=filter_features,
importances=importances)
importances=importances,
root_stats=root_stats)

def compute_importances(self):
"""Compute the importances or correlation between features and error.
Expand Down Expand Up @@ -375,6 +387,32 @@ def compute_importances(self):
else:
return mutual_info_regression(input_data, diff).tolist()

def compute_root_stats(self):
"""Compute the root all data statistics.
:return: The computed root statistics.
:rtype: dict
"""
if self.metric != Metrics.ERROR_RATE:
metric_func = metric_to_func[self.metric]
metric_value = metric_func(self.pred_y, self.true_y)
else:
total = len(self.true_y)
if total == 0:
metric_value = 0
else:
diff = self.get_diff()
error = sum(diff)
metric_value = (error / total) * 100
metric_name = metric_to_display_name[self.metric]
root_stats = {
RootKeys.METRIC_NAME: metric_name,
RootKeys.METRIC_VALUE: metric_value,
RootKeys.TOTAL_SIZE: len(self.true_y),
RootKeys.ERROR_COVERAGE: ROOT_COVERAGE
}
return root_stats

def update_metric(self, metric):
"""Update the metric used by the error analyzer.
Expand Down Expand Up @@ -421,6 +459,15 @@ def get_diff(self):
"""
pass

@abstractmethod
def pred_y(self):
"""Abstract method to get the predicted y labels.
:return: The predicted y labels.
:rtype: numpy.ndarray or list[] or pandas.Series
"""
pass


class ModelAnalyzer(BaseAnalyzer):
"""ModelAnalyzer Class.
Expand Down Expand Up @@ -524,6 +571,17 @@ def get_diff(self):
else:
return self.model.predict(self.dataset) - self.true_y

@property
def pred_y(self):
"""Get the computed predicted y values.
Note for ModelAnalyzer these are computed on the fly.
:return: The computed predicted y values.
:rtype: numpy.ndarray or list[] or pandas.Series
"""
return self.model.predict(self.dataset)


class PredictionsAnalyzer(BaseAnalyzer):
"""PredictionsAnalyzer Class.
Expand Down
51 changes: 36 additions & 15 deletions erroranalysis/erroranalysis/report/error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
_ErrorReportVersion1 = '1.0'
_ErrorReportVersion2 = '2.0'
_ErrorReportVersion3 = '3.0'
_ErrorReportVersion4 = '4.0'
_AllVersions = [_ErrorReportVersion1,
_ErrorReportVersion2,
_ErrorReportVersion3]
_ErrorReportVersion3,
_ErrorReportVersion4]
_VERSION = 'version'

TREE = 'tree'
Expand All @@ -19,6 +21,7 @@
IMPORTANCES = 'importances'
ID = 'id'
METADATA = 'metadata'
ROOT_STATS = 'root_stats'


def json_converter(obj):
Expand Down Expand Up @@ -57,19 +60,18 @@ def as_error_report(error_dict):
return ErrorReport(error_dict[TREE],
error_dict[MATRIX],
error_dict[ID])
elif IMPORTANCES not in error_dict:
return ErrorReport(error_dict[TREE],
error_dict[MATRIX],
error_dict[TREE_FEATURES],
error_dict[MATRIX_FEATURES],
id=error_dict[ID])
else:
return ErrorReport(error_dict[TREE],
error_dict[MATRIX],
error_dict[TREE_FEATURES],
error_dict[MATRIX_FEATURES],
importances=error_dict[IMPORTANCES],
id=error_dict[ID])
extra_args = {}
if IMPORTANCES in error_dict:
extra_args[IMPORTANCES] = error_dict[IMPORTANCES]
if ROOT_STATS in error_dict:
extra_args[ROOT_STATS] = error_dict[ROOT_STATS]
if ID in error_dict:
extra_args[ID] = error_dict[ID]
return ErrorReport(error_dict[TREE],
error_dict[MATRIX],
error_dict[TREE_FEATURES],
error_dict[MATRIX_FEATURES],
**extra_args)
else:
return error_dict

Expand All @@ -93,6 +95,7 @@ def __init__(self,
tree_features=None,
matrix_features=None,
importances=None,
root_stats=None,
id=None):
"""Defines the ErrorReport, which contains the tree and matrix filter.
Expand All @@ -109,6 +112,8 @@ def __init__(self,
:param importances: The feature importances calculated using mutual
information with the error on the true labels.
:type importances: list[float]
:param root_stats: The statistics for the root all data cohort.
:type root_stats: dict
:param id: The unique identifier for the ErrorReport.
A new unique id is created if none is specified.
:type id: str
Expand All @@ -119,7 +124,8 @@ def __init__(self,
self._tree_features = tree_features
self._matrix_features = matrix_features
self._importances = importances
self._metadata = {_VERSION: _ErrorReportVersion3}
self._root_stats = root_stats
self._metadata = {_VERSION: _ErrorReportVersion4}

@property
def __dict__(self):
Expand All @@ -139,6 +145,8 @@ def __dict__(self):
METADATA: self._metadata}
if self._importances is not None:
error_report_dict[IMPORTANCES] = self._importances
if self._root_stats is not None:
error_report_dict[ROOT_STATS] = self._root_stats
return error_report_dict

@property
Expand Down Expand Up @@ -197,6 +205,19 @@ def importances(self):
"""
return self._importances

@property
def root_stats(self):
"""Returns the root cohort statistics.
The root cohort statistics are displayed for both the
heatmap and tree view. They include the metric name
and value for the all data cohort.
:return: The root all data cohort statistics.
:rtype: dict
"""
return self._root_stats

@property
def id(self):
"""Returns the unique identifier for this ErrorReport.
Expand Down
1 change: 1 addition & 0 deletions erroranalysis/tests/test_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def run_error_analyzer(model, X_test, y_test, feature_names,
assert ea_deserialized.tree_features == report1.tree_features
assert ea_deserialized.matrix_features == report1.matrix_features
assert ea_deserialized.importances == report1.importances
assert ea_deserialized.root_stats == report1.root_stats

# validate error report does not modify original dataset in ModelAnalyzer
if isinstance(X_test, pd.DataFrame):
Expand Down

0 comments on commit 7e08266

Please sign in to comment.