In [58]:
# Import necessary libraries

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib
import seaborn as sns

In [59]:


# Load the data
data = pd.read_csv('stroke_data.csv')

data.head(10)


Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1
5,56669,Male,81.0,0,0,Yes,Private,Urban,186.21,29.0,formerly smoked,1
6,53882,Male,74.0,1,1,Yes,Private,Rural,70.09,27.4,never smoked,1
7,10434,Female,69.0,0,0,No,Private,Urban,94.39,22.8,never smoked,1
8,27419,Female,59.0,0,0,Yes,Private,Rural,76.15,,Unknown,1
9,60491,Female,78.0,0,0,Yes,Private,Urban,58.57,24.2,Unknown,1


In [60]:
data.describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
id,5110.0,36517.829354,21161.721625,67.0,17741.25,36932.0,54682.0,72940.0
age,5110.0,43.226614,22.612647,0.08,25.0,45.0,61.0,82.0
hypertension,5110.0,0.097456,0.296607,0.0,0.0,0.0,0.0,1.0
heart_disease,5110.0,0.054012,0.226063,0.0,0.0,0.0,0.0,1.0
avg_glucose_level,5110.0,106.147677,45.28356,55.12,77.245,91.885,114.09,271.74
bmi,4909.0,28.893237,7.854067,10.3,23.5,28.1,33.1,97.6
stroke,5110.0,0.048728,0.21532,0.0,0.0,0.0,0.0,1.0


In [61]:
# Quick view of distribution of target: stroke
data.groupby("stroke")["stroke"].count()

stroke
0    4861
1     249
Name: stroke, dtype: int64

In [62]:
# Quick view of distribution of feature: gender
data.groupby("gender")["gender"].count()


gender
Female    2994
Male      2115
Other        1
Name: gender, dtype: int64

In [63]:
# Quick view of distribution of feature: gender
data.groupby("Residence_type")["Residence_type"].count()


Residence_type
Rural    2514
Urban    2596
Name: Residence_type, dtype: int64

In [64]:
# Quick view of distribution of feature: gender
data.groupby("smoking_status")["smoking_status"].count()

smoking_status
Unknown            1544
formerly smoked     885
never smoked       1892
smokes              789
Name: smoking_status, dtype: int64

In [65]:
# Quick view of distribution of feature: gender
data.groupby("work_type")["work_type"].count()

work_type
Govt_job          657
Never_worked       22
Private          2925
Self-employed     819
children          687
Name: work_type, dtype: int64

In [66]:

# Preprocessing

data = data.dropna()

# Replace 'Other' gender with 'Female'
data = data[data.gender != "Other"]


# Gender, Residence_type, and ever_married are binary variables that must be converted to the 0/1 format
data["gender"].replace({"Male": 0, "Female": 1}, inplace = True)
data["Residence_type"].replace({"Urban": 0, "Rural": 1}, inplace = True)
data["ever_married"].replace({"No": 0, "Yes": 1}, inplace = True)



# Remove redundant columns
data = data.drop(["id", "work_type", "smoking_status"], axis = 1)

data.head()


Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,Residence_type,avg_glucose_level,bmi,stroke
0,0,67.0,0,1,1,0,228.69,36.6,1
2,0,80.0,0,1,1,1,105.92,32.5,1
3,1,49.0,0,0,1,0,171.23,34.4,1
4,1,79.0,1,0,1,1,174.12,24.0,1
5,0,81.0,0,0,1,0,186.21,29.0,1


In [43]:
data.describe().T

Unnamed: 0,count,mean,std,min,25%,50%,75%,max
gender,4909.0,0.590344,0.49182,0.0,0.0,1.0,1.0,1.0
age,4909.0,42.865374,22.555115,0.08,25.0,44.0,60.0,82.0
hypertension,4909.0,0.091872,0.288875,0.0,0.0,0.0,0.0,1.0
heart_disease,4909.0,0.049501,0.216934,0.0,0.0,0.0,0.0,1.0
ever_married,4909.0,0.652679,0.476167,0.0,0.0,1.0,1.0,1.0
Residence_type,4909.0,0.492768,0.499999,0.0,0.0,0.0,1.0,1.0
avg_glucose_level,4909.0,105.30515,44.424341,55.12,77.07,91.68,113.57,271.74
bmi,4909.0,28.893237,7.854067,10.3,23.5,28.1,33.1,97.6
stroke,4909.0,0.042575,0.201917,0.0,0.0,0.0,0.0,1.0


In [94]:
# Split the data into training and testing sets
X = data.drop('stroke', axis=1)
y = data['stroke']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the decision tree classifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)

# Evaluate the model
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

# Print the evaluation metrics
print("Accuracy:", accuracy)
print("Classification Report:\n", class_report)

Accuracy: 0.9032586558044806
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.94      0.95       929
           1       0.16      0.19      0.17        53

    accuracy                           0.90       982
   macro avg       0.56      0.57      0.56       982
weighted avg       0.91      0.90      0.91       982



In [76]:
from sklearn.tree import export_graphviz
import graphviz

# Export the decision tree to a Graphviz format
dot_data = export_graphviz(clf, out_file=None, 
                            feature_names=data.columns[:-1], 
                            class_names=['No Stroke', 'Stroke'],  
                            filled=True, rounded=True,  
                            special_characters=True)  


graph = graphviz.Source(dot_data)
graph.render('decision_tree')

'decision_tree.pdf'

In [90]:
# [gender,	age,	hypertension,	heart_disease,	ever_married,	Residence_type,	avg_glucose_level,	bmi]

case_number_1 = [0,	70, 1,	0,	1,	1,	250.69,	40]
case_number_2 = [0,	50, 0,	0,	0,	0,	280.69,	45]


prediction = clf.predict([case_number_2])
print(prediction)


if prediction == 0:
  print("No stroke")
else:
  print("stroke!!!!")





[0]
No stroke


