In [1]:
# Import necessary libraries
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import matplotlib


In [5]:


# Load the data
data = pd.read_csv('stroke_data.csv')

# Preprocessing
# Drop the ID column as it does not contribute to the prediction
data.drop('id', axis=1, inplace=True)

# Convert categorical features to numerical using one-hot encoding
data = pd.get_dummies(data, columns=['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status'], drop_first=True)



# Replace NaN values with the mean value of that column
data.fillna(data.mean(), inplace=True)

print(data.head())



    age  hypertension  heart_disease  avg_glucose_level        bmi  stroke  \
0  67.0             0              1             228.69  36.600000       1   
1  61.0             0              0             202.21  28.893237       1   
2  80.0             0              1             105.92  32.500000       1   
3  49.0             0              0             171.23  34.400000       1   
4  79.0             1              0             174.12  24.000000       1   

   gender_Male  gender_Other  ever_married_Yes  work_type_Never_worked  \
0            1             0                 1                       0   
1            0             0                 1                       0   
2            1             0                 1                       0   
3            0             0                 1                       0   
4            0             0                 1                       0   

   work_type_Private  work_type_Self-employed  work_type_children  \
0                

In [6]:
# Check for any NaN or infinity values in the data
print(data.isnull().values.any())
print(np.isfinite(data).all())

False
age                               True
hypertension                      True
heart_disease                     True
avg_glucose_level                 True
bmi                               True
stroke                            True
gender_Male                       True
gender_Other                      True
ever_married_Yes                  True
work_type_Never_worked            True
work_type_Private                 True
work_type_Self-employed           True
work_type_children                True
Residence_type_Urban              True
smoking_status_formerly smoked    True
smoking_status_never smoked       True
smoking_status_smokes             True
dtype: bool


In [3]:

# Split the data into training and testing sets
X = data.drop('stroke', axis=1)
y = data['stroke']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the decision tree classifier
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)

# Evaluate the model
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

# Print the evaluation metrics
print("Accuracy:", accuracy)
print("Confusion Matrix:\n", conf_matrix)
print("Classification Report:\n", class_report)



Accuracy: 0.9060665362035225
Confusion Matrix:
 [[912  48]
 [ 48  14]]
Classification Report:
               precision    recall  f1-score   support

           0       0.95      0.95      0.95       960
           1       0.23      0.23      0.23        62

    accuracy                           0.91      1022
   macro avg       0.59      0.59      0.59      1022
weighted avg       0.91      0.91      0.91      1022



In [9]:
from sklearn.tree import export_graphviz
import graphviz

# Export the decision tree to a Graphviz format
dot_data = export_graphviz(clf, out_file=None, 
                            feature_names=data.columns[:-1], 
                            class_names=['No Stroke', 'Stroke'],  
                            filled=True, rounded=True,  
                            special_characters=True)  


graph = graphviz.Source(dot_data)
graph.render('decision_tree')

ModuleNotFoundError: No module named 'graphviz'