Skip to content

Commit

Permalink
Fixed auto eval batch size when train batch size is set (#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Oct 25, 2021
1 parent 5300ebb commit a06fd57
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 40 deletions.
24 changes: 18 additions & 6 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@

from ludwig.backend import Backend, initialize_backend
from ludwig.callbacks import Callback
from ludwig.constants import FULL, PREPROCESSING, TEST, TRAINING, VALIDATION, LEARNING_RATE, BATCH_SIZE, AUTO
from ludwig.constants import FULL, PREPROCESSING, TEST, TRAINING, VALIDATION, LEARNING_RATE, BATCH_SIZE, AUTO, \
EVAL_BATCH_SIZE
from ludwig.data.dataset.base import Dataset
from ludwig.data.postprocessing import convert_predictions, postprocess
from ludwig.data.preprocessing import (load_metadata,
Expand Down Expand Up @@ -504,24 +505,35 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
)

# auto tune batch size
if self.config[TRAINING][BATCH_SIZE] == AUTO:
if self.config[TRAINING][BATCH_SIZE] == AUTO or \
self.config[TRAINING][EVAL_BATCH_SIZE] == AUTO:
# TODO (ASN): add support for substitute_with_max parameter
tuned_batch_size = trainer.tune_batch_size(
self.config,
training_set,
random_seed=random_seed
)
self.config[TRAINING][BATCH_SIZE] = tuned_batch_size

# TODO(travis): pass these in as args to trainer when we call train,
# to avoid setting state on possibly remote trainer
if self.config[TRAINING][BATCH_SIZE] == AUTO:
self.config[TRAINING][BATCH_SIZE] = tuned_batch_size
trainer.batch_size = tuned_batch_size

if self.config[TRAINING][EVAL_BATCH_SIZE] == AUTO:
self.config[TRAINING][EVAL_BATCH_SIZE] = tuned_batch_size
trainer.eval_batch_size = tuned_batch_size

# auto tune learning rate
if self.config[TRAINING][LEARNING_RATE] == AUTO:
new_learning_rate = trainer.tune_learning_rate(
tuned_learning_rate = trainer.tune_learning_rate(
self.config,
LudwigModel.create_model(self.config, random_seed),
training_set,
random_seed=random_seed
)
self.config[TRAINING][LEARNING_RATE] = new_learning_rate
self.config[TRAINING][LEARNING_RATE] = tuned_learning_rate
trainer.learning_rate = tuned_learning_rate

# train model
if self.backend.is_coordinator():
Expand Down Expand Up @@ -1136,7 +1148,7 @@ def experiment(
f"Using validation set instead")

if eval_set is not None:
if self.config[TRAINING]['eval_batch_size'] > 0:
if self.config[TRAINING]['eval_batch_size']:
batch_size = self.config[TRAINING]['eval_batch_size']
else:
batch_size = self.config[TRAINING]['batch_size']
Expand Down
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
SRC = 'dataset_src'

BATCH_SIZE = 'batch_size'
EVAL_BATCH_SIZE = 'eval_batch_size'
LEARNING_RATE = 'learning_rate'
AUTO = 'auto'
CONFIG = 'config'
Expand Down
70 changes: 37 additions & 33 deletions ludwig/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from re import M
import numpy as np
import tensorflow as tf
from typing import Dict, Any

from ludwig.constants import COMBINED, LOSS, TEST, TRAINING, TYPE, VALIDATION
from ludwig.data.dataset.base import Dataset
from ludwig.globals import (MODEL_HYPERPARAMETERS_FILE_NAME,
MODEL_WEIGHTS_FILE_NAME,
TRAINING_CHECKPOINTS_DIR_PATH,
Expand Down Expand Up @@ -114,7 +117,7 @@ def __init__(
decay_steps=10000,
staircase=False,
batch_size=128,
eval_batch_size=0,
eval_batch_size=None,
should_shuffle=True,
shuffle_buffer_size=None,
bucketing_field=None,
Expand Down Expand Up @@ -245,7 +248,7 @@ def __init__(
self.decay_steps = decay_steps
self.staircase = staircase
self.batch_size = batch_size
self.eval_batch_size = batch_size if eval_batch_size < 1 else eval_batch_size
self.eval_batch_size = batch_size if eval_batch_size is None else eval_batch_size
self.should_shuffle = should_shuffle
self.shuffle_buffer_size = shuffle_buffer_size
self.bucketing_field = bucketing_field
Expand Down Expand Up @@ -337,14 +340,15 @@ def train_for_tuning(
self,
model,
dataset,
total_steps=3,
batch_size: int,
total_steps: int = 3,
):
""" function to be used by tune_batch_size """
with dataset.initialize_batcher(
batch_size=self.batch_size,
should_shuffle=self.should_shuffle,
shuffle_buffer_size=self.shuffle_buffer_size,
horovod=self.horovod
batch_size=batch_size,
should_shuffle=False,
shuffle_buffer_size=0,
horovod=None
) as batcher:

step_count = 0
Expand Down Expand Up @@ -372,19 +376,19 @@ def tune_learning_rate(
self,
config,
model,
training_set,
training_set: Dataset,
random_seed: int = default_random_seed,
min_lr: float = 1e-8,
max_lr: float = 1.0,
total_training_steps: int = 100,
mode: str = "exponential",
early_stop_threshold: int = 3,
beta: float = 0.98
):
) -> float:
# TODO (ASN): Circle back on how we want to set default placeholder value
# Currently, since self.learning_rate is originally set to auto, we provide a
# placeholder starting value (namely, .001)
self.learning_rate = 0.001
learning_rate = 0.001

current_learning_rate = min_lr
losses = []
Expand Down Expand Up @@ -420,7 +424,7 @@ def get_optimal_lr(losses, learning_rates, skip_begin: int = 10, skip_end: int =
) as batcher:
step_count = 0
while epoch < self.epochs and step_count < total_training_steps and not diverging:
batcher.set_epoch(epoch)
batcher.set_epoch(epoch, self.batch_size)
model.reset_metrics()
while not batcher.last_batch() and step_count < total_training_steps:
batch = batcher.next_batch()
Expand Down Expand Up @@ -470,17 +474,17 @@ def get_optimal_lr(losses, learning_rates, skip_begin: int = 10, skip_end: int =

optimal_lr = get_optimal_lr(losses, learning_rates)
if optimal_lr:
self.learning_rate = optimal_lr
return self.learning_rate
learning_rate = optimal_lr
return learning_rate

def tune_batch_size(
self,
config,
training_set,
config: Dict[str, Any],
training_set: Dataset,
random_seed: int = default_random_seed,
max_trials: int = 10,
halving_limit: int = 3
):
) -> int:
from ludwig.api import LudwigModel

def _is_valid_batch_size(batch_size):
Expand All @@ -489,10 +493,11 @@ def _is_valid_batch_size(batch_size):
# TODO (ASN) : Circle back on how we want to set default placeholder value
# Currently, since self.batch_size is originally set to auto, we provide a
# placeholder starting value (namely, 128)
self.batch_size = 128
batch_size = 128
skip_save_model = self.skip_save_model
skip_save_progress = self.skip_save_progress
skip_save_log = self.skip_save_log

# Set temporary values
self.skip_save_model = True
self.skip_save_progress = True
Expand All @@ -508,56 +513,54 @@ def _is_valid_batch_size(batch_size):
halving_count = 0
while halving_count < halving_limit:
gc.collect()

low = batch_size
prev_batch_size = batch_size
try:
# re-initalize model...
model = LudwigModel.create_model(config, random_seed)
self.train_for_tuning(model, training_set, total_steps=3)
self.train_for_tuning(model, training_set, batch_size, total_steps=3)
count += 1
if count >= max_trials:
break
low = self.batch_size
prev_batch_size = self.batch_size
if high:
if high - low <= 1:
break
midval = (high + low) // 2
self.batch_size = midval
batch_size = midval
else:
self.batch_size *= 2 # double batch size
batch_size *= 2 # double batch size

if self.batch_size == prev_batch_size:
if batch_size == prev_batch_size:
break

except tf.errors.ResourceExhaustedError as e:
gc.collect()
high = self.batch_size
high = batch_size
halving_count += 1
midval = (high + low) // 2
self.batch_size = midval
batch_size = midval
if high - low <= 1:
break

# make sure that batch size is valid (e.g. less than size of ds)
if not _is_valid_batch_size(self.batch_size):
self.batch_size = min(self.batch_size, len(training_set))
if not _is_valid_batch_size(batch_size):
batch_size = min(batch_size, len(training_set))

# edge case where bs is no longer increasing
if self.batch_size == prev_batch_size:
if batch_size == prev_batch_size:
break

finally:
# Restore original parameters to defaults
# self.epochs = original_epochs
self.skip_save_model = skip_save_model
self.skip_save_progress = skip_save_progress
self.skip_save_log = skip_save_log

if self.eval_batch_size == "auto":
self.eval_batch_size = self.batch_size
finally:
# Turn eager mode off
tf.config.run_functions_eagerly(False)

return self.batch_size
return batch_size

def train(
self,
Expand Down Expand Up @@ -901,6 +904,7 @@ def train(
tables[COMBINED] = [[COMBINED, LOSS]]

# eval metrics on train
self.eval_batch_size = max(self.eval_batch_size, progress_tracker.batch_size)
self.evaluation(
model,
training_set,
Expand Down
2 changes: 1 addition & 1 deletion ludwig/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'regularization_lambda': 0,
'learning_rate': 0.001,
'batch_size': 128,
'eval_batch_size': 0,
'eval_batch_size': None,
'early_stop': 5,
'reduce_learning_rate_on_plateau': 0,
'reduce_learning_rate_on_plateau_patience': 5,
Expand Down
66 changes: 66 additions & 0 deletions tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import shutil
import tempfile

from ludwig.api import LudwigModel
from ludwig.constants import TRAINING, BATCH_SIZE, EVAL_BATCH_SIZE, LEARNING_RATE
from tests.integration_tests.utils import sequence_feature, category_feature, generate_data, LocalTestBackend


def test_tune_batch_size_and_lr(tmpdir):
with tempfile.TemporaryDirectory() as outdir:
input_features = [sequence_feature(reduce_output='sum')]
output_features = [category_feature(vocab_size=2, reduce_input='sum')]

csv_filename = os.path.join(tmpdir, 'training.csv')
data_csv = generate_data(input_features, output_features, csv_filename)
val_csv = shutil.copyfile(data_csv,
os.path.join(tmpdir, 'validation.csv'))
test_csv = shutil.copyfile(data_csv, os.path.join(tmpdir, 'test.csv'))

config = {
'input_features': input_features,
'output_features': output_features,
'combiner': {'type': 'concat', 'fc_size': 14},
'training': {
'epochs': 2,
'batch_size': 'auto',
'eval_batch_size': 'auto',
'learning_rate': 'auto',
},
}

model = LudwigModel(config, backend=LocalTestBackend())

# check preconditions
assert model.config[TRAINING][BATCH_SIZE] == 'auto'
assert model.config[TRAINING][EVAL_BATCH_SIZE] == 'auto'
assert model.config[TRAINING][LEARNING_RATE] == 'auto'

_, _, output_directory = model.train(
training_set=data_csv,
validation_set=val_csv,
test_set=test_csv,
output_directory=outdir
)

def check_postconditions(model):
# check batch size
assert model.config[TRAINING][BATCH_SIZE] != 'auto'
assert model.config[TRAINING][BATCH_SIZE] > 1

assert model.config[TRAINING][EVAL_BATCH_SIZE] != 'auto'
assert model.config[TRAINING][EVAL_BATCH_SIZE] > 1

assert model.config[TRAINING][BATCH_SIZE] == model.config[TRAINING][EVAL_BATCH_SIZE]

# check learning rate
assert model.config[TRAINING][LEARNING_RATE] != 'auto'
assert model.config[TRAINING][LEARNING_RATE] > 0

check_postconditions(model)

model = LudwigModel.load(os.path.join(output_directory, 'model'))

# loaded model should retain the tuned params
check_postconditions(model)

0 comments on commit a06fd57

Please sign in to comment.