Skip to content

Commit

Permalink
Merge pull request #24 from Garve/main
Browse files Browse the repository at this point in the history
A few fixes and a proposal for the ShannonExntropyReason
  • Loading branch information
koaning committed Dec 15, 2021
2 parents 8ff9bdf + 41de343 commit d1e22ad
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
13 changes: 3 additions & 10 deletions doubtlab/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,9 @@ def get_predicates(self, X, y=None):
predicates = doubt.get_predicates(X, y)
```
"""
df = pd.DataFrame({"i": range(len(X))})
for name, func in self.reasons.items():
df[f"predicate_{name}"] = func(X, y)
predicates = [c for c in df.columns if "predicate" in c]
return (
df[predicates]
.assign(s=lambda d: d.sum(axis=1))
.sort_values("s", ascending=False)
.drop(columns=["s"])
)
df = pd.DataFrame({f"predicate_{name}": func(X, y) for name, func in self.reasons.items()})
sorted_index = df.sum(axis=1).sort_values(ascending=False).index
return df.reindex(sorted_index)

def get_indices(self, X, y=None):
"""
Expand Down
48 changes: 47 additions & 1 deletion doubtlab/reason.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,52 @@ def __call__(self, X, y=None):
return np.where(rvals < self.probability, rvals, 0)


class ShannonEntropyReason:
"""
Assign doubt when the normalized Shannon entropy is too high, see
https://math.stackexchange.com/questions/395121/how-entropy-scales-with-sample-size
for a discussion.
Arguments:
model: scikit-learn classifier
threshold: confidence threshold for doubt assignment
Usage:
```python
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from doubtlab.ensemble import DoubtEnsemble
from doubtlab.reason import ShannonEntropyReason
X, y = load_iris(return_X_y=True)
model = LogisticRegression(max_iter=1_000)
model.fit(X, y)
doubt = DoubtEnsemble(reason = ShannonEntropyReason(model=model))
indices = doubt.get_indices(X, y)
```
"""

def __init__(self, model, threshold=0.5):
self.model = model
self.threshold = threshold

def __call__(self, X, y):
probas = self.model.predict_proba(X)
log_probas = self.model.predict_log_proba(X) / np.log(len(self.model.classes_))
entropies = -(probas * log_probas).sum(axis=1)
return np.where(entropies > self.threshold, entropies, 0)

@staticmethod
def from_proba(proba, n_classes, threshold=0.5):
"""Outputs a reason array from a prediction array, skipping the need for a model."""
entropies = -(proba * np.log(proba) / np.log(n_classes)).sum(axis=1)
return np.where(entropies > threshold, entropies, 0)


class WrongPredictionReason:
"""
Assign doubt when the model prediction doesn't match the label.
Expand Down Expand Up @@ -153,7 +199,7 @@ def __call__(self, X, y):

class MarginConfidenceReason:
"""
Assign doubt when a the difference between the top two most confident classes is too small.
Assign doubt when the difference between the top two most confident classes is too small.
Throws an error when there are only two classes.
Expand Down

0 comments on commit d1e22ad

Please sign in to comment.