Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raise error for duplicate accelerate config values when using deepspeed_config_file #941

Merged
merged 30 commits into from Dec 31, 2022

Conversation

pacman100
Copy link
Collaborator

@pacman100 pacman100 commented Dec 23, 2022

What dos this PR do?

  1. Fixes: [feature] assert on ambiguity config when using deepspeed #936

Example:

  1. accelerate config manually tweaked to have both deepspeed_config_file and other ds config entries that are available in the json config file:
command_file: null
commands: null
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: 'cpu'
  offload_param_device: 'cpu'
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
  deepspeed_config_file: 'ds_config.json'
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
gpu_ids: null
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
megatron_lm_config: {}
mixed_precision: 'bf16'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_name: null
tpu_zone: null
use_cpu: false
  1. ds_config.json:
{
    "fp16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "stage3_gather_16bit_weights_on_model_save": false,
        "offload_optimizer": {
            "device": "none"
        },
        "offload_param": {
            "device": "none"
        }
    },
    "gradient_clipping": 1.0,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": 10,
    "steps_per_print": 2000000
}
  1. Code:
from accelerate import Accelerator

def main():
    accelerator = Accelerator()

if __name__ == "__main__":
    main()
  1. output:
ValueError: When using `deepspeed_config_file`, the following accelerate config variables will be 
ignored: ['gradient_accumulation_steps', 'gradient_clipping', 'zero_stage', 
'offload_optimizer_device', 'offload_param_device', 'zero3_save_16bit_model', 'mixed_precision'].
Please specify them appropriately in the DeepSpeed config file.
If you are using accelerate config file, set `mixed_precision=no` and remove others config variables
mentioned in the above specified list; else don't specify these config variables in `accelerate 
launch` command. 
The easiest method is to create new config following the questionnaire via  `accelerate config`.
It will only ask for the necessary config variables when using `deepspeed_config_file`.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 23, 2022

The documentation is not available anymore as the PR was closed or merged.

@pacman100 pacman100 marked this pull request as ready for review December 23, 2022 06:01
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but let's also wait for @stas00 to check this.
Thanks for working on it!

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for working on this, Sourab!

I think perhaps the Issue I posted didn't succeed to communicate what the main problem is.

The mismatch is one issue, and your PR fantastically taking care of it - but really what we want is one definitive source of information.

The problem is that unless the user spent enough time with this code base they will try to maintain 2 sets of configs, which makes work very painful. this is not hypothetical - this is through my own experience and those of others on my team that don't understand that they needn't maintain 2 definitions. I have just discovered that myself.

What I propose is that if ds_config is used any config duplicates are outright refused. As in:

The ds_config is already defining the value for gradient_accumulation, please remove the ambiguous gradient accumulation setting from the accelerate config file.

Or alternatively require a single source in either way as we don't want to restrict users freedom of how they do thing.

In other words every time a duplicity is detected Accelerate will assert and say:

Detected gradient_accumulation defined in both ds_config and accelerate config file, as this creates ambiguity please remove one of the settings.

Does my thinking make sense?

@stas00
Copy link
Contributor

stas00 commented Dec 24, 2022

I tried this branch, getting:

Traceback (most recent call last):
  File "m4/training/main.py", line 47, in <module>
    accelerator = Accelerator(
  File "/mnt/nvme0/code/huggingface/accelerate-master/src/accelerate/accelerator.py", line 246, in __init__
    DeepSpeedPlugin() if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" else None
  File "<string>", line 12, in __init__
  File "/mnt/nvme0/code/huggingface/accelerate-master/src/accelerate/utils/dataclasses.py", line 412, in __post_init__
    self._deepspeed_config_checks()
  File "/mnt/nvme0/code/huggingface/accelerate-master/src/accelerate/utils/dataclasses.py", line 560, in _deepspeed_config_checks
    if ds_gradient_accumulation_steps != int(accelerate_gradient_accumulation_steps):
ValueError: invalid literal for int() with base 10: 'None'

configs:

{
    "fp16": {
        "enabled": true,
        "auto_cast": true,
        "loss_scale": 0.0,
        "initial_scale_power": 10,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "zero_optimization": {
        "stage": 3,
        "allgather_partitions": true,
        "allgather_bucket_size": 5e8,
        "overlap_comm": false,
        "reduce_scatter": true,
        "reduce_bucket_size": "auto",
        "contiguous_gradients": true,
        "stage3_gather_16bit_weights_on_model_save": false,
        "offload_optimizer": {
            "device": "none"
        },
        "offload_param": {
            "device": "none"
        }
    },
    "gradient_clipping": 1.0,
    "gradient_accumulation_steps": 2,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "steps_per_print": 2000000
}

compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_config_file: ./configs/vopt-large-z3/ds_config.json
  zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 1
#num_machines: 20
#num_processes: 80
use_cpu: false


@muellerzr
Copy link
Collaborator

@pacman100 default needs to be 1 instead of None :)

@pacman100
Copy link
Collaborator Author

pacman100 commented Dec 24, 2022

Also clarifying defaults for args in this PR. Now, accelerate launch --use_deepspeed --help shows below output wherein default value info is specified If unspecified, will default to *:

...

DeepSpeed Arguments:
  Arguments related to DeepSpeed.

  --deepspeed_config_file DEEPSPEED_CONFIG_FILE
                        DeepSpeed config file.
  --zero_stage ZERO_STAGE
                        DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag
                        is passed). If unspecified, will default to `2`.
  --offload_optimizer_device OFFLOAD_OPTIMIZER_DEVICE
                        Decides where (none|cpu|nvme) to offload optimizer states (useful only
                        when `use_deepspeed` flag is passed). If unspecified, will default to
                        `none`.
  --offload_param_device OFFLOAD_PARAM_DEVICE
                        Decides where (none|cpu|nvme) to offload parameters (useful only when
                        `use_deepspeed` flag is passed). If unspecified, will default to `none`.
  --gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS
                        No of gradient_accumulation_steps used in your training script (useful
                        only when `use_deepspeed` flag is passed). If unspecified, will default to
                        `1`.
  --gradient_clipping GRADIENT_CLIPPING
                        gradient clipping value used in your training script (useful only when
                        `use_deepspeed` flag is passed). If unspecified, will default to `1.0`.
  --zero3_init_flag ZERO3_INIT_FLAG
                        Decides Whether (true|false) to enable `deepspeed.zero.Init` for
                        constructing massive models. Only applicable with DeepSpeed ZeRO Stage-3.
                        If unspecified, will default to `true`.
  --zero3_save_16bit_model ZERO3_SAVE_16BIT_MODEL
                        Decides Whether (true|false) to save 16-bit model weights when using ZeRO
                        Stage-3. Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will
                        default to `false`.
  --deepspeed_hostfile DEEPSPEED_HOSTFILE
                        DeepSpeed hostfile for configuring multi-node compute resources.
  --deepspeed_exclusion_filter DEEPSPEED_EXCLUSION_FILTER
                        DeepSpeed exclusion filter string when using mutli-node setup.
  --deepspeed_inclusion_filter DEEPSPEED_INCLUSION_FILTER
                        DeepSpeed inclusion filter string when using mutli-node setup.
  --deepspeed_multinode_launcher DEEPSPEED_MULTINODE_LAUNCHER
                        DeepSpeed multi-node launcher to use. If unspecified, will default to
                        `pdsh`.

@pacman100 pacman100 changed the title ds config vs accelerate config checks raise error when for duplicate accelerate config values when using deepspeed_config_file Dec 24, 2022
@pacman100 pacman100 changed the title raise error when for duplicate accelerate config values when using deepspeed_config_file raise error for duplicate accelerate config values when using deepspeed_config_file Dec 24, 2022
@pacman100
Copy link
Collaborator Author

ValueError: When using `deepspeed_config_file`, the following accelerate config variables will be 
ignored: ['gradient_accumulation_steps', 'gradient_clipping', 'zero_stage', 
'offload_optimizer_device', 'offload_param_device', 'zero3_save_16bit_model', 'mixed_precision'].
Please specify them appropriately in the DeepSpeed config file.
If you are using accelerate config file, set `mixed_precision=no` and remove others config variables
mentioned in the above specified list; else don't specify these config variables in `accelerate 
launch` command. 
The easiest method is to create new config following the questionnaire via  `accelerate config`.
It will only ask for the necessary config variables when using `deepspeed_config_file`.

Now, the error will be this

@pacman100
Copy link
Collaborator Author

but really what we want is one definitive source of information.

Hello @stas00, let us know if this addresses the issue. When deepspeed_config_file is specified, it is the single definitive source of information and error is raised when duplicates are found either in accelerate config file or through the arguments of accelerate launch command

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Left a couple of nits.

src/accelerate/commands/launch.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
pacman100 and others added 2 commits December 26, 2022 12:22
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for further improvements, Sourab.

Will this still work with auto entries in ds_config.json, which were designed to be filled by either the system when they need to be calculated at run time or via cmd line args? I'm yet to try it out, but I have a feeling that the latter will no longer work, would it.

src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
@stas00
Copy link
Contributor

stas00 commented Dec 26, 2022

But this shows otherwise:

  File "/mnt/nvme0/code/huggingface/accelerate-master/src/accelerate/utils/dataclasses.py", line 409, in __post_init__
    raise ValueError("gradient_accumulation_steps cannot be set to 'auto' in the DeepSpeed config.")
ValueError: gradient_accumulation_steps cannot be set to 'auto' in the DeepSpeed config.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid

why can't it be auto? I'd expect it to be set either (1) to the default value or (2) the value passed via the --gradient_accumulation_steps cmd line arg.

I wonder if we are having a miscommunication here. I brought up the issue of duplicity of the 2 styles of settings the ds config, since accelerate used its own config files from the beginning, but I have never suggested that setting values via cmd line args support should be dropped.

@pacman100
Copy link
Collaborator Author

Regarding gradient_accumulation_steps and auto, that piece of code wasn't changed at all and behaviour is the same as before. If one doesn't specify that entry in the config file, it is set to default of 1, if it is auto error is raised else thale given value is used. When not using deepspeed_config_file, it is asked when using accelerate config cmd.

@stas00
Copy link
Contributor

stas00 commented Dec 27, 2022

Regarding gradient_accumulation_steps and auto, that piece of code wasn't changed at all and behaviour is the same as before. If one doesn't specify that entry in the config file, it is set to default of 1, if it is auto error is raised else thale given value is used. When not using deepspeed_config_file, it is asked when using accelerate config cmd.

OK, so your logic is different from HF Trainer then.

The HF Trainer was directing cmd args into the config file's auto values so that the user could override them via cmd line args.

I'm not saying the 2 logics have to match. If I am not mistaken the accelerate logic is less flexible, but it's ok if you prefer it that way.

In HF Trainer the auto feature was designed to be used:

  1. when the value can't be known before running - "boot"-time calculated configs
  2. values to be set via cmd line args and defaults

@pacman100
Copy link
Collaborator Author

pacman100 commented Dec 27, 2022

As accelerate is meant to work with all models apart from Transformers and user being in control of the training loop, they are in charge of all the arguments and the naming convention of arguments will be different across different users. On the other hand, in Trainer, users are restricted to a given args set and as such those can be used to fill the DeepSpeed config due to clear mapping between args and DS config params. The idea is that artifacts sent to accelerator.prepare have the params required by DS config and we exactly know the mapping between them and makes filling of the params independent of the user's training loop and their arguments naming convention.

In accelerate, the auto values are those that can be filled via artifacts being sent to accelerator.prepare as all the other places, the user has complete control over the training loop, argument naming and filling. The user still has flexibility to fill in all the auto values themselves as mentioned here #676 (comment)

@stas00
Copy link
Contributor

stas00 commented Dec 27, 2022

As accelerate is meant to work with all models apart from Transformers and user being in control of the training loop, they are in charge of all the arguments and the naming convention of arguments will be different across different users. On the other hand, in Trainer, users are restricted to a given args set and as such those can be used to fill the DeepSpeed config due to clear mapping between args and DS config params. The idea is that artifacts sent to accelerator.prepare have the params required by DS config and we exactly know the mapping between them and makes filling of the params independent of the user's training loop and their arguments naming convention.

Thank you for explaining this to me, Sourab, but I'm having a hard time following how Accelerate is any different from HF Trainer wrt sending cmd line arg values to the unfilled out config values in ds_config. e.g. the Accelerate launcher provides an explicit list of cmd line args for the deepspeed use. There is a 1:1 mapping here as well. Could you please explain how is this different from the HF Trainer?

But as I said above it's totally fine if you prefer to do it this way, Sourab. This doesn't prevent users from doing what they need.

In accelerate, the auto values are those that can be filled via artifacts being sent to accelerator.prepare as all the other places, the user has complete control over the training loop, argument naming and filling. The user still has flexibility to fill in all the auto values themselves as mentioned here #676 (comment)

Understood. more work, but doable. Thank you for the explanations.

So far, `accelerate launch` cmd args were used for filling deepspeed plugin fields and not for setting `auto` values. This PR enables that too.

It also raises assertions when ambiguous values are passed in accelerate config file when using `deepspeed_config_file`
@pacman100
Copy link
Collaborator Author

pacman100 commented Dec 27, 2022

Could you please explain how is this different from the HF Trainer?

Users can have bs or batch_size as cmd arguments in their code and as such we can't fill ds config's train_micro_batch_size_per_gpu whereas Trainer always maps args.per_device_train_batch_size to it. The same reason can go for other configs.

Please note that accelerate launch cmd args are primarily used for setting accelerate config's deespseed fields rather than setting auto values of deepspeed_config_file. Now I understood that you meant using accelerate launch cmd args for filling in auto values of deepspeed_config_file and I've made respective changes.

Latest changes:

Code test.py:

from accelerate import Accelerator
from accelerate.state import AcceleratorState

def main():
    accelerator = Accelerator()
    accelerator.print(f"{AcceleratorState()}")
if __name__ == "__main__":
    main()

Scenario 1: manually tampered accelerate config file having deepspeed_config_file along with other entries.

  1. accelerate config:
command_file: null
commands: null
compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: 'cpu'
  offload_param_device: 'cpu'
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
  deepspeed_config_file: 'ds_config.json'
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
gpu_ids: null
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
megatron_lm_config: {}
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_name: null
tpu_zone: null
use_cpu: false
  1. ds_config.json:
{
    "bf16": {
        "enabled": true
    },
    "zero_optimization": {
        "stage": 3,
        "stage3_gather_16bit_weights_on_model_save": false,
        "offload_optimizer": {
            "device": "none"
        },
        "offload_param": {
            "device": "none"
        }
    },
    "gradient_clipping": 1.0,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": 10,
    "steps_per_print": 2000000
}
  1. Output of accelerate launch test.py:
ValueError: When using `deepspeed_config_file`, the following accelerate config variables will be ignored: 
['gradient_accumulation_steps', 'gradient_clipping', 'zero_stage', 'offload_optimizer_device', 'offload_param_device', 
'zero3_save_16bit_model', 'mixed_precision'].
Please specify them appropriately in the DeepSpeed config file.
If you are using an accelerate config file, remove others config variables mentioned in the above specified list.
The easiest method is to create a new config following the questionnaire via `accelerate config`.
It will only ask for the necessary config variables when using `deepspeed_config_file`.

Scenario 2: Use the solution of the error to create new accelerate config and check that no ambiguity error is now thrown.

  1. Run accelerate config:
$ accelerate config
-------------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine                                                                                                                   
-------------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?                                                                                           
multi-GPU                                                                                                                      
How many different machines will you use (use more than 1 for multi-node training)? [1]:                                       
Do you wish to optimize your script with torch dynamo?[yes/NO]:                                                                
Do you want to use DeepSpeed? [yes/NO]: yes                                                                                    
Do you want to specify a json file to a DeepSpeed config? [yes/NO]: yes                                                        
Please enter the path to the json DeepSpeed config file: ds_config.json                                                        
Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: yes
How many GPU(s) should be used for distributed training? [1]:4
accelerate configuration saved at ds_config_sample.yaml
  1. accelerate config:
compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_config_file: ds_config.json
  zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
use_cpu: false
  1. Output of accelerate launch test.py:
Distributed environment: DEEPSPEED  Backend: nccl
Num processes: 4
Process index: 0
Local process index: 0
Device: cuda:0
Mixed precision type: bf16
ds_config: {'bf16': {'enabled': True}, 'zero_optimization': {'stage': 3, 'stage3_gather_16bit_weights_on_model_save': False, 'offload_optimizer': {'device': 'none'}, 'offload_param': {'device': 'none'}}, 'gradient_clipping': 1.0, 'train_batch_size': 'auto', 'train_micro_batch_size_per_gpu': 'auto', 'gradient_accumulation_steps': 10, 'steps_per_print': inf, 'fp16': {'enabled': False}}

Scenario 3: Setting the accelerate launch cmd args related to deepspeed as auto in deepspeed_config_file and check that things work as expected.

  1. new ds_config.json with auto for the accelerate launch deepspeed cmd args:
{
    "bf16": {
        "enabled": "auto"
    },
    "zero_optimization": {
        "stage": "auto",
        "stage3_gather_16bit_weights_on_model_save": "auto",
        "offload_optimizer": {
            "device": "auto"
        },
        "offload_param": {
            "device": "auto"
        }
    },
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "steps_per_print": 2000000
}
  1. Output of accelerate launch --mixed_precision="fp16" --zero_stage=3 --gradient_accumulation_steps=5 --gradient_clipping=1.0 --offload_param_device="cpu" --offload_optimizer_device="nvme" --zero3_save_16bit_model="true" test.py:
Distributed environment: DEEPSPEED  Backend: nccl
Num processes: 4
Process index: 0
Local process index: 0
Device: cuda:0
Mixed precision type: fp16
ds_config: {'bf16': {'enabled': False}, 'zero_optimization': {'stage': 3, 'stage3_gather_16bit_weights_on_model_save': True, 'offload_optimizer': {'device': 'nvme'}, 'offload_param': {'device': 'cpu'}}, 'gradient_clipping': 1.0, 'train_batch_size': 'auto', 'train_micro_batch_size_per_gpu': 'auto', 'gradient_accumulation_steps': 5, 'steps_per_print': inf, 'fp16': {'enabled': True, 'auto_cast': True}}

Note: Remaining auto values are handled in accelerator.prepare() call.

@stas00
Copy link
Contributor

stas00 commented Dec 27, 2022

Looks fantastic, Sourab! Thank you for the improvements and taking the time to layout out the different scenarios - if I'm not mistaken those would make for perfect additions to the documentation if it resonates. (at the very least the last one to demo how auto values work and why would one want to use those).

BTW, the config generates things like:

fsdp_config: {}
megatron_lm_config: {}

why not just skip parts that the user hasn't asked for? It just makes the config scarier than it is, no? I'm asking since when I first looked at it I wasn't a all sure which of the empty placeholders were safe to remove and which aren't. My personal preference is for active config - that is to only ever list config entries that I work with and any defaults should be just that defaults and not be listed at all. Which I suppose isn't the case with typical configs where everything is listed out whether it's being used or not.

And I can of course remove all those, so definitely it's not an issue, I'm just asking if my thinking resonates with you.

@stas00
Copy link
Contributor

stas00 commented Dec 30, 2022

Sourab, I found one more ambiguous combo in one of our tests:

  zero3_init_flag: true
  zero_stage: 2

This combo is quietly getting accepted. I'm concerned that a developer may see zero3_init_flag: true and think it's zero3?

Do you think accelerate should assert when zero3_init_flag==True is used with stage < 3?

@pacman100
Copy link
Collaborator Author

pacman100 commented Dec 30, 2022

Hello @stas00, with current setup below warning is given which I think is fine:

UserWarning: DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.

@stas00
Copy link
Contributor

stas00 commented Dec 30, 2022

oh boy. I didn't see it. :(

If a tree falls in a forest and no one is around to hear it, does it make a sound?

I guess I need to start using this pragma to turn warnings into errors, but then some warnings can't be acted upon :(

import warnings
warnings.filterwarnings("error")

This is for example an even larger issue for tests, where distributed setup hides most warnings or again there are too many of warnings to see anything.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nits on the doc. Why did we switch the default for zero3_init_flag to True?

docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
docs/source/usage_guides/deepspeed.mdx Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
pacman100 and others added 4 commits December 30, 2022 13:13
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@pacman100 pacman100 merged commit e60f3ca into huggingface:main Dec 31, 2022
@pacman100 pacman100 deleted the smangrul/ds-config-assertions branch January 19, 2023 11:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature] assert on ambiguity config when using deepspeed
5 participants