Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add effective_batch_size to auto-adjust gradient accumulation #3533

Merged
merged 29 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d981c51
WIP: effective_batch_size
tgaddair Aug 10, 2023
e0b556a
Auto update gradient_accumulation_steps
tgaddair Aug 11, 2023
a1c3367
Set num_training_workers
tgaddair Aug 11, 2023
96f0bcb
Use num_training_workers
tgaddair Aug 11, 2023
2954298
Fixed gradient accum
tgaddair Aug 12, 2023
9e8e373
Fixed schema
tgaddair Aug 12, 2023
7afa0f6
WIP: update grad_accum in tuner
tgaddair Aug 14, 2023
ec797b7
Fixed rendering batch_size and gradient_accumulation_steps
tgaddair Aug 15, 2023
27ff1c0
Fixed backend num_workers
tgaddair Aug 15, 2023
bf849a6
Remove restriction on setting everything to auto
tgaddair Aug 15, 2023
70903f6
Unit tests
tgaddair Aug 15, 2023
6f49cfb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2023
107d4b2
Fixed none trainer
tgaddair Aug 15, 2023
1b4faff
Merge
tgaddair Aug 15, 2023
ee615fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2023
65f882e
GBM
tgaddair Aug 15, 2023
ddee785
Merge branch 'total-bs' of https://github.com/ludwig-ai/ludwig into t…
tgaddair Aug 15, 2023
39b73c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2023
7cf6f75
Fixed tuning gbms
tgaddair Aug 16, 2023
73c1940
Merge
tgaddair Aug 16, 2023
07381bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2023
a09705a
Merge branch 'master' into total-bs
tgaddair Aug 16, 2023
a7c2170
Fixed try finally
tgaddair Aug 18, 2023
6ce70f9
Tune for training
tgaddair Aug 18, 2023
1d4b33e
Break out batch size tuning
tgaddair Aug 18, 2023
3885f80
Plumb ray backend
tgaddair Aug 18, 2023
9ce68e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2023
cbc25cd
Fixed signature
tgaddair Aug 21, 2023
49d264a
Merge branch 'total-bs' of https://github.com/ludwig-ai/ludwig into t…
tgaddair Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,7 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
random_seed=random_seed,
) as trainer:
# auto tune batch size
if (
self.config_obj.trainer.to_dict().get(BATCH_SIZE, None) == AUTO
or self.config_obj.trainer.to_dict().get(EVAL_BATCH_SIZE, None) == AUTO
):
self._tune_batch_size(trainer, training_set, random_seed=random_seed)
self._tune_batch_size(trainer, training_set, random_seed=random_seed)

# train model
if self.backend.is_coordinator():
Expand Down Expand Up @@ -795,35 +791,60 @@ def train_online(
config=self.config_obj.trainer, model=self.model, random_seed=random_seed
)

if (
self.config_obj.trainer.to_dict().get(BATCH_SIZE, None) == AUTO
or self.config_obj.trainer.to_dict().get(EVAL_BATCH_SIZE, None) == AUTO
):
self._tune_batch_size(self._online_trainer, dataset, random_seed=random_seed)
self._tune_batch_size(self._online_trainer, dataset, random_seed=random_seed)

self.model = self._online_trainer.train_online(training_dataset)

def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_seed):
if not self.config_obj.trainer.can_tune_batch_size():
# Models like GBMs don't have batch sizes to be tuned
return
Comment on lines +799 to +801
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth logging a message here to indicate this very clearly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path always gets executed regardless of user config, so I wouldn't add a message. Would likely confuse the user.


# Render the batch size and gradient accumulation steps prior to batch size tuning. This is needed in the event
# the effective_batch_size and gradient_accumulation_steps are set explicitly, but batch_size is AUTO. In this
# case, we can infer the batch_size directly without tuning.
num_workers = self.backend.num_training_workers
self.config_obj.trainer.update_batch_size_grad_accum(num_workers)

# TODO (ASN): add support for substitute_with_max parameter
# TODO(travis): detect train and eval batch sizes separately (enable / disable gradients)
if self.backend.supports_batch_size_tuning():
tuned_batch_size = trainer.tune_batch_size(self.config_obj.to_dict(), dataset, random_seed=random_seed)
else:
logger.warning(
f"Backend {self.backend.BACKEND_TYPE} does not support batch size tuning, "
f"using fallback batch size {FALLBACK_BATCH_SIZE}."
)
tuned_batch_size = FALLBACK_BATCH_SIZE

# TODO(travis): pass these in as args to trainer when we call train,
# to avoid setting state on possibly remote trainer
if self.config_obj.trainer.batch_size == AUTO:
if self.backend.supports_batch_size_tuning():
tuned_batch_size = trainer.tune_batch_size(
self.config_obj.to_dict(), dataset, random_seed=random_seed, tune_for_training=True
)
else:
logger.warning(
f"Backend {self.backend.BACKEND_TYPE} does not support batch size tuning, "
f"using fallback training batch size {FALLBACK_BATCH_SIZE}."
)
tuned_batch_size = FALLBACK_BATCH_SIZE

# TODO(travis): pass these in as args to trainer when we call train,
# to avoid setting state on possibly remote trainer
self.config_obj.trainer.batch_size = tuned_batch_size
trainer.batch_size = tuned_batch_size

# Re-render the gradient_accumulation_steps to account for the explicit batch size.
self.config_obj.trainer.update_batch_size_grad_accum(num_workers)

if self.config_obj.trainer.eval_batch_size in {AUTO, None}:
if self.backend.supports_batch_size_tuning():
tuned_batch_size = trainer.tune_batch_size(
self.config_obj.to_dict(), dataset, random_seed=random_seed, tune_for_training=False
)
else:
logger.warning(
f"Backend {self.backend.BACKEND_TYPE} does not support batch size tuning, "
f"using fallback eval batch size {FALLBACK_BATCH_SIZE}."
)
tuned_batch_size = FALLBACK_BATCH_SIZE

self.config_obj.trainer.eval_batch_size = tuned_batch_size
trainer.eval_batch_size = tuned_batch_size

# Update trainer params separate to config params for backends with stateful trainers
trainer.batch_size = self.config_obj.trainer.batch_size
trainer.eval_batch_size = self.config_obj.trainer.eval_batch_size
trainer.gradient_accumulation_steps = self.config_obj.trainer.gradient_accumulation_steps

def predict(
self,
Expand Down
13 changes: 13 additions & 0 deletions ludwig/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ def read_binary_files(self, column: Series, map_fn: Optional[Callable] = None) -
def num_nodes(self) -> int:
raise NotImplementedError()

@property
@abstractmethod
def num_training_workers(self) -> int:
raise NotImplementedError()

@abstractmethod
def get_available_resources(self) -> Resources:
raise NotImplementedError()
Expand Down Expand Up @@ -250,6 +255,10 @@ def __init__(self, **kwargs):
def num_nodes(self) -> int:
return 1

@property
def num_training_workers(self) -> int:
return 1

def get_available_resources(self) -> Resources:
return Resources(cpus=psutil.cpu_count(), gpus=torch.cuda.device_count())

Expand Down Expand Up @@ -308,6 +317,10 @@ def is_coordinator(self):

@property
def num_nodes(self) -> int:
return self._distributed.size() // self._distributed.local_size()

@property
def num_training_workers(self) -> int:
return self._distributed.size()

def get_available_resources(self) -> Resources:
Expand Down
29 changes: 23 additions & 6 deletions ludwig/backend/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def tune_batch_size_fn(
training_set_metadata: TrainingSetMetadataDict = None,
features: Dict[str, Dict] = None,
remote_trainer_cls: Callable[[], Trainer] = None,
tune_for_training: bool = True,
**kwargs,
):
# Pin GPU before loading the model to prevent memory leaking onto other devices
Expand All @@ -276,6 +277,7 @@ def on_best_batch_size_updated(best_batch_size: int, best_samples_per_sec: float
train_shard,
snapshot_weights=False,
on_best_batch_size_updated=on_best_batch_size_updated,
tune_for_training=tune_for_training,
**kwargs,
)
session.report(
Expand Down Expand Up @@ -539,6 +541,7 @@ def tune_batch_size(
self,
config: ModelConfigDict,
training_set: RayDataset,
tune_for_training: bool = True,
**kwargs,
) -> int:
with create_runner(**self.trainer_kwargs) as runner:
Expand All @@ -552,6 +555,7 @@ def tune_batch_size(
ludwig_config=config,
training_set_metadata=training_set.training_set_metadata,
features=training_set.features,
tune_for_training=tune_for_training,
**kwargs,
),
exception_on_error=False,
Expand Down Expand Up @@ -592,6 +596,14 @@ def eval_batch_size(self) -> int:
def eval_batch_size(self, value: int):
self.config.eval_batch_size = value

@property
def gradient_accumulation_steps(self) -> int:
return self.config.gradient_accumulation_steps

@gradient_accumulation_steps.setter
def gradient_accumulation_steps(self, value: int):
self.config.gradient_accumulation_steps = value

@property
def resources_per_worker(self) -> Dict[str, Any]:
trainer_kwargs = get_trainer_kwargs(**self.trainer_kwargs)
Expand Down Expand Up @@ -876,7 +888,7 @@ def __init__(
super().__init__(dataset_manager=RayDatasetManager(self), **kwargs)
self._preprocessor_kwargs = preprocessor_kwargs or {}
self._df_engine = _get_df_engine(processor)
self._horovod_kwargs = trainer or {}
self._distributed_kwargs = trainer or {}
self._pytorch_kwargs = {}
self._data_loader_kwargs = loader or {}
self._preprocessor_pg = None
Expand Down Expand Up @@ -943,7 +955,7 @@ def create_trainer(self, model: BaseModel, **kwargs) -> "BaseTrainer": # noqa:

all_kwargs = {
"model": model,
"trainer_kwargs": self._horovod_kwargs,
"trainer_kwargs": self._distributed_kwargs,
"data_loader_kwargs": self._data_loader_kwargs,
"executable_kwargs": executable_kwargs,
}
Expand All @@ -956,18 +968,18 @@ def create_predictor(self, model: BaseModel, **kwargs):
return RayPredictor(
model,
self.df_engine,
self._horovod_kwargs,
self._distributed_kwargs,
self._data_loader_kwargs,
**executable_kwargs,
)

@property
def distributed_kwargs(self):
return self._horovod_kwargs
return self._distributed_kwargs

@distributed_kwargs.setter
def distributed_kwargs(self, value):
self._horovod_kwargs = value
self._distributed_kwargs = value

@property
def df_engine(self):
Expand Down Expand Up @@ -1068,6 +1080,11 @@ def num_nodes(self) -> int:
return 1
return len(ray.nodes())

@property
def num_training_workers(self) -> int:
trainer_kwargs = get_trainer_kwargs(**self._distributed_kwargs)
return trainer_kwargs["num_workers"]

def get_available_resources(self) -> Resources:
resources = ray.cluster_resources()
return Resources(cpus=resources.get("CPU", 0), gpus=resources.get("GPU", 0))
Expand Down Expand Up @@ -1122,7 +1139,7 @@ def batch_transform(self, df: DataFrame, batch_size: int, transform_fn: Callable
return self.df_engine.from_ray_dataset(ds)

def _get_transform_kwargs(self) -> Dict[str, Any]:
trainer_kwargs = get_trainer_kwargs(**self._horovod_kwargs)
trainer_kwargs = get_trainer_kwargs(**self._distributed_kwargs)
resources_per_worker = trainer_kwargs.get("resources_per_worker", {})
num_gpus = resources_per_worker.get("GPU", 0)
num_cpus = resources_per_worker.get("CPU", (1 if num_gpus == 0 else 0))
Expand Down
6 changes: 5 additions & 1 deletion ludwig/distributed/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn
from torch.optim import Optimizer

from ludwig.constants import AUTO
from ludwig.distributed.base import DistributedStrategy
from ludwig.modules.optimization_modules import create_optimizer
from ludwig.utils.horovod_utils import gather_all_tensors, is_distributed_available
Expand All @@ -35,10 +36,13 @@ def prepare(
base_learning_rate: float,
) -> Tuple[nn.Module, Optimizer]:
optimizer = create_optimizer(model, trainer_config.optimizer, base_learning_rate)
grad_accum_steps = (
trainer_config.gradient_accumulation_steps if trainer_config.gradient_accumulation_steps != AUTO else 1
)
dist_optimizer = hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
backward_passes_per_step=trainer_config.gradient_accumulation_steps,
backward_passes_per_step=grad_accum_steps,
)
return model, dist_optimizer

Expand Down
7 changes: 7 additions & 0 deletions ludwig/schema/metadata/configs/trainer.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
ecd:
effective_batch_size:
commonly_used: true
expected_impact: 2
related_parameters:
- batch_size
suggested_values: auto
ui_display_name: Effective Batch Size
batch_size:
commonly_used: true
default_value_reasoning: Not too big, not too small.
Expand Down
Loading
Loading