Skip to content

Commit 510455f

Browse files
authoredAug 20, 2021
[automl] Memory Aware Config Tuning (#1257)
1 parent 30d1fdf commit 510455f

File tree

3 files changed

+178
-1
lines changed

3 files changed

+178
-1
lines changed
 

‎ludwig/automl/auto_tune_config.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import copy
2+
from collections import OrderedDict
3+
4+
import psutil
5+
import ray
6+
7+
try:
8+
import GPUtil
9+
except ImportError:
10+
raise ImportError(
11+
' ray is not installed. '
12+
'In order to use auto_train please run '
13+
'pip install ludwig[ray]'
14+
)
15+
16+
from ludwig.api import LudwigModel
17+
from ludwig.automl.utils import get_available_resources
18+
from ludwig.data.preprocessing import preprocess_for_training
19+
from ludwig.features.feature_registries import update_config_with_metadata
20+
from ludwig.utils.defaults import merge_with_defaults
21+
from ludwig.constants import COMBINER, HYPEROPT, BATCH_SIZE, TRAINING, TYPE, PREPROCESSING, SPACE
22+
23+
# maps variable search space that can be modified to minimum permissible value for the range
24+
RANKED_MODIFIABLE_PARAM_LIST = {
25+
'tabnet': OrderedDict({
26+
'training.batch_size': 32,
27+
'combiner.size': 8,
28+
'combiner.output_size': 8,
29+
}),
30+
'concat': OrderedDict({
31+
'training.batch_size': 32,
32+
'combiner.fc_size': 64,
33+
'combiner.num_fc_layers': 1,
34+
35+
}),
36+
'tabtransformer': OrderedDict({
37+
'training.batch_size': 32,
38+
'combiner.num_heads:': 4,
39+
'combiner.output_size': 8,
40+
'combiner.num_layers': 4,
41+
'combiner.num_fc_layers': 1,
42+
}),
43+
}
44+
45+
46+
BYTES_PER_MiB = 1048576
47+
48+
49+
def get_trainingset_metadata(config, dataset):
50+
(_, _, _, training_set_metadata) = preprocess_for_training(
51+
config,
52+
dataset=dataset,
53+
preprocessing_params=config[PREPROCESSING])
54+
return training_set_metadata
55+
56+
57+
def get_machine_memory():
58+
59+
if ray.is_initialized(): # using ray cluster
60+
@ray.remote(num_gpus=1)
61+
def get_remote_gpu():
62+
gpus = GPUtil.getGPUs()
63+
total_mem_mb = gpus[0].memory_total
64+
return total_mem_mb * BYTES_PER_MiB
65+
66+
@ray.remote(num_cpus=1)
67+
def get_remote_cpu():
68+
total_mem = psutil.virtual_memory().total
69+
return total_mem
70+
71+
resources = get_available_resources() # check if cluster has GPUS
72+
73+
if resources['gpu'] > 0:
74+
machine_mem = ray.get(get_remote_gpu.remote())
75+
else:
76+
machine_mem = ray.get(get_remote_cpu.remote())
77+
else: # not using ray cluster
78+
if GPUtil.getGPUs():
79+
machine_mem = GPUtil.getGPUs()[0].memory_total * BYTES_PER_MiB
80+
else:
81+
machine_mem = psutil.virtual_memory().total
82+
83+
return machine_mem
84+
85+
86+
def compute_memory_usage(config, training_set_metadata) -> int:
87+
update_config_with_metadata(config, training_set_metadata)
88+
lm = LudwigModel.create_model(config)
89+
lm.get_connected_model()
90+
model_tensors = lm.collect_weights()
91+
total_size = 0
92+
batch_size = config[TRAINING][BATCH_SIZE]
93+
for tnsr in model_tensors:
94+
total_size += tnsr[1].numpy().size * batch_size
95+
total_bytes = total_size * 32 # assumes 32-bit precision
96+
return total_bytes
97+
98+
99+
def sub_new_params(config: dict, new_param_vals: dict):
100+
new_config = copy.deepcopy(config)
101+
for param, val in new_param_vals.items():
102+
config_section = param.split(".")[0]
103+
param_name = param.split(".")[1]
104+
new_config[config_section][param_name] = val
105+
return new_config
106+
107+
108+
def get_new_params(current_param_values, hyperparam_search_space, params_to_modify):
109+
for param, _ in params_to_modify.items():
110+
if hyperparam_search_space[param][SPACE] == "choice":
111+
current_param_values[param] = hyperparam_search_space[param]['categories'][-1]
112+
else:
113+
current_param_values[param] = hyperparam_search_space[param]['upper']
114+
return current_param_values
115+
116+
117+
def memory_tune_config(config, dataset):
118+
fits_in_memory = False
119+
raw_config = merge_with_defaults(config)
120+
training_set_metadata = get_trainingset_metadata(raw_config, dataset)
121+
modified_hyperparam_search_space = copy.deepcopy(
122+
raw_config[HYPEROPT]['parameters'])
123+
params_to_modify = RANKED_MODIFIABLE_PARAM_LIST[raw_config[COMBINER][TYPE]]
124+
param_list = list(params_to_modify.keys())
125+
current_param_values = {}
126+
max_memory = get_machine_memory()
127+
128+
while param_list is not None:
129+
# compute memory utilization
130+
current_param_values = get_new_params(
131+
current_param_values, modified_hyperparam_search_space, params_to_modify)
132+
temp_config = sub_new_params(raw_config, current_param_values)
133+
if compute_memory_usage(temp_config, training_set_metadata) < max_memory:
134+
fits_in_memory = True
135+
break
136+
# check if we have exhausted tuning of current param (e.g. we can no longer reduce the param value)
137+
param, min_value = param_list[0], params_to_modify[param_list[0]]
138+
139+
if param in modified_hyperparam_search_space.keys():
140+
param_space = modified_hyperparam_search_space[param]["space"]
141+
if param_space == "choice":
142+
if len(modified_hyperparam_search_space[param]['categories']) > 2 and \
143+
modified_hyperparam_search_space[param]['categories'][-2] > min_value:
144+
modified_hyperparam_search_space[param][
145+
'categories'] = modified_hyperparam_search_space[param]['categories'][:-1]
146+
else:
147+
param_list.pop(0) # exhausted reduction of this parameter
148+
else:
149+
# reduce by 10%
150+
upper_bound, lower_bound = modified_hyperparam_search_space[param][
151+
"upper"], modified_hyperparam_search_space[param]["lower"]
152+
reduction_val = (upper_bound - lower_bound) * 0.1
153+
new_upper_bound = upper_bound - reduction_val
154+
if (new_upper_bound) > lower_bound and new_upper_bound > min_value:
155+
modified_hyperparam_search_space[param]["upper"] = new_upper_bound
156+
else:
157+
param_list.pop(0) # exhausted reduction of this parameter
158+
else:
159+
param_list.pop(0) # param not in hyperopt search space
160+
161+
modified_config = copy.deepcopy(config)
162+
163+
modified_config[HYPEROPT]["parameters"] = modified_hyperparam_search_space
164+
return modified_config, fits_in_memory

‎ludwig/automl/automl.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ludwig.api import LudwigModel
1818
from ludwig.automl.base_config import _create_default_config, DatasetInfo
19+
from ludwig.automl.auto_tune_config import memory_tune_config
1920
from ludwig.automl.utils import _ray_init
2021
from ludwig.constants import COMBINER, TYPE
2122
from ludwig.hyperopt.run import hyperopt
@@ -61,6 +62,7 @@ def auto_train(
6162
target: str,
6263
time_limit_s: Union[int, float],
6364
output_directory: str = OUTPUT_DIR,
65+
tune_for_memory: bool = False,
6466
**kwargs
6567
) -> AutoTrainResults:
6668
"""
@@ -81,7 +83,8 @@ def auto_train(
8183
# Returns
8284
:return: (AutoTrainResults) results containing hyperopt experiments and best model
8385
"""
84-
config = create_auto_config(dataset, target, time_limit_s)
86+
config = create_auto_config(
87+
dataset, target, time_limit_s, tune_for_memory, **kwargs)
8588
return train_with_config(
8689
dataset,
8790
config,
@@ -94,6 +97,7 @@ def create_auto_config(
9497
dataset: Union[str, pd.DataFrame, dd.core.DataFrame, DatasetInfo],
9598
target: str,
9699
time_limit_s: Union[int, float],
100+
tune_for_memory: bool,
97101
) -> dict:
98102
"""
99103
Returns an auto-generated Ludwig config with the intent of training
@@ -111,6 +115,13 @@ def create_auto_config(
111115
"""
112116
default_configs = _create_default_config(dataset, target, time_limit_s)
113117
model_config = _model_select(default_configs)
118+
if tune_for_memory:
119+
if ray.is_initialized():
120+
model_config, _ = ray.get(ray.remote(num_cpus=1)(
121+
memory_tune_config
122+
).remote(model_config, dataset))
123+
else:
124+
model_config, _ = memory_tune_config(model_config, dataset)
114125
return model_config
115126

116127

‎requirements_ray.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
ray[default,tune]
22
pickle5
33
tensorboardX<2.3
4+
GPUtil
5+

0 commit comments

Comments
 (0)
Please sign in to comment.