In [1]:
from matplotlib import pyplot as plt
from sklearn import tree
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier

In [2]:
# get and clean data
strokedata = pd.read_csv("stroke_data.csv")

# replacing null BMI values with median value
strokedata["bmi"].fillna(strokedata["bmi"].median(), inplace=True)
# drop instance with Other gender
strokedata = strokedata.drop(strokedata[strokedata["gender"] == "Other"].index)
# drop useless columns
strokedata = strokedata.drop(
    [
        "id"
    ],
    axis=1,
)
# create one-hot encoding
strokedata = pd.get_dummies(strokedata, columns=["smoking_status", 'work_type', 'Residence_type', 'ever_married', 'gender'])
strokedata = strokedata.drop(
    [ 
        'work_type_Govt_job',
        'Residence_type_Urban',
        'smoking_status_smokes',
        'smoking_status_Unknown',
        'ever_married_No',
        'Residence_type_Rural',
        'work_type_Private',
        'heart_disease',
        'smoking_status_never smoked',
        'work_type_children',
        'ever_married_Yes',
        'hypertension'
    ],
    axis=1,
)


len(strokedata.columns)
X = strokedata.iloc[:, 0:8]

In [3]:
#Splitting into training and testing sets
train_set, test_set = train_test_split(strokedata, test_size=.15, stratify=strokedata["stroke"], random_state=42)


#Creating input and label data
target = 'stroke'
X_train = train_set.drop(target, axis=1)
X_test = test_set.drop(target, axis=1)
y_train = train_set[target]
y_test = test_set[target]

count = Counter(y_train)

In [4]:
# preprocess
# scale
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# oversample
oversample = SMOTE()
X_train, y_train = oversample.fit_resample(X_train, y_train)

In [5]:
# Cross-Validation
param_grid = {
    "criterion": ['gini', 'entropy', 'log_loss'],
    "ccp_alpha": [0.1, 0.01, 0.001, 0.0001],
    "splitter": ['best', 'random'],
    'max_leaf_nodes': list(range(2, 50)),
    'max_depth': list(range(1, 7)),
    'min_samples_split': [2, 3, 4]
}

grid = GridSearchCV(DecisionTreeClassifier(), param_grid, refit=True, verbose=3, scoring="roc_auc")

grid.fit(X_train, y_train)

Fitting 5 folds for each of 20736 candidates, totalling 103680 fits
[CV 1/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=2, min_samples_split=2, splitter=best;, score=0.781 total time=   0.0s
[CV 2/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=2, min_samples_split=2, splitter=best;, score=0.772 total time=   0.0s
[CV 3/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=2, min_samples_split=2, splitter=best;, score=0.771 total time=   0.0s
[CV 4/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=2, min_samples_split=2, splitter=best;, score=0.791 total time=   0.0s
[CV 5/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=2, min_samples_split=2, splitter=best;, score=0.779 total time=   0.0s
[CV 1/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=2, min_samples_split=2, splitter=random;, score=0.730 total time=   0.0s
[CV 2/5] END ccp_alpha=0.1, criterion=gini, max_depth=1, max_leaf_nodes=

KeyboardInterrupt: 

In [None]:
dtf = grid.best_estimator_

dtf.fit(X_train, y_train)

y_pred = dtf.predict(X_test)
print(classification_report(y_test, y_pred))
print(confusion_matrix(y_test, y_pred))

In [None]:
# Creating data visualizations DT
DTC = DecisionTreeClassifier(criterion='log_loss', ccp_alpha=0.001, splitter='best', max_leaf_nodes=48, min_samples_split=4,max_depth=6)
dt_model = DTC.fit(X_train, y_train)

fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(DTC, 
                   feature_names=strokedata.columns,  
                   class_names=target,
                   filled=True)