This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
/
survival_fairness.py
108 lines (94 loc) · 3.86 KB
/
survival_fairness.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from credoai.artifacts import TabularData
from credoai.evaluators import Evaluator
from credoai.evaluators.utils.validation import (
check_data_instance,
check_existence,
)
from connect.evidence import TableContainer
from credoai.modules import CoxPH
from credoai.modules.stats_utils import columns_from_formula
from credoai.utils import ValidationError
class SurvivalFairness(Evaluator):
"""
Calculate Survival fairness
Parameters
----------
CoxPh_kwargs : _type_, optional
_description_, by default None
confounds : _type_, optional
_description_, by default None
"""
required_artifacts = ["model", "assessment_data", "sensitive_feature"]
def __init__(self, CoxPh_kwargs=None, confounds=None):
if CoxPh_kwargs is None:
CoxPh_kwargs = {"duration_col": "duration", "event_col": "event"}
self.coxPh_kwargs = CoxPh_kwargs
self.confounds = confounds
self.stats = []
def _validate_arguments(self):
check_data_instance(self.assessment_data, TabularData)
check_existence(self.assessment_data.sensitive_features, "sensitive_features")
# check for columns existences
expected_columns = None
if self.confounds:
expected_columns = set(self.confounds)
if "formula" in self.coxPh_kwargs:
expected_columns |= columns_from_formula(self.coxPh_kwargs["formula"])
expected_columns -= {"predictions", "predicted_probabilities"}
if expected_columns is not None:
missing_columns = expected_columns.difference(self.assessment_data.X)
if missing_columns:
raise ValidationError(
f"Columns supplied to CoxPh formula not found in data. Columns are: {missing_columns}"
)
def _setup(self):
self.y_pred = self.model.predict(self.assessment_data.X)
self.sensitive_name = self.assessment_data.sensitive_feature.name
self.survival_df = self.assessment_data.X.copy()
self.survival_df["predictions"] = self.y_pred
self.survival_df = self.survival_df.join(self.assessment_data.sensitive_feature)
# add probabilities
try:
self.y_prob = self.model.predict_proba(self.assessment_data.X)
self.survival_df["predicted_probabilities"] = self.y_prob
except:
self.y_prob = None
return self
def evaluate(self):
self._run_survival_analyses()
result_dfs = (
self._get_summaries()
+ self._get_expected_survival()
+ self._get_survival_curves()
)
sens_feat_label = {"sensitive_feature": self.sensitive_name}
self.results = [
TableContainer(df, **self.get_container_info(labels=sens_feat_label))
for df in result_dfs
]
return self
def _run_survival_analyses(self):
if "formula" in self.coxPh_kwargs:
cph = CoxPH()
cph.fit(self.survival_df, **self.coxPh_kwargs)
self.stats.append(cph)
return
model_predictions = (
["predictions", "predicted_probabilities"]
if self.y_prob is not None
else ["predictions"]
)
for pred in model_predictions:
run_kwargs = self.coxPh_kwargs.copy()
run_kwargs["formula"] = f"{self.sensitive_name} * {pred}"
if self.confounds:
run_kwargs["formula"] += " + ".join(["", *self.confounds])
cph = CoxPH()
cph.fit(self.survival_df, **run_kwargs)
self.stats.append(cph)
def _get_expected_survival(self):
return [s.expected_survival() for s in self.stats]
def _get_summaries(self):
return [s.summary() for s in self.stats]
def _get_survival_curves(self):
return [s.survival_curves() for s in self.stats]