Skip to content

b-bayrak/PertCF

Repository files navigation

PertCF

PyPI version Python versions CI License: MIT Paper

PertCF is a perturbation-based counterfactual explanation method that combines SHAP feature attribution with nearest-neighbour search to generate high-quality, stable counterfactuals for tabular classification models.

What is a counterfactual explanation?
Given a model's prediction for an instance x, a counterfactual x' is the minimal change to x that would flip the prediction. For example: "If Leo earned $500 more per month, his loan application would be accepted."


Why PertCF?

Feature PertCF DiCE CF-SHAP
Multi-class support
SHAP-weighted distance Partial
Custom domain knowledge
Works with sklearn, PyTorch, Keras Partial
No external server needed

PertCF outperforms DiCE and CF-SHAP on dissimilarity and instability across both benchmark datasets (South German Credit, User Knowledge Modeling). See the paper for full results.


Installation

pip install pertcf

For PyTorch or Keras model support:

pip install pertcf[torch]      # + PyTorch adapter
pip install pertcf[tensorflow] # + Keras/TF adapter
pip install pertcf[viz]        # + matplotlib/seaborn for plots

Requirements: Python ≥ 3.9, numpy, pandas, scikit-learn, shap.
No Java. No REST server. No external frameworks.


Quick Start (30 seconds)

import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from pertcf import PertCFExplainer

# 1. Load data and train a model (example: German Credit dataset)
df = pd.read_csv("german_credit.csv")
X = df.drop(columns=["credit_risk"])
y = df["credit_risk"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = GradientBoostingClassifier(random_state=42).fit(X_train, y_train)

# 2. Create and fit the explainer
explainer = PertCFExplainer(
    model=clf,
    X_train=X_train,
    y_train=y_train,
    categorical_features=["purpose", "personal_status", "housing"],
    label="credit_risk",
    num_iter=5,
    coef=5,
)
explainer.fit()

# 3. Explain a prediction
instance = X_test.iloc[0].copy()
instance["credit_risk"] = clf.predict(X_test.iloc[[0]])[0]

counterfactual = explainer.explain(instance)
print("Original:       ", instance.to_dict())
print("Counterfactual: ", counterfactual.to_dict())

Feature Highlights

Works with any classifier

# scikit-learn (auto-detected)
from sklearn.ensemble import RandomForestClassifier
explainer = PertCFExplainer(model=RandomForestClassifier().fit(X, y), ...)

# PyTorch
from pertcf import PertCFExplainer
explainer = PertCFExplainer(
    model=my_torch_model,
    class_names=["bad", "good"],
    ...
)

# Keras / TensorFlow
explainer = PertCFExplainer(
    model=my_keras_model,
    class_names=["bad", "good"],
    ...
)

# Any callable
explainer = PertCFExplainer(
    model=my_model,
    predict_fn=lambda X: my_model.predict(X),
    predict_proba_fn=lambda X: my_model.predict_proba(X),
    class_names=["bad", "good"],
    ...
)

Domain knowledge via custom similarity matrices

# Model the relationship between credit purposes
explainer = PertCFExplainer(
    model=clf,
    ...
    similarity_matrices={
        "purpose": {
            ("car", "furniture"): 0.7,   # similar purposes
            ("car", "education"): 0.2,   # less similar
        }
    }
)

Pre-computed SHAP values (for large datasets)

import shap
# Compute SHAP once, reuse across experiments
shap_exp = shap.TreeExplainer(clf)
shap_vals = shap_exp.shap_values(X_train)
# … build shap_df with shape (n_classes, n_features) …

explainer = PertCFExplainer(
    model=clf, shap_values=shap_df, ...
)
explainer.fit()  # skips SHAP computation

Built-in benchmark

# Reproduce the paper's Table 2 results
results = explainer.benchmark(X_test, n=100, coef=5, verbose=True)
# Results (n=100/100):
#   dissimilarity       : 0.0517
#   sparsity            : 0.7983
#   runtime_mean        : 0.4069

Evaluation metrics

from pertcf import metrics

print(metrics.dissimilarity(query, cf, explainer.sim_fn, cf_class))
print(metrics.sparsity(query, cf))
print(metrics.instability(query, cf, explainer))

# All at once:
results = metrics.evaluate(queries, counterfactuals, explainer)

How PertCF Works

1. Compute SHAP values per class → class-specific feature importance weights
2. For query x:
   a. Find Nearest Unlike Neighbour (NUN) using SHAP-weighted similarity
   b. Perturb x toward NUN using SHAP weights:
      - Numeric:     p_f = x_f + shap_target_f * (nun_f - x_f)
      - Categorical: p_f = nun_f  if sim(x_f, nun_f) < 0.5  else  x_f
   c. If perturbed instance flips class → refine (approach source)
   d. If not → push harder (approach target)
   e. Terminate when step size < threshold or max iterations reached

See the paper for full algorithmic details.


Examples

Notebook Dataset Description
quickstart_german_credit.ipynb South German Credit Basic usage, benchmark, comparison to DiCE
quickstart_knowledge.ipynb User Knowledge Modeling Multi-class classification
custom_similarity.ipynb German Credit Domain knowledge with custom similarity matrices

Launch in Colab: Open in Colab


Parameter Guide

Parameter Default Description
num_iter 10 Max perturbation iterations. Higher → better quality, slower.
coef 5 Step-size threshold coefficient. Higher → finer convergence.

Recommended settings from the paper:

Dataset num_iter coef Notes
South German Credit 5 5 Most categorical features
User Knowledge Modeling 5 3 All numeric features

Citation

If you use PertCF in your research, please cite:

@inproceedings{bayrak2023pertcf,
  title     = {PertCF: A Perturbation-Based Counterfactual Generation Approach},
  author    = {Bayrak, Bet{\"u}l and Bach, Kerstin},
  booktitle = {Artificial Intelligence XXXVIII},
  series    = {Lecture Notes in Computer Science},
  volume    = {14381},
  pages     = {174--187},
  year      = {2023},
  publisher = {Springer, Cham},
  doi       = {10.1007/978-3-031-47994-6_13}
}

License

MIT © Betül Bayrak

This work was supported by the Research Council of Norway through the EXAIGON project (ID 304843).

About

Perturbation-based counterfactual explanations with SHAP-weighted feature attribution

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages