Skip to content

Commit

Permalink
Fixed automl to work when combiner is not specified (#1293)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored and ShreyaR committed Sep 17, 2021
1 parent ee0f0c6 commit 29cdc9b
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 28 deletions.
4 changes: 2 additions & 2 deletions ludwig/automl/auto_tune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)

from ludwig.api import LudwigModel
from ludwig.automl.utils import get_available_resources
from ludwig.automl.utils import get_available_resources, get_model_name
from ludwig.data.preprocessing import preprocess_for_training
from ludwig.features.feature_registries import update_config_with_metadata
from ludwig.utils.defaults import merge_with_defaults
Expand Down Expand Up @@ -120,7 +120,7 @@ def memory_tune_config(config, dataset):
training_set_metadata = get_trainingset_metadata(raw_config, dataset)
modified_hyperparam_search_space = copy.deepcopy(
raw_config[HYPEROPT]['parameters'])
params_to_modify = RANKED_MODIFIABLE_PARAM_LIST[raw_config[COMBINER][TYPE]]
params_to_modify = RANKED_MODIFIABLE_PARAM_LIST[get_model_name(raw_config)]
param_list = list(params_to_modify.keys())
current_param_values = {}
max_memory = get_machine_memory()
Expand Down
10 changes: 6 additions & 4 deletions ludwig/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
(2) Tunes config based on resource constraints
(3) Runs hyperparameter optimization experiment
"""
from typing import Dict, Union

import os
import warnings
from typing import Dict, Union

import numpy as np
import pandas as pd

from ludwig.api import LudwigModel
from ludwig.automl.base_config import _create_default_config, DatasetInfo
from ludwig.automl.auto_tune_config import memory_tune_config
from ludwig.automl.utils import _ray_init
from ludwig.automl.utils import _ray_init, get_model_name
from ludwig.constants import COMBINER, TYPE
from ludwig.hyperopt.run import hyperopt

Expand Down Expand Up @@ -54,7 +56,7 @@ def best_trial_id(self) -> str:

@property
def best_model(self) -> LudwigModel:
return LudwigModel.load(self.path_to_best_model)
return LudwigModel.load(os.path.join(self.path_to_best_model, 'model'))


def auto_train(
Expand Down Expand Up @@ -146,7 +148,7 @@ def train_with_config(
:return: (AutoTrainResults) results containing hyperopt experiments and best model
"""
_ray_init()
model_name = config[COMBINER][TYPE]
model_name = get_model_name(config)
hyperopt_results = _train(
config,
dataset,
Expand Down
13 changes: 11 additions & 2 deletions ludwig/automl/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
from dataclasses import dataclass

from dataclasses import dataclass
from dataclasses_json import LetterCase, dataclass_json
from pandas import Series

from ludwig.constants import COMBINER, TYPE
from ludwig.utils.defaults import default_combiner_type

try:
import ray
except ImportError:
Expand Down Expand Up @@ -56,7 +59,7 @@ def avg_num_tokens(field: Series) -> int:
return avg_words


def get_available_resources():
def get_available_resources() -> dict:
# returns total number of gpus and cpus
resources = ray.cluster_resources()
gpus = resources.get('GPU', 0)
Expand All @@ -68,6 +71,12 @@ def get_available_resources():
return resources


def get_model_name(config: dict) -> str:
if COMBINER in config and TYPE in config[COMBINER]:
return config[COMBINER][TYPE]
return default_combiner_type


def _ray_init():
if ray.is_initialized():
return
Expand Down
8 changes: 7 additions & 1 deletion ludwig/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import copy
import logging

from ludwig.constants import *
Expand Down Expand Up @@ -186,13 +188,16 @@ def _set_proc_column(config: dict) -> None:


def _merge_hyperopt_with_training(config: dict) -> None:
if 'hyperopt' not in config or TRAINING not in config:
if 'hyperopt' not in config:
return

scheduler = config['hyperopt'].get('sampler', {}).get('scheduler')
if not scheduler:
return

if TRAINING not in config:
config[TRAINING] = {}

# Disable early stopping when using a scheduler. We achieve this by setting the parameter
# to -1, which ensures the condition to apply early stopping is never met.
training = config[TRAINING]
Expand Down Expand Up @@ -220,6 +225,7 @@ def _merge_hyperopt_with_training(config: dict) -> None:


def merge_with_defaults(config):
config = copy.deepcopy(config)
_perform_sanity_checks(config)
_set_feature_column(config)
_set_proc_column(config)
Expand Down
16 changes: 7 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
import os
import tempfile
import uuid

import pytest
Expand Down Expand Up @@ -51,10 +52,9 @@ def csv_filename():
temporary data. After the data is used, all the temporary data is deleted.
:return: None
"""
csv_filename = uuid.uuid4().hex[:10].upper() + '.csv'
yield csv_filename

delete_temporary_data(csv_filename)
with tempfile.TemporaryDirectory() as tmpdir:
csv_filename = os.path.join(tmpdir, uuid.uuid4().hex[:10].upper() + '.csv')
yield csv_filename


@pytest.fixture()
Expand All @@ -64,11 +64,9 @@ def yaml_filename():
a config file. After the test runs, this file will be deleted
:return: None
"""
yaml_filename = 'model_def_' + uuid.uuid4().hex[:10].upper() + '.yaml'
yield yaml_filename

if os.path.exists(yaml_filename):
os.remove(yaml_filename)
with tempfile.TemporaryDirectory() as tmpdir:
yaml_filename = os.path.join(tmpdir, 'model_def_' + uuid.uuid4().hex[:10].upper() + '.yaml')
yield yaml_filename


def delete_temporary_data(csv_path):
Expand Down
14 changes: 4 additions & 10 deletions tests/integration_tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,21 @@ def test_experiment_seq_seq_model_def_file(csv_filename, yaml_filename):
)


def test_experiment_seq_seq_train_test_valid(csv_filename):
def test_experiment_seq_seq_train_test_valid(tmpdir):
# seq-to-seq test to use train, test, validation files
input_features = [text_feature(reduce_output=None, encoder='rnn')]
output_features = [
text_feature(reduce_input=None, vocab_size=3, decoder='tagger')
]

train_csv = generate_data(
input_features, output_features, 'tr_' + csv_filename
input_features, output_features, os.path.join(tmpdir, 'train.csv')
)
test_csv = generate_data(
input_features, output_features, 'test_' + csv_filename, 20
input_features, output_features, os.path.join(tmpdir, 'test.csv'), 20
)
valdation_csv = generate_data(
input_features, output_features, 'val_' + csv_filename, 20
input_features, output_features, os.path.join(tmpdir, 'val.csv'), 20
)

run_experiment(
Expand All @@ -184,12 +184,6 @@ def test_experiment_seq_seq_train_test_valid(csv_filename):
validation_set=valdation_csv
)

# Delete the temporary data created
# This test is saving the processed data to hdf5
for prefix in ['tr_', 'test_', 'val_']:
if os.path.isfile(prefix + csv_filename):
os.remove(prefix + csv_filename)


def test_experiment_multi_input_intent_classification(csv_filename):
# Multiple inputs, Single category output
Expand Down
87 changes: 87 additions & 0 deletions tests/ludwig/utils/test_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import copy

import pytest

from ludwig.constants import TRAINING, HYPEROPT
from ludwig.utils.defaults import merge_with_defaults, default_training_params
from tests.integration_tests.utils import binary_feature, category_feature, \
numerical_feature, text_feature, sequence_feature, vector_feature


HYPEROPT_CONFIG = {
"parameters": {
"training.learning_rate": {
"space": "loguniform",
"lower": 0.001,
"upper": 0.1,
},
"combiner.num_fc_layers": {
"space": "randint",
"lower": 2,
"upper": 6
},
"utterance.cell_type": {
"space": "grid_search",
"values": ["rnn", "gru"]
},
"utterance.bidirectional": {
"space": "choice",
"categories": [True, False]
},
"utterance.fc_layers": {
"space": "choice",
"categories": [
[{"fc_size": 512}, {"fc_size": 256}],
[{"fc_size": 512}],
[{"fc_size": 256}],
]
}
},
"sampler": {"type": "ray"},
"executor": {"type": "ray"},
"goal": "minimize"
}

SCHEDULER = {'type': 'async_hyperband', 'time_attr': 'time_total_s'}

default_early_stop = default_training_params['early_stop']


@pytest.mark.parametrize("use_train,use_hyperopt_scheduler", [
(True,True),
(False,True),
(True,False),
(False,False),
])
def test_merge_with_defaults_early_stop(use_train, use_hyperopt_scheduler):
all_input_features = [
binary_feature(),
category_feature(),
numerical_feature(),
text_feature(),
]
all_output_features = [
category_feature(),
sequence_feature(),
vector_feature(),
]

# validate config with all features
config = {
'input_features': all_input_features,
'output_features': all_output_features,
HYPEROPT: HYPEROPT_CONFIG,
}
config = copy.deepcopy(config)

if use_train:
config[TRAINING] = {'batch_size': '42'}

if use_hyperopt_scheduler:
# hyperopt scheduler cannot be used with early stopping
config[HYPEROPT]['sampler']['scheduler'] = SCHEDULER

merged_config = merge_with_defaults(config)

expected = -1 if use_hyperopt_scheduler else default_early_stop
assert merged_config[TRAINING]['early_stop'] == expected

0 comments on commit 29cdc9b

Please sign in to comment.