Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed enhancements and fixes #676

Merged
merged 4 commits into from Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 24 additions & 19 deletions src/accelerate/accelerator.py
Expand Up @@ -723,25 +723,28 @@ def _prepare_deepspeed(self, *args):

deepspeed_plugin = self.state.deepspeed_plugin

result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
for obj in args
]

batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
if self.split_batches:
batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
if len(batch_sizes) == 0:
raise ValueError(
"You must specify a training or evaluation dataloader in `accelerate.prepare()` when using DeepSpeed."
)
if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto":
result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
for obj in args
]

batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
if self.split_batches:
batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
if len(batch_sizes) == 0:
raise ValueError(
"You must specify a training or evaluation dataloader in `accelerate.prepare()` when using DeepSpeed."
)

batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
if len(batch_sizes) > 1:
logger.info(
"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
)
batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
if len(batch_sizes) > 1:
logger.info(
"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
)
else:
batch_size_per_device = deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"]

config_kwargs = {
"train_micro_batch_size_per_gpu": batch_size_per_device,
Expand Down Expand Up @@ -916,7 +919,9 @@ def backward(self, loss, **kwargs):

Should be used in lieu of `loss.backward()`.
"""
loss /= self.gradient_accumulation_steps
if self.distributed_type != DistributedType.DEEPSPEED:
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
loss = loss / self.gradient_accumulation_steps
if self.distributed_type == DistributedType.DEEPSPEED:
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
elif self.scaler is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/commands/launch.py
Expand Up @@ -549,7 +549,8 @@ def deepspeed_launcher(args):
current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower()
current_env["DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower()
current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower()
current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file).lower()
if args.deepspeed_config_file is not None:
current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file)

if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
with open(".deepspeed_env", "a") as f:
Expand Down