In [242]:
!pip install --quiet fairlearn shap

In [243]:
import pandas as pd
import numpy as np
import sklearn
import fairlearn
import shap

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import datetime as dt

print(f"Import sucessful at {dt.datetime.now()}")

Import sucessful at 2025-08-10 16:00:14.713467


## Import data  

In [244]:
df = pd.read_csv("iris_combined.csv")
df.shape, df.columns, df.head()

((300, 7),
 Index(['flower_id', 'event_timestamp', 'species', 'sepal_length',
        'sepal_width', 'petal_length', 'petal_width'],
       dtype='object'),
    flower_id            event_timestamp    species  sepal_length  sepal_width  \
 0        110  2023-12-31 23:59:59+00:00  virginica           6.8          3.0   
 1        105  2023-12-31 23:59:59+00:00  virginica           6.7          3.3   
 2         20  2023-12-31 23:59:59+00:00     setosa           4.6          3.2   
 3        132  2023-12-31 23:59:59+00:00  virginica           6.9          3.1   
 4          1  2023-12-31 23:59:59+00:00     setosa           4.6          3.6   
 
    petal_length  petal_width  
 0           5.5          2.1  
 1           5.7          2.5  
 2           1.4          0.2  
 3           5.4          2.1  
 4           1.0          0.2  )

## Fairlearn  

- Introduce a “location” attribute in IRIS dataset with values 0 and 1 assigned randomly.       
- Incorporate fairlearn explainer with location as sensitive attribute.

### Prepare data for fairlearn  

In [245]:
# Add a column named 'location' and fill it with 1 and 0 randomly
df['location'] = np.random.randint(0, 2, df.shape[0])
df.shape, df.columns, df.head()

((300, 8),
 Index(['flower_id', 'event_timestamp', 'species', 'sepal_length',
        'sepal_width', 'petal_length', 'petal_width', 'location'],
       dtype='object'),
    flower_id            event_timestamp    species  sepal_length  sepal_width  \
 0        110  2023-12-31 23:59:59+00:00  virginica           6.8          3.0   
 1        105  2023-12-31 23:59:59+00:00  virginica           6.7          3.3   
 2         20  2023-12-31 23:59:59+00:00     setosa           4.6          3.2   
 3        132  2023-12-31 23:59:59+00:00  virginica           6.9          3.1   
 4          1  2023-12-31 23:59:59+00:00     setosa           4.6          3.6   
 
    petal_length  petal_width  location  
 0           5.5          2.1         0  
 1           5.7          2.5         1  
 2           1.4          0.2         0  
 3           5.4          2.1         0  
 4           1.0          0.2         0  )

In [246]:
df['location'].value_counts()

location
0    151
1    149
Name: count, dtype: int64

In [247]:
# Location value counts species wise
df.groupby('location')['species'].value_counts()

location  species   
0         virginica     58
          versicolor    47
          setosa        46
1         setosa        54
          versicolor    53
          virginica     42
Name: count, dtype: int64

In [248]:
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'location']]
y = df['species']
X.shape, y.shape

((300, 5), (300,))

y.value_counts().plot.bar(color='lavender', 
                          # edgecolor='black'
                         )

### Evaluate fairness metrics

In [249]:
from fairlearn.metrics import MetricFrame
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split

In [250]:
np.random.seed(42)

In [251]:
X_train, X_test, y_train, y_test, A_train, A_test = train_test_split(X, y, df['location'], test_size=0.4, random_state=42)
X_train.shape, X_test.shape, y_train.shape, y_test.shape, A_train.shape, A_test.shape

((180, 5), (120, 5), (180,), (120,), (180,), (120,))

In [252]:
# Train a decision tree model
# clf = DecisionTreeClassifier(min_samples_leaf=5, max_depth=2, random_state=42)
clf = LogisticRegression(random_state=42)
# clf = SVC(kernel='linear', C=0.1, random_state=42)

clf.fit(X_train, y_train)

In [253]:
class_probs = pd.DataFrame([clf.predict(X_test), clf.predict_proba(X_test)]).T
class_probs.shape, class_probs.head()

((120, 2),
             0                                                  1
 0   virginica  [0.00029747083427913843, 0.4494671549145677, 0...
 1  versicolor  [0.018656109514411063, 0.9520458342942902, 0.0...
 2  versicolor  [0.0014139458058033286, 0.5958298571567567, 0....
 3      setosa  [0.9823106324915788, 0.017689363493477498, 4.0...
 4      setosa  [0.9865807596364145, 0.013419234302697291, 6.0...)

In [254]:
clf.predict(X_test)[:5]

array(['virginica', 'versicolor', 'versicolor', 'setosa', 'setosa'],
      dtype=object)

In [255]:
y_pred = clf.predict(X_test)
print(y_pred[:5])
accuracy_score(y_test, y_pred)

['virginica' 'versicolor' 'versicolor' 'setosa' 'setosa']


0.9416666666666667

In [256]:
mf = MetricFrame(metrics=accuracy_score,
                y_true=y_test,
                y_pred=y_pred,
                sensitive_features=A_test)
mf

<fairlearn.metrics._metric_frame.MetricFrame at 0x7f0f42b2c5e0>

In [257]:
mf.overall.item()

0.9416666666666667

In [258]:
mf.by_group

location
0    0.910714
1    0.968750
Name: accuracy_score, dtype: float64

In [259]:
print("Accuracy difference:", mf.difference())  # max group accuracy - min group accuracy
print("Accuracy ratio:", mf.ratio())            # min / max

Accuracy difference: 0.0580357142857143
Accuracy ratio: 0.9400921658986175


In [260]:
print("Gap in pp:", 100 * mf.difference().item())

Gap in pp: 5.803571428571431


In [261]:
from fairlearn.metrics import MetricFrame, selection_rate
from sklearn.metrics import recall_score
import numpy as np

classes = clf.classes_
proba   = clf.predict_proba(X_test)

for c_idx, c in enumerate(classes):
    y_true_bin = (y_test == c).astype(int)
    y_pred_bin = (proba[:, c_idx] >= 0.5).astype(int)

    mf_c = MetricFrame(
        metrics={
            "selection_rate": selection_rate, 
            "recall": recall_score   # from sklearn
        },
        y_true=y_true_bin,
        y_pred=y_pred_bin,
        sensitive_features=A_test
    )

    print(f"\nClass: {c}")
    print(mf_c.by_group)
    print("Selection rate diff:", mf_c.difference()["selection_rate"])
    print("Recall diff:", mf_c.difference()["recall"])



Class: setosa
          selection_rate  recall
location                        
0               0.232143     1.0
1               0.406250     1.0
Selection rate diff: 0.17410714285714285
Recall diff: 0.0

Class: versicolor
          selection_rate    recall
location                          
0               0.303571  0.833333
1               0.312500  1.000000
Selection rate diff: 0.008928571428571452
Recall diff: 0.16666666666666663

Class: virginica
          selection_rate  recall
location                        
0               0.464286    0.92
1               0.281250    0.90
Selection rate diff: 0.1830357142857143
Recall diff: 0.020000000000000018


## SHAP  

- Explain in simple words what do the SHAP full dataset explainer plots (similar to what was shown in the demo) for class virginica mean.

In [270]:
for i, cls in enumerate(clf.classes_):
    print(f"{cls}: {i}")

setosa: 0
versicolor: 1
virginica: 2


In [267]:
import shap

shap.initjs()
explainer = shap.KernelExplainer(clf.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)
shap_values

Using 180 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.


  0%|          | 0/120 [00:00<?, ?it/s]

array([[[-0.01068946,  0.02400673, -0.01331726],
        [-0.01766078, -0.03048833,  0.04814912],
        [-0.30600733,  0.1129604 ,  0.19304693],
        [-0.00491509,  0.0039275 ,  0.00098759],
        [ 0.00068234, -0.01650768,  0.01582534]],

       [[-0.00764349, -0.00425729,  0.01190078],
        [-0.00317349,  0.00512342, -0.00194993],
        [-0.3041087 ,  0.51583808, -0.21172938],
        [-0.0042337 ,  0.07187696, -0.06764325],
        [-0.00107231,  0.00789613, -0.00682382]],

       [[-0.01332237,  0.04842657, -0.0351042 ],
        [-0.00733323, -0.0054979 ,  0.01283112],
        [-0.30985242,  0.30764571,  0.00220671],
        [-0.00762427, -0.0947265 ,  0.10235077],
        [ 0.00065843, -0.01558656,  0.01492813]],

       ...,

       [[-0.01443735,  0.0410095 , -0.02657216],
        [ 0.001421  ,  0.00476215, -0.00618315],
        [-0.31830945, -0.19905755,  0.51736699],
        [-0.0081545 , -0.18331828,  0.19147278],
        [ 0.00059436, -0.00714462,  0.00655026]],


In [266]:
# Shap plot for class 0 (setosa)

shap.force_plot(explainer.expected_value[0], shap_values[..., 0], X_test)

In [264]:
# Shap plot for class 1 (versicolor)
shap.force_plot(explainer.expected_value[0], shap_values[..., 1], X_test)

In [265]:
# Shap plot for class 2 (virginica)
shap.force_plot(explainer.expected_value[0], shap_values[..., 2], X_test)