Skip to content

Commit

Permalink
Add a utility for writing a barebones config file (#371)
Browse files Browse the repository at this point in the history
* Create a basic_config function
  • Loading branch information
muellerzr committed May 18, 2022
1 parent 64e41a4 commit 043d2ec
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/internal.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ The main work on your PyTorch `DataLoader` is done by the following function:
[[autodoc]] utils.synchronize_rng_states

[[autodoc]] utils.wait_for_everyone

[[autodoc]] utils.write_basic_config
9 changes: 8 additions & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,12 @@

from .launch import PrepareForLaunch
from .memory import find_executable_batch_size
from .other import extract_model_from_parallel, get_pretty_name, patch_environment, save, wait_for_everyone
from .other import (
extract_model_from_parallel,
get_pretty_name,
patch_environment,
save,
wait_for_everyone,
write_basic_config,
)
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
45 changes: 45 additions & 0 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import os
from contextlib import contextmanager
from pathlib import Path

import torch

from ..commands.config.cluster import ClusterConfig
from ..commands.config.config_args import default_json_config_file
from ..state import AcceleratorState
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_tpu_available
Expand Down Expand Up @@ -109,3 +112,45 @@ def get_pretty_name(obj):
if hasattr(obj, "__name__"):
return obj.__name__
return str(obj)


def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file):
"""
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
set CPU if it is a CPU-only machine.
Args:
mixed_precision (`str`, *optional*, defaults to "no"):
Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
save_location (`str`, *optional*, defaults to `default_json_config_file`):
Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default
location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overriden by setting
the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`.
"""
path = Path(save_location)
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
print(
f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
)
return
mixed_precision = mixed_precision.lower()
if mixed_precision not in ["no", "fp16", "bf16"]:
raise ValueError(f"`mixed_precision` should be one of 'no', 'fp16', or 'bf16'. Received {mixed_precision}")
config = {"compute_environment": "LOCAL_MACHINE", "mixed_precision": mixed_precision}
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
config["num_processes"] = num_gpus
config["use_cpu"] = False
if num_gpus > 1:
config["distributed_type"] = "MULTI_GPU"
else:
config["distributed_type"] = "NO"
else:
num_gpus = 0
config["use_cpu"] = True
config["num_processes"] = 1
config["distributed_type"] = "NO"
if not path.exists():
config = ClusterConfig(**config)
config.to_json_file(path)

0 comments on commit 043d2ec

Please sign in to comment.