<a href="https://colab.research.google.com/github/marco10507/ml-portfolio/blob/main/stroke_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import gdown
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import classification_report

file_id = '1AbKB9-FUmOwJpAipx8ij4V60eLvPRCKr'
download_url = f'https://drive.google.com/uc?id={file_id}'
gdown.download(download_url, 'healthcare-dataset-stroke-data.csv', quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1AbKB9-FUmOwJpAipx8ij4V60eLvPRCKr
To: /content/healthcare-dataset-stroke-data.csv
100%|██████████| 317k/317k [00:00<00:00, 70.5MB/s]


'healthcare-dataset-stroke-data.csv'

In [20]:
import pandas as pd
stroke_data = pd.read_csv('healthcare-dataset-stroke-data.csv');
stroke_data.info()
stroke_data.describe()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 5110 non-null   int64  
 1   gender             5110 non-null   object 
 2   age                5110 non-null   float64
 3   hypertension       5110 non-null   int64  
 4   heart_disease      5110 non-null   int64  
 5   ever_married       5110 non-null   object 
 6   work_type          5110 non-null   object 
 7   Residence_type     5110 non-null   object 
 8   avg_glucose_level  5110 non-null   float64
 9   bmi                4909 non-null   float64
 10  smoking_status     5110 non-null   object 
 11  stroke             5110 non-null   int64  
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB


Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
count,5110.0,5110.0,5110.0,5110.0,5110.0,4909.0,5110.0
mean,36517.829354,43.226614,0.097456,0.054012,106.147677,28.893237,0.048728
std,21161.721625,22.612647,0.296607,0.226063,45.28356,7.854067,0.21532
min,67.0,0.08,0.0,0.0,55.12,10.3,0.0
25%,17741.25,25.0,0.0,0.0,77.245,23.5,0.0
50%,36932.0,45.0,0.0,0.0,91.885,28.1,0.0
75%,54682.0,61.0,0.0,0.0,114.09,33.1,0.0
max,72940.0,82.0,1.0,1.0,271.74,97.6,1.0


In [21]:
print("gender unique values:", stroke_data['gender'].unique())
print("ever_married unique values:", stroke_data['ever_married'].unique())
print("work_type unique values:", stroke_data['work_type'].unique())
print("Residence_type unique values:", stroke_data['Residence_type'].unique())
print("smoking_status unique values:", stroke_data['smoking_status'].unique())

gender unique values: ['Male' 'Female' 'Other']
ever_married unique values: ['Yes' 'No']
work_type unique values: ['Private' 'Self-employed' 'Govt_job' 'children' 'Never_worked']
Residence_type unique values: ['Urban' 'Rural']
smoking_status unique values: ['formerly smoked' 'never smoked' 'smokes' 'Unknown']


In [22]:
print(stroke_data['stroke'].value_counts())

stroke
0    4861
1     249
Name: count, dtype: int64


In [23]:
stroke_data = stroke_data.dropna()
stroke_data = stroke_data[stroke_data['smoking_status'] != 'Unknown'];

In [24]:
print(stroke_data['stroke'].value_counts())

stroke
0    3246
1     180
Name: count, dtype: int64


In [25]:
stroke_data = pd.get_dummies(stroke_data, columns=['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status'])

In [26]:
print(stroke_data.columns)

Index(['id', 'age', 'hypertension', 'heart_disease', 'avg_glucose_level',
       'bmi', 'stroke', 'gender_Female', 'gender_Male', 'gender_Other',
       'ever_married_No', 'ever_married_Yes', 'work_type_Govt_job',
       'work_type_Never_worked', 'work_type_Private',
       'work_type_Self-employed', 'work_type_children', 'Residence_type_Rural',
       'Residence_type_Urban', 'smoking_status_formerly smoked',
       'smoking_status_never smoked', 'smoking_status_smokes'],
      dtype='object')


In [27]:
majority_class = stroke_data[stroke_data['stroke'] == 0]
minority_class = stroke_data[stroke_data['stroke'] == 1]

majority_class_sampled = majority_class.sample(n=len(minority_class), random_state=42)

stroke_data_balanced = pd.concat([majority_class_sampled, minority_class])

print(stroke_data_balanced['stroke'].value_counts())

stroke
0    180
1    180
Name: count, dtype: int64


In [28]:
features = stroke_data_balanced.drop(columns=['id', 'stroke'])
target = stroke_data_balanced['stroke'];


X_train, X_test, y_train, y_test = train_test_split(features, target, random_state=42, test_size=0.2)

In [29]:
gbm = GradientBoostingClassifier()

param_grip = {
    "n_estimators": [50, 100, 150],
    "learning_rate": [0.01, 0.1, 0.2],
    "max_depth": [3, 5, 7],
    "subsample" : [0.8, 0.9, 1]
}

search_grid = GridSearchCV(param_grid=param_grip, estimator=gbm, cv=5, n_jobs=-1)
search_grid.fit(X_train, y_train)

In [30]:
print("best params", search_grid.best_params_)
print("best score", search_grid.best_score_)

best params {'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 50, 'subsample': 0.9}
best score 0.7082274652147611


In [31]:
y_pred = search_grid.predict(X_test)

print(classification_report(y_test, y_pred))


feature_importance_df = pd.DataFrame({
    'Feature': X_test.columns,
    'Importance': search_grid.best_estimator_.feature_importances_
})

feature_importance_df = feature_importance_df.sort_values(by='Importance', ascending=False)


print(feature_importance_df)


              precision    recall  f1-score   support

           0       0.75      0.57      0.65        37
           1       0.64      0.80      0.71        35

    accuracy                           0.68        72
   macro avg       0.69      0.68      0.68        72
weighted avg       0.69      0.68      0.68        72

                           Feature  Importance
0                              age    0.453144
3                avg_glucose_level    0.209076
4                              bmi    0.163345
1                     hypertension    0.024607
19           smoking_status_smokes    0.024089
18     smoking_status_never smoked    0.021497
16            Residence_type_Urban    0.014366
13         work_type_Self-employed    0.013724
9                 ever_married_Yes    0.013404
12               work_type_Private    0.013239
6                      gender_Male    0.011402
2                    heart_disease    0.010073
10              work_type_Govt_job    0.007367
17  smoking_sta