Skip to content

Different dtype while saving optimizer with FSDP #153

@heraclex12

Description

@heraclex12

Hi,

I'm trying to train Mixtral 8x22B using FSDP. I followed the zephyr-141b-A35b recipe to create my training config below:

model_name_or_path: mistral-community/Mixtral-8x22B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true

dataset_mixer:
  HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
preprocessing_num_workers: 12

bf16: true
do_eval: false
evaluation_strategy: "no"
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 2.0e-5
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_seq_length: 2048
num_train_epochs: 3
optim: adamw_bnb_8bit
output_dir: ./models/zephyr-sft-141b-A35b
overwrite_output_dir: true
per_device_train_batch_size: 1
remove_unused_columns: true
push_to_hub: false
report_to: "none"
save_strategy: "steps"
save_steps: 5
seed: 42
warmup_steps: 0.1

However, there was a saving error that happened while FSDP was saving the optimizer. What did I do wrong with the config?

Here is the log:

[-01:2]:    self._save_checkpoint(model, trial, metrics=metrics)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint
[-01:2]:    self._save_optimizer_and_scheduler(output_dir)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler
[-01:2]:    save_fsdp_optimizer(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer
[-01:2]:    optim_state = FSDP.optim_state_dict(model, optimizer)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
[-01:2]:    return FullyShardedDataParallel._optim_state_dict_impl(
[-02:1]:Traceback (most recent call last):
[-02:1]:  File "/home/user/alignment-handbook/scripts/run_sft.py", line 233, in <module>
[-02:1]:    main()
[-02:1]:  File "/home/user/alignment-handbook/scripts/run_sft.py", line 188, in main
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl
[-01:2]:    return _optim_state_dict(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[-02:1]:    train_result = trainer.train(resume_from_checkpoint=checkpoint)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
[-02:1]:    output = super().train(*args, **kwargs)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
[-02:1]:    return inner_training_loop(
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2193, in _inner_training_loop
[-02:1]:    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2588, in _maybe_log_save_evaluate
[-01:2]:    return func(*args, **kwargs)
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict
[-01:2]:    fsdp_osd_state = convert_fn(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params
[-01:2]:    _gather_all_orig_param_state(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state
[-01:2]:    output_states = _allgather_orig_param_states(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states
[-01:2]:    dtype, state_buffers = _convert_all_state_info(
[-02:1]:    self._save_checkpoint(model, trial, metrics=metrics)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint
[-02:1]:    self._save_optimizer_and_scheduler(output_dir)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler
[-02:1]:    save_fsdp_optimizer(
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer
[-02:1]:    optim_state = FSDP.optim_state_dict(model, optimizer)
[-02:1]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
[-02:1]:    return FullyShardedDataParallel._optim_state_dict_impl(
[-01:2]:  File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1378, in _convert_all_state_info
[-01:2]:    assert dtype == info.dtype

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions