In [12]:
import pandas as pd
import numpy as np
import dill

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import roc_auc_score, precision_recall_curve

Ссылка на датасет: https://www.kaggle.com/nareshbhat/health-care-data-set-on-heart-attack-possibility

Необходимо предсказать есть ли риск сердечного заболевания у человека.

Описание признаков:
1. age - возраст
2. sex - пол
3. cp - тип боли в груди (4 варианта ответа)
4. trestbps - артериальное давление в состоянии покоя
5. chol - уровень холестерина мг/дл
6. fbs - уровень сахара в крови натощак > 120 мг / дл
7. restecg - результаты электрокардиографии в покое (значения 0,1,2)
8. thalach - достигнутая максимальная частота пульса
9. exang - стенокардия, вызванная физической нагрузкой
10. oldpeak - депрессия ST, вызванная упражнениями по сравнению с отдыхом
11. slope - наклон сегмента ST при пиковой нагрузке
12. ca - количество крупных сосудов (0-3), окрашенных флурозопией
13. thal - 0 = нормально; 1 = исправленный дефект; 2 = обратимый дефект
14. target - 1 - высокий шанс сердечного приступа; 0 - низкий шанс сердечного приступа

In [3]:
df_path = "heart.csv"
df = pd.read_csv(df_path)

In [6]:
target = 'target'

X_train, X_test, y_train, y_test = train_test_split(df.drop(target, axis=1), df[target],
                                                    train_size=0.75,
                                                    stratify=df[target],
                                                    random_state=42)

In [7]:
def get_metrics(y_true, y_pred):
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)

    fscore = (2 * precision * recall) / (precision + recall)
    ix = np.argmax(fscore)
    
    return fscore[ix], precision[ix], recall[ix], roc_auc_score(y_true, y_pred), thresholds[ix]

In [9]:
model = RandomForestClassifier(random_state = 42)
model.fit(X_train, y_train);

In [13]:
with open("heart_attack_random_forest.dill", "wb") as f:
    dill.dump(model, f)

Проверка модели

In [14]:
with open('heart_attack_random_forest.dill', 'rb') as in_strm:
    model = dill.load(in_strm)

In [15]:
preds = model.predict_proba(X_test)[:, 1]
metrics = get_metrics(y_test, preds)

In [16]:
metrics = np.array(metrics).round(2)
print("fscore: {}; precision: {}; recall: {}; roc_auc: {}; threshold: {}".format(*metrics))

fscore: 0.84; precision: 0.74; recall: 0.98; roc_auc: 0.88; threshold: 0.4


In [19]:
preds

array([0.17, 0.01, 0.79, 0.69, 0.68, 0.01, 0.73, 0.64, 0.26, 0.14, 0.36,
       0.99, 0.98, 0.83, 0.21, 0.89, 0.83, 0.59, 0.84, 0.4 , 0.71, 0.26,
       0.71, 0.99, 0.81, 0.56, 0.11, 0.71, 0.86, 0.47, 0.67, 0.36, 0.42,
       0.28, 0.27, 0.78, 0.91, 0.46, 0.55, 0.91, 1.  , 0.87, 0.03, 0.45,
       0.61, 0.88, 0.05, 0.74, 0.46, 0.91, 0.73, 0.39, 0.99, 0.02, 0.42,
       0.6 , 0.13, 0.65, 0.04, 0.89, 0.4 , 0.91, 0.97, 0.07, 0.85, 0.72,
       0.87, 0.27, 0.18, 0.72, 0.27, 0.51, 0.83, 0.76, 0.83, 0.61])

In [20]:
preds = model.predict_proba(X_test)

In [26]:
df

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,63,1,3,145,233,1,0,150,0,2.3,0,0,1,1
1,37,1,2,130,250,0,1,187,0,3.5,0,0,2,1
2,41,0,1,130,204,0,0,172,0,1.4,2,0,2,1
3,56,1,1,120,236,0,1,178,0,0.8,2,0,2,1
4,57,0,0,120,354,0,1,163,1,0.6,2,0,2,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
298,57,0,0,140,241,0,1,123,1,0.2,1,0,3,0
299,45,1,3,110,264,0,1,132,0,1.2,1,0,3,0
300,68,1,0,144,193,1,1,141,0,3.4,1,2,3,0
301,57,1,0,130,131,0,1,115,1,1.2,1,1,3,0


In [27]:
X_train

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal
66,51,1,2,100,222,0,1,143,1,1.2,1,0,2
260,66,0,0,178,228,1,1,165,1,1.0,1,2,3
289,55,0,0,128,205,0,2,130,1,2.0,1,1,3
237,60,1,0,140,293,0,0,170,0,1.2,1,2,3
144,76,0,2,140,197,0,2,116,0,1.1,1,0,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...
170,56,1,2,130,256,1,0,142,1,0.6,1,1,1
60,71,0,2,110,265,1,0,130,0,0.0,2,1,2
128,52,0,2,136,196,0,0,169,0,0.1,1,0,2
53,44,0,2,108,141,0,1,175,0,0.6,1,0,2
