diff --git a/ludwig/api.py b/ludwig/api.py index 3302f51d4d8..381e6297620 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -613,11 +613,7 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): random_seed=random_seed, ) as trainer: # auto tune batch size - if ( - self.config_obj.trainer.to_dict().get(BATCH_SIZE, None) == AUTO - or self.config_obj.trainer.to_dict().get(EVAL_BATCH_SIZE, None) == AUTO - ): - self._tune_batch_size(trainer, training_set, random_seed=random_seed) + self._tune_batch_size(trainer, training_set, random_seed=random_seed) # train model if self.backend.is_coordinator(): @@ -795,35 +791,60 @@ def train_online( config=self.config_obj.trainer, model=self.model, random_seed=random_seed ) - if ( - self.config_obj.trainer.to_dict().get(BATCH_SIZE, None) == AUTO - or self.config_obj.trainer.to_dict().get(EVAL_BATCH_SIZE, None) == AUTO - ): - self._tune_batch_size(self._online_trainer, dataset, random_seed=random_seed) + self._tune_batch_size(self._online_trainer, dataset, random_seed=random_seed) self.model = self._online_trainer.train_online(training_dataset) def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_seed): + if not self.config_obj.trainer.can_tune_batch_size(): + # Models like GBMs don't have batch sizes to be tuned + return + + # Render the batch size and gradient accumulation steps prior to batch size tuning. This is needed in the event + # the effective_batch_size and gradient_accumulation_steps are set explicitly, but batch_size is AUTO. In this + # case, we can infer the batch_size directly without tuning. + num_workers = self.backend.num_training_workers + self.config_obj.trainer.update_batch_size_grad_accum(num_workers) + # TODO (ASN): add support for substitute_with_max parameter # TODO(travis): detect train and eval batch sizes separately (enable / disable gradients) - if self.backend.supports_batch_size_tuning(): - tuned_batch_size = trainer.tune_batch_size(self.config_obj.to_dict(), dataset, random_seed=random_seed) - else: - logger.warning( - f"Backend {self.backend.BACKEND_TYPE} does not support batch size tuning, " - f"using fallback batch size {FALLBACK_BATCH_SIZE}." - ) - tuned_batch_size = FALLBACK_BATCH_SIZE - - # TODO(travis): pass these in as args to trainer when we call train, - # to avoid setting state on possibly remote trainer if self.config_obj.trainer.batch_size == AUTO: + if self.backend.supports_batch_size_tuning(): + tuned_batch_size = trainer.tune_batch_size( + self.config_obj.to_dict(), dataset, random_seed=random_seed, tune_for_training=True + ) + else: + logger.warning( + f"Backend {self.backend.BACKEND_TYPE} does not support batch size tuning, " + f"using fallback training batch size {FALLBACK_BATCH_SIZE}." + ) + tuned_batch_size = FALLBACK_BATCH_SIZE + + # TODO(travis): pass these in as args to trainer when we call train, + # to avoid setting state on possibly remote trainer self.config_obj.trainer.batch_size = tuned_batch_size - trainer.batch_size = tuned_batch_size + + # Re-render the gradient_accumulation_steps to account for the explicit batch size. + self.config_obj.trainer.update_batch_size_grad_accum(num_workers) if self.config_obj.trainer.eval_batch_size in {AUTO, None}: + if self.backend.supports_batch_size_tuning(): + tuned_batch_size = trainer.tune_batch_size( + self.config_obj.to_dict(), dataset, random_seed=random_seed, tune_for_training=False + ) + else: + logger.warning( + f"Backend {self.backend.BACKEND_TYPE} does not support batch size tuning, " + f"using fallback eval batch size {FALLBACK_BATCH_SIZE}." + ) + tuned_batch_size = FALLBACK_BATCH_SIZE + self.config_obj.trainer.eval_batch_size = tuned_batch_size - trainer.eval_batch_size = tuned_batch_size + + # Update trainer params separate to config params for backends with stateful trainers + trainer.batch_size = self.config_obj.trainer.batch_size + trainer.eval_batch_size = self.config_obj.trainer.eval_batch_size + trainer.gradient_accumulation_steps = self.config_obj.trainer.gradient_accumulation_steps def predict( self, diff --git a/ludwig/backend/base.py b/ludwig/backend/base.py index 39f7a97d6e8..586bba17780 100644 --- a/ludwig/backend/base.py +++ b/ludwig/backend/base.py @@ -118,6 +118,11 @@ def read_binary_files(self, column: Series, map_fn: Optional[Callable] = None) - def num_nodes(self) -> int: raise NotImplementedError() + @property + @abstractmethod + def num_training_workers(self) -> int: + raise NotImplementedError() + @abstractmethod def get_available_resources(self) -> Resources: raise NotImplementedError() @@ -250,6 +255,10 @@ def __init__(self, **kwargs): def num_nodes(self) -> int: return 1 + @property + def num_training_workers(self) -> int: + return 1 + def get_available_resources(self) -> Resources: return Resources(cpus=psutil.cpu_count(), gpus=torch.cuda.device_count()) @@ -308,6 +317,10 @@ def is_coordinator(self): @property def num_nodes(self) -> int: + return self._distributed.size() // self._distributed.local_size() + + @property + def num_training_workers(self) -> int: return self._distributed.size() def get_available_resources(self) -> Resources: diff --git a/ludwig/backend/ray.py b/ludwig/backend/ray.py index 96c784445b1..afed70707b8 100644 --- a/ludwig/backend/ray.py +++ b/ludwig/backend/ray.py @@ -250,6 +250,7 @@ def tune_batch_size_fn( training_set_metadata: TrainingSetMetadataDict = None, features: Dict[str, Dict] = None, remote_trainer_cls: Callable[[], Trainer] = None, + tune_for_training: bool = True, **kwargs, ): # Pin GPU before loading the model to prevent memory leaking onto other devices @@ -276,6 +277,7 @@ def on_best_batch_size_updated(best_batch_size: int, best_samples_per_sec: float train_shard, snapshot_weights=False, on_best_batch_size_updated=on_best_batch_size_updated, + tune_for_training=tune_for_training, **kwargs, ) session.report( @@ -539,6 +541,7 @@ def tune_batch_size( self, config: ModelConfigDict, training_set: RayDataset, + tune_for_training: bool = True, **kwargs, ) -> int: with create_runner(**self.trainer_kwargs) as runner: @@ -552,6 +555,7 @@ def tune_batch_size( ludwig_config=config, training_set_metadata=training_set.training_set_metadata, features=training_set.features, + tune_for_training=tune_for_training, **kwargs, ), exception_on_error=False, @@ -592,6 +596,14 @@ def eval_batch_size(self) -> int: def eval_batch_size(self, value: int): self.config.eval_batch_size = value + @property + def gradient_accumulation_steps(self) -> int: + return self.config.gradient_accumulation_steps + + @gradient_accumulation_steps.setter + def gradient_accumulation_steps(self, value: int): + self.config.gradient_accumulation_steps = value + @property def resources_per_worker(self) -> Dict[str, Any]: trainer_kwargs = get_trainer_kwargs(**self.trainer_kwargs) @@ -876,7 +888,7 @@ def __init__( super().__init__(dataset_manager=RayDatasetManager(self), **kwargs) self._preprocessor_kwargs = preprocessor_kwargs or {} self._df_engine = _get_df_engine(processor) - self._horovod_kwargs = trainer or {} + self._distributed_kwargs = trainer or {} self._pytorch_kwargs = {} self._data_loader_kwargs = loader or {} self._preprocessor_pg = None @@ -943,7 +955,7 @@ def create_trainer(self, model: BaseModel, **kwargs) -> "BaseTrainer": # noqa: all_kwargs = { "model": model, - "trainer_kwargs": self._horovod_kwargs, + "trainer_kwargs": self._distributed_kwargs, "data_loader_kwargs": self._data_loader_kwargs, "executable_kwargs": executable_kwargs, } @@ -956,18 +968,18 @@ def create_predictor(self, model: BaseModel, **kwargs): return RayPredictor( model, self.df_engine, - self._horovod_kwargs, + self._distributed_kwargs, self._data_loader_kwargs, **executable_kwargs, ) @property def distributed_kwargs(self): - return self._horovod_kwargs + return self._distributed_kwargs @distributed_kwargs.setter def distributed_kwargs(self, value): - self._horovod_kwargs = value + self._distributed_kwargs = value @property def df_engine(self): @@ -1068,6 +1080,11 @@ def num_nodes(self) -> int: return 1 return len(ray.nodes()) + @property + def num_training_workers(self) -> int: + trainer_kwargs = get_trainer_kwargs(**self._distributed_kwargs) + return trainer_kwargs["num_workers"] + def get_available_resources(self) -> Resources: resources = ray.cluster_resources() return Resources(cpus=resources.get("CPU", 0), gpus=resources.get("GPU", 0)) @@ -1122,7 +1139,7 @@ def batch_transform(self, df: DataFrame, batch_size: int, transform_fn: Callable return self.df_engine.from_ray_dataset(ds) def _get_transform_kwargs(self) -> Dict[str, Any]: - trainer_kwargs = get_trainer_kwargs(**self._horovod_kwargs) + trainer_kwargs = get_trainer_kwargs(**self._distributed_kwargs) resources_per_worker = trainer_kwargs.get("resources_per_worker", {}) num_gpus = resources_per_worker.get("GPU", 0) num_cpus = resources_per_worker.get("CPU", (1 if num_gpus == 0 else 0)) diff --git a/ludwig/distributed/horovod.py b/ludwig/distributed/horovod.py index 3acad2c478b..80ea4f784cc 100644 --- a/ludwig/distributed/horovod.py +++ b/ludwig/distributed/horovod.py @@ -13,6 +13,7 @@ from torch import nn from torch.optim import Optimizer +from ludwig.constants import AUTO from ludwig.distributed.base import DistributedStrategy from ludwig.modules.optimization_modules import create_optimizer from ludwig.utils.horovod_utils import gather_all_tensors, is_distributed_available @@ -35,10 +36,13 @@ def prepare( base_learning_rate: float, ) -> Tuple[nn.Module, Optimizer]: optimizer = create_optimizer(model, trainer_config.optimizer, base_learning_rate) + grad_accum_steps = ( + trainer_config.gradient_accumulation_steps if trainer_config.gradient_accumulation_steps != AUTO else 1 + ) dist_optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), - backward_passes_per_step=trainer_config.gradient_accumulation_steps, + backward_passes_per_step=grad_accum_steps, ) return model, dist_optimizer diff --git a/ludwig/schema/metadata/configs/trainer.yaml b/ludwig/schema/metadata/configs/trainer.yaml index cfba30b4d42..59264676a09 100644 --- a/ludwig/schema/metadata/configs/trainer.yaml +++ b/ludwig/schema/metadata/configs/trainer.yaml @@ -1,4 +1,11 @@ ecd: + effective_batch_size: + commonly_used: true + expected_impact: 2 + related_parameters: + - batch_size + suggested_values: auto + ui_display_name: Effective Batch Size batch_size: commonly_used: true default_value_reasoning: Not too big, not too small. diff --git a/ludwig/schema/trainer.py b/ludwig/schema/trainer.py index 60cdb4caace..7082b1f1849 100644 --- a/ludwig/schema/trainer.py +++ b/ludwig/schema/trainer.py @@ -5,15 +5,7 @@ from packaging.version import parse as parse_version from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import ( - DEFAULT_BATCH_SIZE, - LOSS, - MAX_POSSIBLE_BATCH_SIZE, - MODEL_ECD, - MODEL_GBM, - MODEL_LLM, - TRAINING, -) +from ludwig.constants import AUTO, LOSS, MAX_POSSIBLE_BATCH_SIZE, MODEL_ECD, MODEL_GBM, MODEL_LLM, TRAINING from ludwig.error import ConfigValidationError from ludwig.schema import utils as schema_utils from ludwig.schema.lr_scheduler import LRSchedulerConfig, LRSchedulerDataclassField @@ -92,6 +84,9 @@ class BaseTrainerConfig(schema_utils.BaseMarshmallowConfig, ABC): ), ) + def can_tune_batch_size(self) -> bool: + return True + @DeveloperAPI @register_trainer_schema(MODEL_ECD) @@ -105,6 +100,38 @@ def __post_init__(self): "Trainer param `compile: true` requires PyTorch 2.0.0 or higher. Please upgrade PyTorch and try again." ) + if self.effective_batch_size != AUTO and self.max_batch_size < self.effective_batch_size: + raise ConfigValidationError( + f"`max_batch_size` ({self.max_batch_size}) must be greater than or equal to " + f"`effective_batch_size` ({self.effective_batch_size})." + ) + + if self.effective_batch_size != AUTO and self.batch_size != AUTO: + if self.effective_batch_size < self.batch_size: + raise ConfigValidationError( + f"`effective_batch_size` ({self.effective_batch_size}) " + f"must be greater than or equal to `batch_size` ({self.batch_size})." + ) + + if self.effective_batch_size % self.batch_size != 0: + raise ConfigValidationError( + f"`effective_batch_size` ({self.effective_batch_size}) " + f"must be divisible by `batch_size` ({self.batch_size})." + ) + + if self.effective_batch_size != AUTO and self.gradient_accumulation_steps != AUTO: + if self.effective_batch_size < self.gradient_accumulation_steps: + raise ConfigValidationError( + f"`effective_batch_size` ({self.effective_batch_size}) must be greater than or equal to " + f"`gradient_accumulation_steps` ({self.gradient_accumulation_steps})." + ) + + if self.effective_batch_size % self.gradient_accumulation_steps != 0: + raise ConfigValidationError( + f"`effective_batch_size` ({self.effective_batch_size}) must be divisible by " + f"`gradient_accumulation_steps` ({self.gradient_accumulation_steps})." + ) + learning_rate: Union[float, str] = schema_utils.OneOfOptionsField( default=0.001, allow_none=False, @@ -159,8 +186,27 @@ def __post_init__(self): parameter_metadata=TRAINER_METADATA[MODEL_ECD]["steps_per_checkpoint"], ) + effective_batch_size: Union[int, str] = schema_utils.OneOfOptionsField( + default=AUTO, + allow_none=False, + description=( + "The effective batch size is the total number of samples used to compute a single gradient update " + "to the model weights. This differs from `batch_size` by taking `gradient_accumulation_steps` and number " + "of training worker processes into account. In practice, " + "`effective_batch_size = batch_size * gradient_accumulation_steps * num_workers`. " + "If 'auto', the effective batch size is derivied implicitly from `batch_size`, but if set explicitly, then " + "one of `batch_size` or `gradient_accumulation_steps` must be set to something other than 'auto', and " + "consequently will be set following the formula given above." + ), + parameter_metadata=TRAINER_METADATA[MODEL_ECD]["effective_batch_size"], + field_options=[ + schema_utils.PositiveInteger(default=128, description="", allow_none=False), + schema_utils.StringOptions(options=["auto"], default="auto", allow_none=False), + ], + ) + batch_size: Union[int, str] = schema_utils.OneOfOptionsField( - default=DEFAULT_BATCH_SIZE, + default=AUTO, allow_none=False, description=( "The number of training examples utilized in one training step of the model. If ’auto’, the " @@ -185,6 +231,17 @@ def __post_init__(self): parameter_metadata=TRAINER_METADATA[MODEL_ECD]["max_batch_size"], ) + gradient_accumulation_steps: Union[int, str] = schema_utils.OneOfOptionsField( + default=AUTO, + allow_none=False, + description="Number of steps to accumulate gradients over before performing a weight update.", + parameter_metadata=TRAINER_METADATA[MODEL_ECD]["gradient_accumulation_steps"], + field_options=[ + schema_utils.PositiveInteger(default=1, description="", allow_none=False), + schema_utils.StringOptions(options=["auto"], default="auto", allow_none=False), + ], + ) + early_stop: int = schema_utils.IntegerRange( default=5, min=-1, @@ -343,11 +400,10 @@ def __post_init__(self): parameter_metadata=TRAINER_METADATA[MODEL_ECD]["compile"], ) - gradient_accumulation_steps: int = schema_utils.PositiveInteger( - default=1, - description="Number of steps to accumulate gradients over before performing a weight update.", - parameter_metadata=TRAINER_METADATA[MODEL_ECD]["gradient_accumulation_steps"], - ) + def update_batch_size_grad_accum(self, num_workers: int): + from ludwig.utils.trainer_utils import get_rendered_batch_size_grad_accum + + self.batch_size, self.gradient_accumulation_steps = get_rendered_batch_size_grad_accum(self, num_workers) @DeveloperAPI @@ -705,6 +761,9 @@ class GBMTrainerConfig(BaseTrainerConfig): parameter_metadata=TRAINER_METADATA[MODEL_GBM]["feature_pre_filter"], ) + def can_tune_batch_size(self) -> bool: + return False + @DeveloperAPI @ludwig_dataclass @@ -795,6 +854,9 @@ class NoneTrainerConfig(LLMTrainerConfig): parameter_metadata=TRAINER_METADATA[MODEL_LLM]["type"], ) + def can_tune_batch_size(self) -> bool: + return False + @DeveloperAPI @register_llm_trainer_schema("finetune") diff --git a/ludwig/trainers/base.py b/ludwig/trainers/base.py index baa41f58ad8..1d56461aba4 100644 --- a/ludwig/trainers/base.py +++ b/ludwig/trainers/base.py @@ -26,6 +26,7 @@ def tune_batch_size( random_seed: int = default_random_seed, max_trials: int = 10, halving_limit: int = 3, + tune_for_training: bool = True, ) -> int: raise NotImplementedError() diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 7e6bb3d1294..4af8b705c9f 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -31,7 +31,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from ludwig.constants import LOSS, MAX_CPU_BATCH_SIZE, MINIMIZE, MODEL_ECD, TEST, TRAIN, TRAINING, VALIDATION +from ludwig.constants import AUTO, LOSS, MAX_CPU_BATCH_SIZE, MINIMIZE, MODEL_ECD, TEST, TRAIN, TRAINING, VALIDATION from ludwig.data.dataset.base import Dataset from ludwig.distributed.base import DistributedStrategy, LocalStrategy from ludwig.globals import ( @@ -139,6 +139,7 @@ def __init__( self.regularization_lambda = config.regularization_lambda self.regularization_type = config.regularization_type self.batch_size = config.batch_size + self.effective_batch_size = config.effective_batch_size self.max_batch_size = config.max_batch_size self.eval_batch_size = config.batch_size if config.eval_batch_size is None else config.eval_batch_size self.should_shuffle = config.should_shuffle @@ -154,7 +155,9 @@ def __init__( self.increase_batch_size_eval_metric = config.increase_batch_size_eval_metric self.increase_batch_size_eval_split = config.increase_batch_size_eval_split self.gradient_accumulation_steps = ( - config.gradient_accumulation_steps if self.distributed.allow_gradient_accumulation() else 1 + config.gradient_accumulation_steps + if self.distributed.allow_gradient_accumulation() and config.gradient_accumulation_steps != AUTO + else 1 ) self.resume = resume self.skip_save_model = skip_save_model @@ -168,12 +171,6 @@ def __init__( if self.device is None: self.device = get_torch_device() - base_learning_rate = config.learning_rate - if self.distributed: - lr_scale_fn = learning_rate_scale_fns[config.learning_rate_scaling] - base_learning_rate *= lr_scale_fn(self.distributed.size() * self.gradient_accumulation_steps) - self.base_learning_rate = base_learning_rate - self.model = model self.model.prepare_for_training() self.model = self.distributed.to_device(self.model) @@ -207,6 +204,12 @@ def __init__( self.original_sigint_handler = None def prepare(self): + base_learning_rate = self.config.learning_rate + if self.distributed: + lr_scale_fn = learning_rate_scale_fns[self.config.learning_rate_scaling] + base_learning_rate *= lr_scale_fn(self.distributed.size() * self.gradient_accumulation_steps) + self.base_learning_rate = base_learning_rate + self.dist_model, self.optimizer = self.distributed.prepare( self.compiled_model, self.config, @@ -372,6 +375,7 @@ def tune_batch_size( halving_limit: int = 3, snapshot_weights: bool = True, on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None, + tune_for_training: bool = True, ) -> int: logger.info("Tuning batch size...") skip_save_model = self.skip_save_model @@ -388,8 +392,19 @@ def tune_batch_size( max_batch_size = ( self.max_batch_size if torch.cuda.is_available() else min(self.max_batch_size, MAX_CPU_BATCH_SIZE) ) + + if self.effective_batch_size != AUTO: + # If an effective batch size is set, we must ensure that batch size tuning doesn't exceed it + max_batch_size = min(self.effective_batch_size, max_batch_size) + + if not tune_for_training: + # No need to save and restore model and optimizer states, as they aren't modified during predict + snapshot_weights = False + self.dist_model.train() # Sets model training mode. - evaluator = self._create_batch_size_evaluator() + evaluator = ( + self._create_batch_size_evaluator() if tune_for_training else self._create_predict_batch_size_evaluator() + ) with tempfile.TemporaryDirectory() as tmpdir: if snapshot_weights: # Save a snapshot of the model and optimizer state to restore later, as they will be modified @@ -403,12 +418,23 @@ def tune_batch_size( best_batch_size = evaluator.select_best_batch_size( len(training_set), max_batch_size, max_trials, self.is_coordinator() ) - return self.distributed.broadcast_object(best_batch_size) + best_batch_size = self.distributed.broadcast_object(best_batch_size) + + if tune_for_training: + # Update batch size / gradient accumulation before preparing the trainer. This is needed primarily + # for DeepSpeed, which needs to know the batch size and gradient accumulation steps before init + self.config.batch_size = best_batch_size + self.config.update_batch_size_grad_accum(self.distributed.size()) + self.batch_size = self.config.batch_size + self.gradient_accumulation_steps = self.config.gradient_accumulation_steps + + return best_batch_size finally: # Restore original parameters to defaults self.skip_save_model = skip_save_model self.skip_save_progress = skip_save_progress self.skip_save_log = skip_save_log + if snapshot_weights: # Restore the model weights prior to batch size tuning to undo any updates made to the weights if self.distributed.prepare_before_load(): @@ -438,6 +464,29 @@ def step(self, batch_size: int): return _TrainerBatchSizeEvaluator() + def _create_predict_batch_size_evaluator(self) -> BatchSizeEvaluator: + trainer = self + + class _PredictBatchSizeEvaluator(BatchSizeEvaluator): + def reset(self): + trainer.model.reset_metrics() + trainer.optimizer.zero_grad() + + def step(self, batch_size: int): + trainer.distributed.set_batch_size(trainer.dist_model, batch_size) + inputs = { + input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(trainer.device) + for input_feature_name, input_feature in trainer.model.input_features.items() + } + targets = { + output_feature_name: output_feature.create_sample_output(batch_size=batch_size).to(trainer.device) + for output_feature_name, output_feature in trainer.model.output_features.items() + } + with torch.no_grad(): + trainer.dist_model((inputs, targets)) + + return _PredictBatchSizeEvaluator() + def run_evaluation( self, training_set, diff --git a/ludwig/trainers/trainer_lightgbm.py b/ludwig/trainers/trainer_lightgbm.py index bc4ff113529..88d59901bfb 100644 --- a/ludwig/trainers/trainer_lightgbm.py +++ b/ludwig/trainers/trainer_lightgbm.py @@ -170,6 +170,7 @@ def tune_batch_size( random_seed: int, max_trials: int = 10, halving_limit: int = 3, + tune_for_training: bool = True, ) -> int: raise NotImplementedError("Tuning batch size is not supported for LightGBM.") diff --git a/ludwig/trainers/trainer_llm.py b/ludwig/trainers/trainer_llm.py index 8d1c28a6b1c..7dbf242aef2 100644 --- a/ludwig/trainers/trainer_llm.py +++ b/ludwig/trainers/trainer_llm.py @@ -204,6 +204,7 @@ def tune_batch_size( halving_limit: int = 3, snapshot_weights: bool = True, on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None, + tune_for_training: bool = True, ) -> int: # TODO: Implement batch size tuning for LLM, currently just returns the default batch size # Compared to ECD, this just requires forward passes till we OOM. diff --git a/ludwig/utils/trainer_utils.py b/ludwig/utils/trainer_utils.py index c963394b8ea..13c2360423b 100644 --- a/ludwig/utils/trainer_utils.py +++ b/ludwig/utils/trainer_utils.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict, OrderedDict -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, TYPE_CHECKING try: from typing import Literal @@ -8,13 +8,16 @@ from typing_extensions import Literal from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import COMBINED, LOSS -from ludwig.features.base_feature import OutputFeature +from ludwig.constants import AUTO, COMBINED, LOSS from ludwig.models.base import BaseModel from ludwig.modules.metric_modules import get_best_function from ludwig.utils.data_utils import save_json from ludwig.utils.metric_utils import TrainerMetric +if TYPE_CHECKING: + from ludwig.features.base_feature import OutputFeature + from ludwig.schema.trainer import BaseTrainerConfig + logger = logging.getLogger(__name__) @@ -52,7 +55,7 @@ def get_new_progress_tracker( best_eval_metric_value: float, best_increase_batch_size_eval_metric: float, learning_rate: float, - output_features: Dict[str, OutputFeature], + output_features: Dict[str, "OutputFeature"], ): """Returns a new instance of a ProgressTracker with empty metrics.""" return ProgressTracker( @@ -357,3 +360,22 @@ def get_training_report( ] ) return training_report + + +def get_rendered_batch_size_grad_accum(config: "BaseTrainerConfig", num_workers: int) -> Tuple[int, int]: + effective_batch_size = config.effective_batch_size + batch_size = config.batch_size + gradient_accumulation_steps = config.gradient_accumulation_steps + + if config.batch_size == AUTO: + if config.effective_batch_size != AUTO and config.gradient_accumulation_steps != AUTO: + batch_size = max(int(effective_batch_size / gradient_accumulation_steps / num_workers), 1) + + if config.gradient_accumulation_steps == AUTO: + if config.batch_size != AUTO: + if config.effective_batch_size != AUTO: + gradient_accumulation_steps = max(int(effective_batch_size / batch_size / num_workers), 1) + else: + gradient_accumulation_steps = 1 + + return batch_size, gradient_accumulation_steps diff --git a/tests/integration_tests/test_trainer.py b/tests/integration_tests/test_trainer.py index b7f623875e8..b973de8e384 100644 --- a/tests/integration_tests/test_trainer.py +++ b/tests/integration_tests/test_trainer.py @@ -88,8 +88,9 @@ def test_tune_learning_rate(tmpdir): @pytest.mark.parametrize("is_cpu", [True, False]) +@pytest.mark.parametrize("effective_batch_size", ["auto", 256]) @pytest.mark.parametrize("eval_batch_size", ["auto", None, 128]) -def test_tune_batch_size_and_lr(tmpdir, eval_batch_size, is_cpu): +def test_tune_batch_size_and_lr(tmpdir, eval_batch_size, effective_batch_size, is_cpu): input_features = [sequence_feature(encoder={"reduce_output": "sum"})] output_features = [ category_feature(decoder={"vocab_size": 2}, reduce_input="sum"), @@ -106,7 +107,9 @@ def test_tune_batch_size_and_lr(tmpdir, eval_batch_size, is_cpu): trainer = { "epochs": 2, + "effective_batch_size": effective_batch_size, "batch_size": "auto", + "gradient_accumulation_steps": "auto", "learning_rate": "auto", } @@ -123,7 +126,9 @@ def test_tune_batch_size_and_lr(tmpdir, eval_batch_size, is_cpu): model = LudwigModel(config, backend=LocalTestBackend(), logging_level=logging.INFO) # check preconditions + assert model.config_obj.trainer.effective_batch_size == effective_batch_size assert model.config_obj.trainer.batch_size == "auto" + assert model.config_obj.trainer.gradient_accumulation_steps == "auto" assert model.config_obj.trainer.eval_batch_size == eval_batch_size assert model.config_obj.trainer.learning_rate == "auto" @@ -135,9 +140,18 @@ def test_tune_batch_size_and_lr(tmpdir, eval_batch_size, is_cpu): def check_postconditions(model): # check batch size + assert model.config_obj.trainer.effective_batch_size == effective_batch_size assert model.config_obj.trainer.batch_size != "auto" assert model.config_obj.trainer.batch_size > 1 + # check gradient accumulation + assert model.config_obj.trainer.gradient_accumulation_steps != "auto" + if effective_batch_size == "auto": + assert model.config_obj.trainer.gradient_accumulation_steps == 1 + else: + batch_size = model.config_obj.trainer.batch_size + assert model.config_obj.trainer.gradient_accumulation_steps == effective_batch_size // batch_size + # 4 is the largest possible batch size for this dataset (20% of dataset size) assert model.config_obj.trainer.batch_size <= MAX_BATCH_SIZE_DATASET_FRACTION * num_samples diff --git a/tests/ludwig/utils/test_trainer_utils.py b/tests/ludwig/utils/test_trainer_utils.py index c548f854e73..3979bbdc9fc 100644 --- a/tests/ludwig/utils/test_trainer_utils.py +++ b/tests/ludwig/utils/test_trainer_utils.py @@ -1,11 +1,13 @@ from collections import OrderedDict +from typing import Union import pytest -from ludwig.constants import BATCH_SIZE, COMBINED, LOSS +from ludwig.constants import AUTO, BATCH_SIZE, COMBINED, LOSS from ludwig.features.category_feature import CategoryOutputFeature from ludwig.features.feature_utils import LudwigFeatureDict from ludwig.schema.features.category_feature import ECDCategoryOutputFeatureConfig +from ludwig.schema.trainer import ECDTrainerConfig from ludwig.schema.utils import load_config_with_kwargs from ludwig.utils import trainer_utils from ludwig.utils.metric_utils import TrainerMetric @@ -316,3 +318,36 @@ def test_get_final_steps_per_checkpoint(): ) == 1024 ) + + +@pytest.mark.parametrize( + "effective_batch_size,batch_size,gradient_accumulation_steps,num_workers,expected_batch_size,expected_grad_accum", + [ + (128, 16, 4, 2, 16, 4), + (AUTO, 16, 4, 2, 16, 4), + (128, 16, AUTO, 2, 16, 4), + (128, AUTO, 4, 2, 16, 4), + (128, AUTO, AUTO, 2, AUTO, AUTO), + (AUTO, AUTO, AUTO, 2, AUTO, AUTO), + (AUTO, 16, AUTO, 2, 16, 1), + (AUTO, AUTO, 4, 2, AUTO, 4), + ], +) +def test_get_rendered_batch_size_grad_accum( + effective_batch_size: Union[str, int], + batch_size: Union[str, int], + gradient_accumulation_steps: Union[str, int], + num_workers: int, + expected_batch_size: int, + expected_grad_accum: int, +): + config = ECDTrainerConfig.from_dict( + { + "effective_batch_size": effective_batch_size, + "batch_size": batch_size, + "gradient_accumulation_steps": gradient_accumulation_steps, + } + ) + rendered_batch_size, rendered_grad_accum = trainer_utils.get_rendered_batch_size_grad_accum(config, num_workers) + assert rendered_batch_size == expected_batch_size + assert rendered_grad_accum == expected_grad_accum