-
Notifications
You must be signed in to change notification settings - Fork 459
Closed
Description
Hi,
I'm trying to train Mixtral 8x22B using FSDP. I followed the zephyr-141b-A35b recipe to create my training config below:
model_name_or_path: mistral-community/Mixtral-8x22B-v0.1
model_revision: main
torch_dtype: bfloat16
use_flash_attention_2: true
dataset_mixer:
HuggingFaceH4/ultrachat_200k: 1.0
dataset_splits:
- train_sft
preprocessing_num_workers: 12
bf16: true
do_eval: false
evaluation_strategy: "no"
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
learning_rate: 2.0e-5
log_level: info
logging_steps: 10
lr_scheduler_type: cosine
max_seq_length: 2048
num_train_epochs: 3
optim: adamw_bnb_8bit
output_dir: ./models/zephyr-sft-141b-A35b
overwrite_output_dir: true
per_device_train_batch_size: 1
remove_unused_columns: true
push_to_hub: false
report_to: "none"
save_strategy: "steps"
save_steps: 5
seed: 42
warmup_steps: 0.1
However, there was a saving error that happened while FSDP was saving the optimizer. What did I do wrong with the config?
Here is the log:
[-01:2]: self._save_checkpoint(model, trial, metrics=metrics)
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint
[-01:2]: self._save_optimizer_and_scheduler(output_dir)
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler
[-01:2]: save_fsdp_optimizer(
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer
[-01:2]: optim_state = FSDP.optim_state_dict(model, optimizer)
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
[-01:2]: return FullyShardedDataParallel._optim_state_dict_impl(
[-02:1]:Traceback (most recent call last):
[-02:1]: File "/home/user/alignment-handbook/scripts/run_sft.py", line 233, in <module>
[-02:1]: main()
[-02:1]: File "/home/user/alignment-handbook/scripts/run_sft.py", line 188, in main
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1255, in _optim_state_dict_impl
[-01:2]: return _optim_state_dict(
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[-02:1]: train_result = trainer.train(resume_from_checkpoint=checkpoint)
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
[-02:1]: output = super().train(*args, **kwargs)
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
[-02:1]: return inner_training_loop(
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2193, in _inner_training_loop
[-02:1]: self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2588, in _maybe_log_save_evaluate
[-01:2]: return func(*args, **kwargs)
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1972, in _optim_state_dict
[-01:2]: fsdp_osd_state = convert_fn(
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1795, in _convert_state_with_orig_params
[-01:2]: _gather_all_orig_param_state(
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1689, in _gather_all_orig_param_state
[-01:2]: output_states = _allgather_orig_param_states(
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1519, in _allgather_orig_param_states
[-01:2]: dtype, state_buffers = _convert_all_state_info(
[-02:1]: self._save_checkpoint(model, trial, metrics=metrics)
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2660, in _save_checkpoint
[-02:1]: self._save_optimizer_and_scheduler(output_dir)
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2757, in _save_optimizer_and_scheduler
[-02:1]: save_fsdp_optimizer(
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 157, in save_fsdp_optimizer
[-02:1]: optim_state = FSDP.optim_state_dict(model, optimizer)
[-02:1]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1832, in optim_state_dict
[-02:1]: return FullyShardedDataParallel._optim_state_dict_impl(
[-01:2]: File "/home/user/.local/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1378, in _convert_all_state_info
[-01:2]: assert dtype == info.dtype
Metadata
Metadata
Assignees
Labels
No labels