Skip to content

Commit

Permalink
Update automl.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Apr 2, 2021
1 parent 23beefe commit ba571ac
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions supervised/automl.py
Expand Up @@ -59,7 +59,7 @@ def __init__(
total_time_limit (int): The total time limit in seconds for AutoML training.
It is not used when `model_time_limit` is not `None`.
mode (str): Can be {`Explain`, `Perform`, `Compete`}. This parameter defines the goal of AutoML and how intensive the AutoML search will be.
mode (str): Can be {`Explain`, `Perform`, `Compete`, `Optuna`}. This parameter defines the goal of AutoML and how intensive the AutoML search will be.
- `Explain` : To to be used when the user wants to explain and understand the data.
- Uses 75%/25% train/test split.
Expand All @@ -70,10 +70,15 @@ def __init__(
- Uses the following models: `Linear`, `Random Forest`, `LightGBM`, `XGBoost`, `CatBoost`, `Neural Network`, and `Ensemble`.
- Has learning curves and importance plots in reports.
- `Compete` : To be used for machine learning competitions (maximum performance).
- Uses 10-fold CV (Cross-Validation).
- Uses the following models: `Decision Tree`, `Random Forest`, `Extra Trees`, `XGBoost`, `CatBoost`, `Neural Network`,
- Uses 80/20 train/test split, or 5-fold CV, or 10-fold CV (Cross-Validation) - it depends on `total_time_limit`. If not set directly, AutoML will select validation automatically.
- Uses the following models: `Decision Tree`, `Random Forest`, `Extra Trees`, `LightGBM`, `XGBoost`, `CatBoost`, `Neural Network`,
`Nearest Neighbors`, `Ensemble`, and `Stacking`.
- It has only learning curves in the reports.
- `Optuna` : To be used for creating highly-tuned machine learning models.
- Uses 10-fold CV (Cross-Validation).
- It tunes with Optuna the following algorithms: `Random Forest`, `Extra Trees`, `LightGBM`, `XGBoost`, `CatBoost`, `Neural Network`.
- It applies `Ensemble` and `Stacking` for trained models.
- It has only learning curves in the reports.
ml_task (str): Can be {"auto", "binary_classification", "multiclass_classification", "regression"}.
Expand Down

0 comments on commit ba571ac

Please sign in to comment.