In [2]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, accuracy_score

In [3]:
import pandas as pd
pd.set_option('display.max_columns', None)
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import RandomOverSampler

class DataLoader():
    def __init__(self):
        self.data = None

    def load_dataset(self, path="healthcare-dataset-stroke-data.csv"):
        self.data = pd.read_csv(path)

    def preprocess_data(self):
        categorical_cols = ["gender",
                            "ever_married",
                            "work_type",
                            "Residence_type",
                            "smoking_status"]
        encoded = pd.get_dummies(self.data[categorical_cols],
                                prefix=categorical_cols)

        self.data = pd.concat([encoded, self.data], axis=1)
        self.data.drop(categorical_cols, axis=1, inplace=True)

        self.data.bmi = self.data.bmi.fillna(0)

        self.data.drop(["id"], axis=1, inplace=True)

    def get_data_split(self):
        X = self.data.iloc[:,:-1]
        y = self.data.iloc[:,-1]
        return train_test_split(X, y, test_size=0.20, random_state=2021)

    def oversample(self, X_train, y_train):
        oversample = RandomOverSampler(sampling_strategy='minority')
        x_np = X_train.to_numpy()
        y_np = y_train.to_numpy()
        x_np, y_np = oversample.fit_resample(x_np, y_np)
        x_over = pd.DataFrame(x_np, columns=X_train.columns)
        y_over = pd.Series(y_np, name=y_train.name)
        return x_over, y_over

In [4]:
data_loader = DataLoader()
data_loader.load_dataset()
data_loader.preprocess_data()

In [5]:
X_train, X_test, y_train, y_test = data_loader.get_data_split()
# Oversample the train data means eliminating skewness
X_train, y_train = data_loader.oversample(X_train, y_train)
print(X_train.shape)
print(X_test.shape)

(7778, 21)
(1022, 21)


In [6]:
rf = RandomForestClassifier()
rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)
print(f"F1 Score {f1_score(y_test, y_pred, average='macro')}")
print(f"Accuracy {accuracy_score(y_test, y_pred)}")

F1 Score 0.5459447865919049
Accuracy 0.9432485322896281


In [7]:
!pip install dice-ml

Collecting dice-ml
  Downloading dice_ml-0.10-py3-none-any.whl (2.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
Collecting raiutils>=0.4.0 (from dice-ml)
  Downloading raiutils-0.4.1-py3-none-any.whl (17 kB)
Installing collected packages: raiutils, dice-ml
Successfully installed dice-ml-0.10 raiutils-0.4.1


In [8]:
import dice_ml
# Data object tells what the data looks like
data_dice = dice_ml.Data(dataframe=data_loader.data,
                         # For perturbation we need to tell continuous and discrete data as stretegy is different
                         continuous_features=['age',
                                              'avg_glucose_level',
                                              'bmi'],
                         outcome_name='stroke') #need to specify target variable



In [9]:
rf_dice = dice_ml.Model(model=rf,
                        #need to tell background model type as data handling  is done accordingly
                        backend="sklearn")
#now giving data and model object to Dice
explainer = dice_ml.Dice(data_dice,
                         rf_dice,
                         #we have different methods here like Random sampling, genetic algorithm,..etc
                         method="random")#here used random sampling

In [10]:
input_datapoint = X_test[0:1] #data point selected
cf = explainer.generate_counterfactuals(input_datapoint,
                                  total_CFs=3, #no. of counterfactuals to be generated
                                  desired_class="opposite") #for binary

100%|██████████| 1/1 [00:00<00:00,  4.71it/s]


In [11]:
cf.visualize_as_dataframe(show_only_changes=True)

Query instance (original outcome : 0)


Unnamed: 0,gender_Female,gender_Male,gender_Other,ever_married_No,ever_married_Yes,work_type_Govt_job,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,70.0,0,0,72.559998,30.4,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,gender_Female,gender_Male,gender_Other,ever_married_No,ever_married_Yes,work_type_Govt_job,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,149.72,30.4,1
1,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,143.0,30.4,1
2,0.0,1.0,0.0,0.0,1.0,0.0,-,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,151.45,30.4,1


In [12]:
features_to_vary=['avg_glucose_level',
                  'bmi',
                  'smoking_status_smokes']
permitted_range={'avg_glucose_level':[50,250],
                'bmi':[18, 35]}

In [13]:
cf = explainer.generate_counterfactuals(input_datapoint,
                                  total_CFs=3,
                                  desired_class="opposite",
                                  permitted_range=permitted_range,
                                  features_to_vary=features_to_vary)
# Visualize it
cf.visualize_as_dataframe(show_only_changes=True)

100%|██████████| 1/1 [00:00<00:00,  3.30it/s]

Query instance (original outcome : 0)





Unnamed: 0,gender_Female,gender_Male,gender_Other,ever_married_No,ever_married_Yes,work_type_Govt_job,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
0,0,1,0,0,1,0,0,1,0,0,1,0,0,1,0,0,70.0,0,0,72.559998,30.4,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,gender_Female,gender_Male,gender_Other,ever_married_No,ever_married_Yes,work_type_Govt_job,work_type_Never_worked,work_type_Private,work_type_Self-employed,work_type_children,Residence_type_Rural,Residence_type_Urban,smoking_status_Unknown,smoking_status_formerly smoked,smoking_status_never smoked,smoking_status_smokes,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,143.78,30.4,1
1,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,148.8,30.4,1
2,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,-,0.0,0.0,136.62,30.4,1
