Skip to content

[BUG] Gradient accumulation causing training loss differences in Deepspeed vs FSDP #5898

@gramesh-amd

Description

@gramesh-amd

Describe the bug
I am trying to pretrain an Olmo 1B model on 8 MI 250 GPUs with Docker image: rocm/pytorch:latest (ROCm 6.1). I'm using a small subset of Dolma dataset for pretraining.

I see that training loss is comparable between FSDP and Deepspeed when gradient accumulation is small but as the gradient accumulation increases, the training loss seems to be different
image
^ for instance the above run, im using a gradient accumulation of 16 (dark blue is FSDP and purple is deepspeed)

I'm testing all my training runs in mixed precision amp_fp16. The reduce_dtype in FSDP.MixedPrecision is set to fp32 and I also make sure to set "data_types": { "grad_accum_dtype": "fp32" } in ds_config.

Here are the relevant ds_config im using:

ds_config = {
               
                "train_batch_size": 1024, 
                "train_micro_batch_size_per_gpu": 8, # grad_acc of 16 will get 1024 effective batch size
                "prescale_gradients": True | False, # ive tried both
                "zero_optimization": {
                    "stage": 0,
                    "cpu_offload": False,
                    "overlap_comm": True | False, # ive tried both
                    "reduce_scatter": True,
                    "reduce_bucket_size": model_hidden_size * model_hidden_size,
                    "contiguous_gradients": True,

                },
                "gradient_clipping": 1.0,
                "data_types": { "grad_accum_dtype": "fp32" },
                "bf16": {
                    "enabled": True
                },
            }
image ^ I also tried to run a full precision FP32 run with per_gpu_batch_size of 2 and a high gradient accumulation of 128 and I still see a big difference in training losses (Blue is Deepspeed, yellow is FSDP)

Given the other settings are same (lr, lr_Scheduler, optimizer etc), what could be causing this difference?

To Reproduce

For deepspeed version of Olmo, Im using the changes in this pull request with the latest code changes. I can share more details if needed

ds_report output

 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn is not compatible with ROCM
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn is not compatible with ROCM
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/conda/envs/olmo/lib/python3.9/site-packages/torch']
torch version .................... 2.3.0a0+gitae01701
deepspeed install path ........... ['/opt/conda/envs/olmo/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.14.5+unknown, unknown, unknown
torch cuda version ............... None
torch hip version ................ 6.1.40091-a8dbc0c19
nvcc version ..................... None
deepspeed wheel compiled w. ...... torch 2.3, hip 6.1
shared memory (/dev/shm) size .... 503.85 GB

System info (please complete the following information):

  • GPU count and types: 8 Mi250
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]: single node
  • Python version: 3.9

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?: im using torchrun

Docker context
Are you using a specific docker image that you can share? rocm/pytorch:latest (ROCm 6.1)

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtraining

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions