Skip to content

Init on meta device and then materialize on gpu leads to very large errors #36577

@fingertap

Description

@fingertap

System Info

  • transformers version: 4.47.1
  • Platform: Linux-5.4.0-153-generic-x86_64-with-glibc2.35
  • Python version: 3.10.16
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.5.1
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A800-SXM4-80GB

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Issue

I prepared a repo to reproduce this issue.

The key is that the language loss produced by the model using the following initialization pipeline is very large compared to directly initialize it on the GPU with from_pretrained. The pipeline to create large-error model

  1. init the model with init_empty_weights
  2. materialize the model with to_empty
  3. load in the weights with load_state_dict from a model that is created by from_pretrained.

Motivation

The reason why I want to initialize the model in this way is that I need to support load large models into GPU using FSDP. I found that by passing a meta-device model to FSDP with param_init_fn being to_empty, the language loss is large. If I pass a cuda model, it is ok. Later I found that this is not the problem of FSDP. Using to_empty on single-GPU case will still lead to large loss. In the era of large language models, this issue is VERY CRITICAL for scaling up to larger models.

Expected behavior

A normal loss after loading the state dict, no matter how the model is initialized (should work after to_empty). In my repro example, the loss shoule be ~0.15. At least it should < 1.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions