In [157]:
import pandas
import numpy 
import dice_ml
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.compose import ColumnTransformer

In [158]:
path = "healthcare-dataset-stroke-data.csv"
healthcare_dataset = pandas.read_csv(path)

In [159]:
healthcare_dataset

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1
...,...,...,...,...,...,...,...,...,...,...,...,...
5105,18234,Female,80.0,1,0,Yes,Private,Urban,83.75,,never smoked,0
5106,44873,Female,81.0,0,0,Yes,Self-employed,Urban,125.20,40.0,never smoked,0
5107,19723,Female,35.0,0,0,Yes,Self-employed,Rural,82.99,30.6,never smoked,0
5108,37544,Male,51.0,0,0,Yes,Private,Rural,166.29,25.6,formerly smoked,0


In [160]:
simplified_dataset = healthcare_dataset.loc[:,['age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi', 'smoking_status', 'stroke']]



In [161]:
simplified_dataset.dropna(subset = ['bmi'], inplace = True )

In [162]:
mapping = {'formerly smoked':1,
            'smokes':1,
            'never smoked': 0,
            'Unknown':0  }
simplified_dataset.replace({'smoking_status':mapping}, inplace = True)

In [163]:
simplified_dataset['smoking_status'].unique()

array([1, 0], dtype=int64)

In [164]:
simplified_dataset

Unnamed: 0,age,hypertension,heart_disease,avg_glucose_level,bmi,smoking_status,stroke
0,67.0,0,1,228.69,36.6,1,1
2,80.0,0,1,105.92,32.5,0,1
3,49.0,0,0,171.23,34.4,1,1
4,79.0,1,0,174.12,24.0,0,1
5,81.0,0,0,186.21,29.0,1,1
...,...,...,...,...,...,...,...
5104,13.0,0,0,103.08,18.6,0,0
5106,81.0,0,0,125.20,40.0,0,0
5107,35.0,0,0,82.99,30.6,0,0
5108,51.0,0,0,166.29,25.6,1,0


In [165]:
target = simplified_dataset['stroke']
dataset_X = simplified_dataset.drop('stroke',axis=1)

In [166]:
x_train, x_test, y_train, y_test = train_test_split(dataset_X,
                                                    target,
                                                    test_size=0.2,
                                                    random_state=0,
                                                    stratify=target)

In [167]:
#List of names of continuous features: (Important, it is a list)
continuous_feat = ['age','avg_glucose_level','bmi']

numeric_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())])

In [168]:
transformation = ColumnTransformer(transformers=[('num',numeric_transformer,continuous_feat)])

In [169]:
clf = Pipeline(steps=[('preprocessor', transformation),('classifier', RandomForestClassifier())])
model = clf.fit(x_train, y_train)

In [170]:
y_pred = model.predict(x_test)


In [171]:
score = sklearn.metrics.accuracy_score(y_test, y_pred)

In [172]:
 import dice_ml

In [186]:
data_dice = dice_ml.Data(dataframe=simplified_dataset,continuous_features = continuous_feat, outcome_name = 'stroke')

In [187]:
model_dice = dice_ml.Model(model = model,backend='sklearn')

In [188]:
explainer = dice_ml.Dice(data_dice,model_dice,method='random')

In [189]:
input_datapoint = x_test[0:1]

In [190]:
counterfactual = explainer.generate_counterfactuals(input_datapoint,total_CFs=5,desired_class = 'opposite')

  cfs_df = cfs_df.append(rows_to_add)
100%|██████████| 1/1 [00:02<00:00,  2.56s/it]


In [191]:
counterfactual.visualize_as_dataframe()

Query instance (original outcome : 0)


Unnamed: 0,age,hypertension,heart_disease,avg_glucose_level,bmi,smoking_status,stroke
0,78.0,0,0,55.32,29.6,1,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,age,hypertension,heart_disease,avg_glucose_level,bmi,smoking_status,stroke
0,78.0,0.0,0.0,56.01,23.9,1.0,1
1,78.0,0.0,0.0,56.01,23.7,1.0,0
2,78.0,1.0,0.0,56.01,23.9,1.0,1
3,78.0,0.0,0.0,56.01,24.2,1.0,1
4,78.0,0.0,0.0,56.01,25.7,1.0,1
