In [1]:
import pandas as pd

from autogluon.tabular import TabularDataset, TabularPredictor
from sklearn.model_selection import train_test_split

In [2]:
df = pd.read_csv('../../data/strokes/stroke-gc.csv')

In [3]:
df.head()

Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,Male,67,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,Female,61,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,Male,80,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,Female,49,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,Female,79,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1


In [4]:
df_train, df_test=train_test_split(df,test_size=0.4, random_state=42)
df_train.shape, df_test.shape

((5833, 11), (3889, 11))

In [5]:
predictor= TabularPredictor(
    label='stroke',
    problem_type='binary',
    eval_metric='accuracy',
    path='AutogluonModels/strokes/synthetic/gc',
    learner_kwargs = {'positive_class':1}
    )

In [6]:
predictor.fit(
    train_data = df_train,
    verbosity = 2,
    presets=['best_quality'],
    time_limit = 350
)

Verbosity: 2 (Standard Logging)
AutoGluon Version:  1.1.1
Python Version:     3.10.2
Operating System:   Windows
Platform Machine:   AMD64
Platform Version:   10.0.26100
CPU Count:          20
Memory Avail:       8.10 GB / 15.71 GB (51.5%)
Disk Space Avail:   664.97 GB / 933.46 GB (71.2%)
Presets specified: ['best_quality']
Setting dynamic_stacking from 'auto' to True. Reason: Enable dynamic_stacking when use_bag_holdout is disabled. (use_bag_holdout=False)
Stack configuration (auto_stack=True): num_stack_levels=1, num_bag_folds=8, num_bag_sets=1
DyStack is enabled (dynamic_stacking=True). AutoGluon will try to determine whether the input data is affected by stacked overfitting and enable or disable stacking as a consequence.
	This is used to identify the optimal `num_stack_levels` value. Copies of AutoGluon will be fit on subsets of the data. Then holdout validation data is used to detect stacked overfitting.
	Running DyStack for up to 87s of the 350s of remaining time (25%).
	Running

<autogluon.tabular.predictor.predictor.TabularPredictor at 0x29098656bc0>

In [7]:
predictor.fit_summary()

*** Summary of fit() ***
Estimated performance of each model:
                      model  score_val eval_metric  pred_time_val    fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0            XGBoost_BAG_L2   0.786388    accuracy       2.478624  130.806585                0.076924           3.819755            2       True         25
1       WeightedEnsemble_L3   0.786388    accuracy       2.478624  131.211007                0.000000           0.404422            3       True         26
2           CatBoost_BAG_L2   0.783816    accuracy       2.439869  160.374412                0.038169          33.387582            2       True         21
3           LightGBM_BAG_L2   0.782616    accuracy       2.497420  130.061250                0.095720           3.074420            2       True         18
4         LightGBMXT_BAG_L2   0.782102    accuracy       2.450650  129.544280                0.048950           2.557450            2       True         17
5 



{'model_types': {'KNeighborsUnif_BAG_L1': 'StackerEnsembleModel_KNN',
  'KNeighborsDist_BAG_L1': 'StackerEnsembleModel_KNN',
  'LightGBMXT_BAG_L1': 'StackerEnsembleModel_LGB',
  'LightGBM_BAG_L1': 'StackerEnsembleModel_LGB',
  'RandomForestGini_BAG_L1': 'StackerEnsembleModel_RF',
  'RandomForestEntr_BAG_L1': 'StackerEnsembleModel_RF',
  'CatBoost_BAG_L1': 'StackerEnsembleModel_CatBoost',
  'ExtraTreesGini_BAG_L1': 'StackerEnsembleModel_XT',
  'ExtraTreesEntr_BAG_L1': 'StackerEnsembleModel_XT',
  'NeuralNetFastAI_BAG_L1': 'StackerEnsembleModel_NNFastAiTabular',
  'XGBoost_BAG_L1': 'StackerEnsembleModel_XGBoost',
  'NeuralNetTorch_BAG_L1': 'StackerEnsembleModel_TabularNeuralNetTorch',
  'LightGBMLarge_BAG_L1': 'StackerEnsembleModel_LGB',
  'CatBoost_r177_BAG_L1': 'StackerEnsembleModel_CatBoost',
  'LightGBM_r131_BAG_L1': 'StackerEnsembleModel_LGB',
  'WeightedEnsemble_L2': 'WeightedEnsembleModel',
  'LightGBMXT_BAG_L2': 'StackerEnsembleModel_LGB',
  'LightGBM_BAG_L2': 'StackerEnsembleMod

In [8]:
predictor.evaluate(df_test)

{'accuracy': 0.7732064798148625,
 'balanced_accuracy': 0.7729109656446627,
 'mcc': 0.5462864693514804,
 'roc_auc': 0.8532771220500188,
 'f1': 0.7660477453580902,
 'precision': 0.7771797631862217,
 'recall': 0.7552301255230126}

In [9]:
predictor.leaderboard(df_train, silent=True)

Unnamed: 0,model,score_test,score_val,eval_metric,pred_time_test,pred_time_val,fit_time,pred_time_test_marginal,pred_time_val_marginal,fit_time_marginal,stack_level,can_infer,fit_order
0,KNeighborsDist_BAG_L1,1.0,0.667581,accuracy,0.039738,0.043153,0.00963,0.039738,0.043153,0.00963,1,True,2
1,RandomForestGini_BAG_L1,1.0,0.756215,accuracy,0.20253,0.303263,0.92706,0.20253,0.303263,0.92706,1,True,5
2,RandomForestEntr_BAG_L1,1.0,0.754329,accuracy,0.21525,0.338248,0.891416,0.21525,0.338248,0.891416,1,True,6
3,ExtraTreesEntr_BAG_L1,1.0,0.7449,accuracy,0.241079,0.280172,0.809603,0.241079,0.280172,0.809603,1,True,9
4,ExtraTreesGini_BAG_L1,1.0,0.740271,accuracy,0.248679,0.285701,0.914918,0.248679,0.285701,0.914918,1,True,8
5,LightGBMLarge_BAG_L1,0.972227,0.762558,accuracy,0.286237,0.139235,3.741095,0.286237,0.139235,3.741095,1,True,13
6,ExtraTreesGini_BAG_L2,0.875536,0.770273,accuracy,4.121466,2.690228,127.844404,0.25568,0.288528,0.857574,2,True,22
7,XGBoost_BAG_L1,0.87485,0.766158,accuracy,0.328042,0.089384,2.88898,0.328042,0.089384,2.88898,1,True,11
8,ExtraTreesEntr_BAG_L2,0.873821,0.76993,accuracy,4.088714,2.691895,127.832425,0.222928,0.290194,0.845595,2,True,23
9,LightGBM_r131_BAG_L1,0.860449,0.762729,accuracy,0.330314,0.197407,3.365595,0.330314,0.197407,3.365595,1,True,15


In [10]:
predictor.feature_importance(data=df_train)

Computing feature importance via permutation shuffling for 10 features using 5000 rows with 5 shuffle sets...
	224.33s	= Expected runtime (44.87s per shuffle set)
	128.75s	= Actual runtime (Completed 5 of 5 shuffle sets)


Unnamed: 0,importance,stddev,p_value,n,p99_high,p99_low
age,0.19516,0.004251,2.698755e-08,5,0.203912,0.186408
avg_glucose_level,0.13516,0.00639,5.976015e-07,5,0.148316,0.122004
work_type,0.12256,0.006511,9.520062e-07,5,0.135965,0.109155
ever_married,0.0972,0.005963,1.691466e-06,5,0.109478,0.084922
bmi,0.06564,0.003025,5.394247e-07,5,0.071868,0.059412
smoking_status,0.02828,0.004144,5.377412e-05,5,0.036812,0.019748
hypertension,0.0216,0.002585,2.413646e-05,5,0.026922,0.016278
heart_disease,0.02028,0.000955,5.883235e-07,5,0.022246,0.018314
gender,0.01128,0.001753,6.775392e-05,5,0.014889,0.007671
Residence_type,0.00272,0.000657,0.0003791366,5,0.004073,0.001367
