# Local Model Interpretability: **dice-ml**

### References
- [pypi: dice-ml](https://pypi.org/project/dice-ml/)
- [dice-ml - Diverse Counterfactual Explanations for ML Models](https://coderzcolumn.com/tutorials/machine-learning/dice-ml-diverse-counterfactual-explanations-for-ml-models)
- [Advanced options to customize Counterfactual Explanations](https://interpret.ml/DiCE/notebooks/DiCE_with_advanced_options.html)
- [Generating counterfactuals for multi-class classification and regression models](https://interpret.ml/DiCE/notebooks/DiCE_multiclass_classification_and_regression.html#Regression)

In [2]:
%%capture
pip install dice-ml

In [131]:
import pandas as pd
import numpy as np
import random
import dice_ml
from sklearn.datasets import fetch_california_housing, load_boston
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import accuracy_score, classification_report
import warnings
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 35)


## load datasets

In [132]:
# boston dataset (regression)
boston = load_boston()
boston_df = pd.DataFrame(data=boston.data, columns=boston.feature_names)
boston_df["Price"] = boston.target
boston_df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


In [133]:

# cancer (classification)
breast_cancer = load_breast_cancer()
breast_cancer_df = pd.DataFrame(data=breast_cancer.data, columns=breast_cancer.feature_names)
breast_cancer_df["TumorType"] = breast_cancer.target
breast_cancer_df.head()

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,TumorType
0,17.99,10.38,122.8,1001.0,0.1184,0.2776,0.3001,0.1471,0.2419,0.07871,1.095,0.9053,8.589,153.4,0.006399,0.04904,0.05373,0.01587,0.03003,0.006193,25.38,17.33,184.6,2019.0,0.1622,0.6656,0.7119,0.2654,0.4601,0.1189,0
1,20.57,17.77,132.9,1326.0,0.08474,0.07864,0.0869,0.07017,0.1812,0.05667,0.5435,0.7339,3.398,74.08,0.005225,0.01308,0.0186,0.0134,0.01389,0.003532,24.99,23.41,158.8,1956.0,0.1238,0.1866,0.2416,0.186,0.275,0.08902,0
2,19.69,21.25,130.0,1203.0,0.1096,0.1599,0.1974,0.1279,0.2069,0.05999,0.7456,0.7869,4.585,94.03,0.00615,0.04006,0.03832,0.02058,0.0225,0.004571,23.57,25.53,152.5,1709.0,0.1444,0.4245,0.4504,0.243,0.3613,0.08758,0
3,11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156,3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7,0.2098,0.8663,0.6869,0.2575,0.6638,0.173,0
4,20.29,14.34,135.1,1297.0,0.1003,0.1328,0.198,0.1043,0.1809,0.05883,0.7572,0.7813,5.438,94.44,0.01149,0.02461,0.05688,0.01885,0.01756,0.005115,22.54,16.67,152.2,1575.0,0.1374,0.205,0.4,0.1625,0.2364,0.07678,0


## Regression

In [134]:
# train-test split
print("Dataset Size : ", boston.data.shape, boston.target.shape)
X_train, X_test, Y_train, Y_test = train_test_split(boston.data, boston.target,
                                                    train_size=0.90,
                                                    random_state=123)
print("Train/Test Sizes : ",X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)

Dataset Size :  (506, 13) (506,)
Train/Test Sizes :  (455, 13) (51, 13) (455,) (51,)


In [135]:
# model
model = RandomForestRegressor(max_depth=2, random_state=0)
# training model
model.fit(X_train, Y_train)
# scores
print("Train MSE : %.2f"%mean_squared_error(Y_train, model.predict(X_train)))
print("Test  MSE : %.2f"%mean_squared_error(Y_test, model.predict(X_test)))
print("Train R2 Score : %.2f"%r2_score(Y_train, model.predict(X_train)))
print("Test  R2 Score : %.2f"%r2_score(Y_test, model.predict(X_test)))

Train MSE : 17.73
Test  MSE : 60.44
Train R2 Score : 0.78
Test  R2 Score : 0.47


In [136]:
# data / model instances
d = dice_ml.Data(dataframe=boston_df, continuous_features=boston.feature_names.tolist(), outcome_name='Price')
m = dice_ml.Model(model=model, backend="sklearn", model_type='regressor')
# initiate DiCE
exp = dice_ml.Dice(d, m) # , method="genetic"
exp

<dice_ml.explainer_interfaces.dice_random.DiceRandom at 0x7fac9d3f4358>

In [138]:

# select random sample from test dataset
idx = random.randint(1, len(X_test))
print("Actual Price : %.2f"%Y_test[idx])
#sample = dict(zip(boston.feature_names, X_test[idx]))
sample = pd.DataFrame(X_test[idx:idx+1], columns = list(boston.feature_names))
sample

Actual Price : 21.60


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT
0,0.26938,0.0,9.9,0.0,0.544,6.266,82.8,3.2628,4.0,304.0,18.4,393.39,7.9


In [139]:
# generate an explanation instance with 4 counterfactual explanations with the same features as that of our original query sample
dice_exp = exp.generate_counterfactuals(sample, total_CFs=4, desired_range=[30.0, 35.0])

100%|██████████| 1/1 [00:00<00:00,  3.57it/s]


In [140]:
dice_exp.visualize_as_dataframe()

Query instance (original outcome : 24)


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,0.26938,0.0,9.9,0.0,0.544,6.266,82.800003,3.2628,4.0,304.0,18.4,393.390015,7.9,24.0



Diverse Counterfactual set (new outcome: [30.0, 35.0])


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,0.26938,0.0,9.9,0.0,0.427,7.361,82.8,3.2628,4.0,304.0,18.4,393.39,7.9,32.42625
1,0.26938,0.0,9.9,0.8,0.544,7.436,82.8,3.2628,4.0,304.0,18.4,393.39,7.9,34.799614
2,0.26938,0.0,9.9,0.0,0.544,7.223,82.8,3.2628,4.0,304.0,18.4,393.39,7.9,32.42625
3,0.26938,0.0,9.9,0.0,0.544,7.361,82.8,3.2628,4.0,304.0,18.4,393.39,7.9,32.42625


In [141]:
dice_exp.visualize_as_dataframe(show_only_changes=True)

Query instance (original outcome : 24)


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,0.26938,0.0,9.9,0.0,0.544,6.266,82.800003,3.2628,4.0,304.0,18.4,393.390015,7.9,24.0



Diverse Counterfactual set (new outcome: [30.0, 35.0])


Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,Price
0,0.26938,-,9.9,-,0.427,7.361,82.8,3.2628,-,-,18.4,393.39,7.9,32.42625045776367
1,0.26938,-,9.9,0.8,0.544,7.436,82.8,3.2628,-,-,18.4,393.39,7.9,34.79961395263672
2,0.26938,-,9.9,-,0.544,7.223,82.8,3.2628,-,-,18.4,393.39,7.9,32.42625045776367
3,0.26938,-,9.9,-,0.544,7.361,82.8,3.2628,-,-,18.4,393.39,7.9,32.42625045776367


## Classification

In [142]:
# train - test split
print("Dataset Size : ", breast_cancer.data.shape, breast_cancer.target.shape)
X_train, X_test, Y_train, Y_test = train_test_split(breast_cancer.data, breast_cancer.target,
                                                    train_size=0.90,
                                                    stratify=breast_cancer.target,
                                                    random_state=123)
print("Train/Test Sizes : ",X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)

Dataset Size :  (569, 30) (569,)
Train/Test Sizes :  (512, 30) (57, 30) (512,) (57,)


In [143]:
# model
model = RandomForestClassifier(max_depth=2, random_state=0)
# training model
model.fit(X_train, Y_train)
# scoring
test_preds = [0 if pred< 0.5 else 1 for pred in model.predict(X_test).flatten()]
train_preds = [0 if pred< 0.5 else 1 for pred in model.predict(X_train).flatten()]
print("Train Accuracy : %.2f"%accuracy_score(Y_train, train_preds))
print("Test  Accuracy : %.2f"%accuracy_score(Y_test, test_preds))
print("\nTest  Classification Report : ")
print(classification_report(Y_test, test_preds))

Train Accuracy : 0.97
Test  Accuracy : 0.91

Test  Classification Report : 
              precision    recall  f1-score   support

           0       0.90      0.86      0.88        21
           1       0.92      0.94      0.93        36

    accuracy                           0.91        57
   macro avg       0.91      0.90      0.90        57
weighted avg       0.91      0.91      0.91        57



In [144]:
# instances
d = dice_ml.Data(dataframe=breast_cancer_df,
                 continuous_features=breast_cancer.feature_names.tolist(),
                 outcome_name='TumorType')

m = dice_ml.Model(model=model, backend="sklearn", model_type='classifier')

# initiate DiCE
exp = dice_ml.Dice(d, m)

In [145]:
# get a random sample
idx = random.randint(1, len(X_test))
sample = dict(zip(breast_cancer.feature_names, X_test[idx]))
print("Actual Class : %d"%Y_test[idx])
sample = pd.DataFrame(X_test[idx:idx+1], columns = list(breast_cancer.feature_names))
sample

Actual Class : 1


Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension
0,14.26,19.65,97.83,629.9,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,29.25,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.3,23.73,107.0,709.0,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082


In [146]:

dice_exp = exp.generate_counterfactuals(sample, total_CFs=4, desired_class=1)

100%|██████████| 1/1 [00:10<00:00, 10.14s/it]


In [147]:
dice_exp.visualize_as_dataframe()

Query instance (original outcome : 0)


Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,TumorType
0,14.26,19.65,97.830002,629.900024,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,29.25,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.3,23.73,107.0,709.0,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082,0



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,TumorType
0,14.0,19.65,98.0,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,30.0,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.0,24.0,78.4,709.0,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082,1
1,14.0,19.65,98.0,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,30.0,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.0,24.0,55.8,709.0,0.0962,0.4193,0.6783,0.1505,0.2398,0.1082,1
2,15.01,19.65,98.0,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,30.0,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,16.01,24.0,107.0,709.0,0.1006,0.4193,0.6783,0.1,0.2398,0.1082,1
3,14.0,19.65,98.0,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,29.0,0.005298,0.07446,0.1435,0.02292,0.02566,0.004017,15.0,24.0,89.6,709.0,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082,1


In [148]:
dice_exp.visualize_as_dataframe(show_only_changes=True)

Query instance (original outcome : 0)


Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,TumorType
0,14.26,19.65,97.830002,629.900024,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,29.25,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.3,23.73,107.0,709.0,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082,0



Diverse Counterfactual set (new outcome: 1)


Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,radius error,texture error,perimeter error,area error,smoothness error,compactness error,concavity error,concave points error,symmetry error,fractal dimension error,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension,TumorType
0,14.0,19.65,98.00000000000009,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,30.00000000000012,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.0,24.000000000000043,78.4,-,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082,1.0
1,14.0,19.65,98.00000000000009,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,30.00000000000012,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,15.0,24.000000000000043,55.8,-,0.0962,0.4193,0.6783,0.1505,0.2398,0.1082,1.0
2,15.009999999999984,19.65,98.00000000000009,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,30.00000000000012,0.005298,0.07446,0.1435,0.02292,0.02566,0.01298,16.009999999999987,24.000000000000043,-,-,0.1006,0.4193,0.6783,0.1,0.2398,0.1082,1.0
3,14.0,19.65,98.00000000000009,630.0,0.07837,0.2233,0.3003,0.07798,0.1704,0.07769,0.3628,1.49,3.399,29.0,0.005298,0.07446,0.1435,0.02292,0.02566,0.004017,15.0,24.000000000000043,89.6,-,0.08949,0.4193,0.6783,0.1505,0.2398,0.1082,1.0
