Skip to content

Commit

Permalink
Initial implementation of the end-to-end autotrain module (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
ANarayan committed Jul 14, 2021
1 parent a2a076a commit c3fffea
Show file tree
Hide file tree
Showing 13 changed files with 1,007 additions and 70 deletions.
73 changes: 49 additions & 24 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from ludwig.backend import Backend, initialize_backend
from ludwig.callbacks import Callback
from ludwig.constants import FULL, PREPROCESSING, TEST, TRAINING, VALIDATION
from ludwig.constants import FULL, PREPROCESSING, TEST, TRAINING, VALIDATION, LEARNING_RATE, BATCH_SIZE, AUTO
from ludwig.data.dataset.base import Dataset
from ludwig.data.postprocessing import convert_predictions, postprocess
from ludwig.data.preprocessing import (load_metadata,
Expand Down Expand Up @@ -347,12 +347,12 @@ def train(
# if we are skipping all saving,
# there is no need to create a directory that will remain empty
should_create_output_directory = not (
skip_save_training_description and
skip_save_training_statistics and
skip_save_model and
skip_save_progress and
skip_save_log and
skip_save_processed_input
skip_save_training_description and
skip_save_training_statistics and
skip_save_model and
skip_save_progress and
skip_save_log and
skip_save_processed_input
)

output_url = output_directory
Expand All @@ -375,7 +375,8 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
output_directory)

if isinstance(training_set, Dataset) and training_set_metadata is not None:
preprocessed_data = (training_set, validation_set, test_set, training_set_metadata)
preprocessed_data = (
training_set, validation_set, test_set, training_set_metadata)
else:
# save description
if self.backend.is_coordinator():
Expand All @@ -394,10 +395,12 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
# print description
logger.info('Experiment name: {}'.format(experiment_name))
logger.info('Model name: {}'.format(model_name))
logger.info('Output directory: {}'.format(output_directory))
logger.info(
'Output directory: {}'.format(output_directory))
logger.info('\n')
for key, value in description.items():
logger.info('{}: {}'.format(key, pformat(value, indent=4)))
logger.info('{}: {}'.format(
key, pformat(value, indent=4)))
logger.info('\n')

preprocessed_data = self.preprocess(
Expand Down Expand Up @@ -431,7 +434,8 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
if self.backend.is_coordinator():
logger.info('Training set: {0}'.format(len(training_set)))
if validation_set is not None:
logger.info('Validation set: {0}'.format(len(validation_set)))
logger.info('Validation set: {0}'.format(
len(validation_set)))
if test_set is not None:
logger.info('Test set: {0}'.format(len(test_set)))
if not skip_save_model:
Expand Down Expand Up @@ -486,6 +490,26 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
config_fp=self.config_fp,
)

# auto tune batch size
if self.config[TRAINING][BATCH_SIZE] == AUTO:
# TODO (ASN): add support for substitute_with_max parameter
tuned_batch_size = trainer.tune_batch_size(
self.config,
training_set,
random_seed=random_seed
)
self.config[TRAINING][BATCH_SIZE] = tuned_batch_size

# auto tune learning rate
if self.config[TRAINING][LEARNING_RATE] == AUTO:
new_learning_rate = trainer.tune_learning_rate(
self.config,
LudwigModel.create_model(self.config, random_seed),
training_set,
random_seed=random_seed
)
self.config[TRAINING][LEARNING_RATE] = new_learning_rate

# train model
if self.backend.is_coordinator():
print_boxed('TRAINING')
Expand Down Expand Up @@ -522,7 +546,8 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
# results of the model with highest validation test performance
if self.backend.is_coordinator() and validation_set is not None:
epoch_best_vali_metric, best_vali_metric = best_function(
enumerate(validation_field_result[validation_metric]),
enumerate(
validation_field_result[validation_metric]),
key=lambda pair: pair[1]
)
logger.info(
Expand Down Expand Up @@ -717,7 +742,7 @@ def predict(
# if we are skipping all saving,
# there is no need to create a directory that will remain empty
should_create_exp_dir = not (
skip_save_unprocessed_output and skip_save_predictions
skip_save_unprocessed_output and skip_save_predictions
)
if should_create_exp_dir:
makedirs(output_directory, exist_ok=True)
Expand All @@ -730,7 +755,7 @@ def predict(
output_directory=output_directory,
backend=self.backend,
skip_save_unprocessed_output=skip_save_unprocessed_output
or not self.backend.is_coordinator(),
or not self.backend.is_coordinator(),
)
converted_postproc_predictions = convert_predictions(
postproc_predictions,
Expand Down Expand Up @@ -869,9 +894,9 @@ def evaluate(
# if we are skipping all saving,
# there is no need to create a directory that will remain empty
should_create_exp_dir = not (
skip_save_unprocessed_output and
skip_save_predictions and
skip_save_eval_stats
skip_save_unprocessed_output and
skip_save_predictions and
skip_save_eval_stats
)
if should_create_exp_dir:
makedirs(output_directory, exist_ok=True)
Expand All @@ -885,16 +910,16 @@ def evaluate(
output_directory=output_directory,
backend=self.backend,
skip_save_unprocessed_output=skip_save_unprocessed_output
or not self.backend.is_coordinator(),
or not self.backend.is_coordinator(),
)
else:
postproc_predictions = predictions # = {}

if self.backend.is_coordinator():
should_save_predictions = (
collect_predictions
and postproc_predictions is not None
and not skip_save_predictions
collect_predictions
and postproc_predictions is not None
and not skip_save_predictions
)
if should_save_predictions:
save_prediction_outputs(
Expand Down Expand Up @@ -1295,7 +1320,7 @@ def preprocess(
training_set_metadata) = preprocessed_data

return proc_training_set, proc_validation_set, proc_test_set, \
training_set_metadata
training_set_metadata

@staticmethod
def load(
Expand Down Expand Up @@ -1357,7 +1382,7 @@ def load(
model_dir,
MODEL_HYPERPARAMETERS_FILE_NAME
)
))
))

if backend_param is None and 'backend' in config:
# Reset backend from config
Expand Down Expand Up @@ -1707,7 +1732,7 @@ def kfold_cross_validate(
else:
ValueError(
"{} format is not supported for k_fold_cross_validate()"
.format(data_format)
.format(data_format)
)

kfold_cv_stats = {}
Expand Down
Empty file added ludwig/automl/__init__.py
Empty file.
108 changes: 108 additions & 0 deletions ludwig/automl/automl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
automl.py
Driver script which:
(1) Builds a base config by performing type inference and populating config
w/default combiner parameters, training paramers, and hyperopt search space
(2) Tunes config based on resource constraints
(3) Runs hyperparameter optimization experiment
"""
from typing import Dict, Union

import numpy as np
import pandas as pd
import warnings
from ludwig.automl.base_config import _create_default_config
from ludwig.constants import COMBINER, TYPE
from ludwig.hyperopt.run import hyperopt

try:
import dask.dataframe as dd
import ray
except ImportError:
raise ImportError(
' ray is not installed. '
'In order to use auto_train please run '
'pip install ludwig[ray]'
)


OUTPUT_DIR = "."


def _model_select(default_configs):
"""
Performs model selection based on dataset.
Note: Current implementation returns tabnet by default. This will be
improved in subsequent iterations
"""
return default_configs['tabnet']


def auto_train(
dataset: Union[str, pd.DataFrame, dd.core.DataFrame],
target: str,
time_limit_s: Union[int, float],
output_dir: str = OUTPUT_DIR,
config=None,
):
"""
Main auto train API that first builds configs for each model type
(e.g. concat, tabnet, transformer). Then selects model based on dataset
attributes. And finally runs a hyperparameter optimization experiment.
All batch and learning rate tuning is done @ training time.
# Inputs
:param dataset: (str) filepath to dataset.
:param target_name: (str) name of target feature
:param time_limit_s: (int, float) total time allocated to auto_train. acts
as the stopping parameter
# Returns
:return: (str) path to best trained model
"""
if config is None:
config = create_auto_config(dataset, target, time_limit_s)
model_name = config[COMBINER][TYPE]
hyperopt_results = _train(config, dataset,
output_dir, model_name=model_name)
experiment_analysis = hyperopt_results.experiment_analysis
# catch edge case where metric_score is nan
# TODO (ASN): Decide how we want to proceed if at least one trial has
# completed
for trial in hyperopt_results.ordered_trials:
if np.isnan(trial.metric_score):
warnings.warn(
"There was an error running the experiment. "
"A trial failed to start. "
"Consider increasing the time budget for experiment. "
)

autotrain_results = {
'path_to_best_model': experiment_analysis.best_checkpoint,
'trial_id': "_".join(experiment_analysis.best_logdir.split("/")[-1].split("_")[1:])
}
return autotrain_results


def create_auto_config(dataset, target, time_limit_s) -> dict:
default_configs = _create_default_config(dataset, target, time_limit_s)
model_config = _model_select(default_configs)
return model_config


def _train(
config: Dict,
dataset: Union[str, pd.DataFrame, dd.core.DataFrame],
output_dir: str,
model_name: str
):
hyperopt_results = hyperopt(
config,
dataset=dataset,
output_directory=output_dir,
model_name=model_name
)
return hyperopt_results
Loading

0 comments on commit c3fffea

Please sign in to comment.