Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplementedError: Cannot copy out of meta tensor; no data! with Multi-node training #26971

Closed
2 of 4 tasks
ari9dam opened this issue Oct 21, 2023 · 5 comments
Closed
2 of 4 tasks

Comments

@ari9dam
Copy link

ari9dam commented Oct 21, 2023

System Info

A100
Cuda 11.7
PyTorch 2.0.1
# This dependencies file is produced by 'conda export'
{
  "channels": [
    "pytorch",
    "defaults"
  ],
  "dependencies": [
    "_libgcc_mutex=0.1=main",
    "_openmp_mutex=5.1=1_gnu",
    "ca-certificates=2023.01.10=h06a4308_0",
    "ld_impl_linux-64=2.38=h1181459_1",
    "libffi=3.4.4=h6a678d5_0",
    "libgcc-ng=11.2.0=h1234567_1",
    "libgomp=11.2.0=h1234567_1",
    "libstdcxx-ng=11.2.0=h1234567_1",
    "magma-cuda117=2.6.1=1",
    "ncurses=6.4=h6a678d5_0",
    "openssl=1.1.1t=h7f8727e_0",
    "pip=23.0.1=py38h06a4308_0",
    "python=3.8.16=h7a1cb2a_3",
    "readline=8.2=h5eee18b_0",
    "sqlite=3.41.2=h5eee18b_0",
    "tk=8.6.12=h1ccaba5_0",
    "xz=5.4.2=h5eee18b_0",
    "zlib=1.2.13=h5eee18b_0",
    {
      "pip": [
        "absl-py==2.0.0",
        "accelerate==0.24.0.dev0",
        "adal==1.2.7",
        "aiofiles==23.1.0",
        "aiohttp==3.8.4",
        "aiosignal==1.3.1",
        "altair==5.1.2",
        "antlr4-python3-runtime==4.9.3",
        "anyio==3.7.1",
        "apex==0.1",
        "applicationinsights==0.11.10",
        "argcomplete==2.1.2",
        "asttokens==2.4.0",
        "async-timeout==4.0.2",
        "attrs==23.1.0",
        "azure-common==1.1.28",
        "azure-core==1.26.4",
        "azure-graphrbac==0.61.1",
        "azure-identity==1.13.0",
        "azure-mgmt-authorization==3.0.0",
        "azure-mgmt-containerregistry==10.2.0",
        "azure-mgmt-core==1.4.0",
        "azure-mgmt-keyvault==10.3.0",
        "azure-mgmt-resource==22.0.0",
        "azure-mgmt-storage==21.0.0",
        "azure-ml==0.0.1",
        "azure-ml-component==0.9.18.post2",
        "azure-storage-blob==12.13.0",
        "azureml-automl-common-tools==1.51.0",
        "azureml-automl-core==1.51.0.post1",
        "azureml-contrib-services==1.51.0",
        "azureml-core==1.51.0",
        "azureml-dataprep==4.10.9",
        "azureml-dataprep-native==38.0.0",
        "azureml-dataprep-rslex==2.17.12",
        "azureml-dataset-runtime==1.51.0",
        "azureml-defaults==1.51.0",
        "azureml-inference-server-http==0.8.4.1",
        "azureml-mlflow==1.51.0",
        "azureml-pipeline==1.51.0",
        "azureml-pipeline-core==1.51.0",
        "azureml-pipeline-steps==1.51.0",
        "azureml-sdk==1.51.0",
        "azureml-telemetry==1.51.0",
        "azureml-train-automl-client==1.51.0.post1",
        "azureml-train-core==1.51.0",
        "azureml-train-restclients-hyperdrive==1.51.0",
        "backcall==0.2.0",
        "backports-tempfile==1.0",
        "backports-weakref==1.0.post1",
        "bcrypt==4.0.1",
        "bytecode==0.15.1",
        "cachetools==5.3.0",
        "cerberus==1.3.4",
        "certifi==2023.5.7",
        "cffi==1.15.1",
        "charset-normalizer==3.1.0",
        "click==8.1.7",
        "cloudpickle==2.2.1",
        "cmake==3.26.3",
        "coloredlogs==15.0.1",
        "comm==0.1.4",
        "contextlib2==21.6.0",
        "coverage==6.3.1",
        "cryptography==40.0.2",
        "cycler==0.12.1",
        "databricks-cli==0.18.0",
        "datasets==2.14.5",
        "debugpy==1.6.7.post1",
        "decorator==5.1.1",
        "deepspeed==0.9.1",
        "dill==0.3.7",
        "distro==1.8.0",
        "docker==6.1.3",
        "dotnetcore2==3.1.23",
        "einops==0.7.0",
        "entrypoints==0.4",
        "evaluate==0.4.1",
        "exceptiongroup==1.1.3",
        "executing==2.0.0",
        "fairscale==0.4.13",
        "fastapi==0.104.0",
        "ffmpy==0.3.1",
        "filelock==3.12.0",
        "flash-attn==2.3.2",
        "flask==2.2.5",
        "flask-cors==3.0.10",
        "flatbuffers==23.5.9",
        "fonttools==4.43.1",
        "frozenlist==1.3.3",
        "fsspec==2023.5.0",
        "fusepy==3.0.1",
        "gitdb==4.0.11",
        "gitpython==3.1.40",
        "google-api-core==2.11.0",
        "google-auth==2.19.0",
        "google-auth-oauthlib==0.4.6",
        "googleapis-common-protos==1.59.0",
        "gradio==3.23.0",
        "grpcio==1.59.0",
        "gunicorn==20.1.0",
        "h11==0.14.0",
        "h5py==3.8.0",
        "hjson==3.1.0",
        "horovod==0.24.2",
        "httpcore==0.18.0",
        "httpx==0.25.0",
        "huggingface-hub==0.17.3",
        "humanfriendly==10.0",
        "idna==3.4",
        "igraph==0.10.4",
        "importlib-metadata==6.6.0",
        "importlib-resources==6.1.0",
        "inference-schema==1.5.1",
        "inflector==3.1.0",
        "iniconfig==2.0.0",
        "intel-openmp==2021.4.0",
        "ipykernel==6.25.2",
        "ipython==8.12.3",
        "isodate==0.6.1",
        "itsdangerous==2.1.2",
        "jedi==0.19.1",
        "jeepney==0.8.0",
        "jinja2==3.1.2",
        "jmespath==1.0.1",
        "joblib==1.3.2",
        "jsonlines==4.0.0",
        "jsonpickle==3.0.2",
        "jsonschema==4.19.1",
        "jsonschema-specifications==2023.7.1",
        "jupyter-client==8.4.0",
        "jupyter-core==5.4.0",
        "kiwisolver==1.4.5",
        "knack==0.10.1",
        "lightning-utilities==0.8.0",
        "linkify-it-py==2.0.2",
        "lit==16.0.5",
        "lxml==4.9.2",
        "markdown==3.5",
        "markdown-it-py==2.2.0",
        "markdown2==2.4.10",
        "markupsafe==2.1.2",
        "matplotlib==3.5.3",
        "matplotlib-inline==0.1.6",
        "mdit-py-plugins==0.3.3",
        "mdurl==0.1.2",
        "mkl==2021.4.0",
        "mkl-include==2021.4.0",
        "mlflow-skinny==2.7.1",
        "mpi4py==3.1.1",
        "mpmath==1.3.0",
        "msal==1.22.0",
        "msal-extensions==1.0.0",
        "msccl==2.3.0",
        "msrest==0.7.1",
        "msrestazure==0.6.4",
        "multidict==6.0.4",
        "multiprocess==0.70.15",
        "ndg-httpsclient==0.5.1",
        "nebulaml==0.16.2",
        "nest-asyncio==1.5.6",
        "networkx==3.1",
        "ninja==1.10.2",
        "nltk==3.8.1",
        "numpy==1.22.2",
        "oauthlib==3.2.2",
        "omegaconf==2.3.0",
        "onnx==1.14.0",
        "onnxruntime-gpu==1.16.1",
        "onnxruntime-training==1.14.1",
        "opencensus==0.11.2",
        "opencensus-context==0.1.3",
        "opencensus-ext-azure==1.1.9",
        "opencensus-ext-logging==0.1.1",
        "orjson==3.9.9",
        "packaging==23.0",
        "pandas==2.0.3",
        "paramiko==3.3.1",
        "parso==0.8.3",
        "pathspec==0.11.2",
        "pexpect==4.8.0",
        "pickleshare==0.7.5",
        "pillow==9.5.0",
        "pkginfo==1.9.6",
        "pkgutil-resolve-name==1.3.10",
        "platformdirs==3.11.0",
        "pluggy==1.0.0",
        "portalocker==2.7.0",
        "prompt-toolkit==3.0.39",
        "protobuf==3.20.3",
        "psutil==5.8.0",
        "ptyprocess==0.7.0",
        "pure-eval==0.2.2",
        "py==1.11.0",
        "py-cpuinfo==5.0.0",
        "py-spy==0.3.12",
        "pyarrow==9.0.0",
        "pyasn1==0.5.0",
        "pyasn1-modules==0.3.0",
        "pybind11==2.11.1",
        "pycparser==2.21",
        "pydantic==1.10.8",
        "pydash==7.0.6",
        "pydub==0.25.1",
        "pygments==2.16.1",
        "pyjwt==2.7.0",
        "pynacl==1.5.0",
        "pyopenssl==23.2.0",
        "pyparsing==3.1.1",
        "pysocks==1.7.1",
        "pytest==7.1.0",
        "pytest-mpi==0.6",
        "python-dateutil==2.8.2",
        "python-multipart==0.0.6",
        "pytorch-lightning==1.9.3",
        "pytz==2023.3.post1",
        "pyyaml==6.0",
        "pyzmq==25.1.1",
        "referencing==0.30.2",
        "regex==2023.10.3",
        "requests==2.31.0",
        "requests-oauthlib==1.3.1",
        "responses==0.18.0",
        "rouge-score==0.1.2",
        "rpds-py==0.10.6",
        "rsa==4.9",
        "ruamel-yaml==0.17.16",
        "ruamel-yaml-clib==0.2.8",
        "safetensors==0.4.0",
        "scipy==1.7.3",
        "secretstorage==3.3.3",
        "semantic-version==2.10.0",
        "sentencepiece==0.1.99",
        "setuptools==67.6.0",
        "six==1.16.0",
        "smmap==5.0.1",
        "sniffio==1.3.0",
        "sqlparse==0.4.4",
        "stack-data==0.6.3",
        "starlette==0.27.0",
        "supervisor==4.2.5",
        "svgwrite==1.4.3",
        "sympy==1.12",
        "tabulate==0.9.0",
        "tbb==2021.9.0",
        "tensorboard==2.11.2",
        "tensorboard-data-server==0.6.1",
        "tensorboard-plugin-wit==1.8.1",
        "texttable==1.6.7",
        "timm==0.9.7",
        "tokenizers==0.14.1",
        "toml==0.10.2",
        "tomli==2.0.1",
        "toolz==0.12.0",
        "torch==2.0.1+cu117",
        "torch-nebula==0.16.2",
        "torch-ort==1.14.0",
        "torch-tb-profiler==0.4.3",
        "torchaudio==2.0.2+cu117",
        "torchmetrics==0.11.3",
        "torchsnapshot==0.1.0",
        "torchvision==0.15.2+cu117",
        "tornado==6.3.3",
        "tqdm==4.62.3",
        "traitlets==5.11.2",
        "transformers==4.35.0.dev0",
        "triton==2.0.0",
        "tutel==0.1",
        "typing-extensions==4.8.0",
        "tzdata==2023.3",
        "uc-micro-py==1.0.2",
        "urllib3==1.26.16",
        "uvicorn==0.23.2",
        "wavedrom==2.0.3.post3",
        "wcwidth==0.2.8",
        "websocket-client==1.6.4",
        "websockets==11.0.3",
        "werkzeug==3.0.0",
        "wheel==0.40.0",
        "wrapt==1.12.1",
        "xxhash==3.4.1",
        "yarl==1.9.2",
        "z3-solver==4.12.2.0",
        "zipp==3.15.0"
      ]
    }
  ],
  "name": "ptca",
  "prefix": "/opt/conda/envs/ptca"
}

Who can help?

@muellerz @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

   model = transformers.AutoModelForCausalLM.from_pretrained(
          "mistralai/Mistral-7B-v0.1",
          torch_dtype=torch.bfloat16,
          use_flash_attention_2=True
      )

    trainer = Trainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    compute_metrics = None,
                    **data_module)

    trainer.train()


The training job works on A100 with 1 node and 8 GPUs. It fails when job uses more than 1 node with the error:

File "./trainer.py", line 206, in <module>
    train()
  File "./trainer.py", line 157, in train
    model = transformers.AutoModelForCausalLM.from_pretrained(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py", line 565, in from_pretrained
    return model_class.from_pretrained(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3333, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py", line 3723, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py", line 744, in _load_state_dict_into_meta_model
    set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/accelerate/utils/modeling.py", line 317, in set_module_tensor_to_device
    new_value = value.to(device)
NotImplementedError: Cannot copy out of meta tensor; no data!

Expected behavior

No error

@ari9dam
Copy link
Author

ari9dam commented Oct 21, 2023

Relevant: #26631 @pacman100

@ari9dam
Copy link
Author

ari9dam commented Oct 21, 2023

compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
  fsdp_use_orig_params: true
main_training_function: main
mixed_precision: bf16
num_machines: 2
num_processes: 16
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

@pacman100
Copy link
Contributor

Hello @ari9dam,

The PR you tagged above should resolve this issue. Please recreate the FSDP config via accelerate config command and answer False for RAM efficient loading of the pretrained model.

@ari9dam
Copy link
Author

ari9dam commented Oct 21, 2023

Thank you that solved it. I've one more question: @pacman100
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
use_flash_attention_2=True
)

should I pass torch dtype here while loading the model? I'm using bf16 in accelerate config. I get warnings:

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').

@Muennighoff
Copy link
Contributor

also had this issue and fixed it by changing

        if (
            is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0
        ) or (is_fsdp_enabled() and not is_local_dist_rank_0()):
            map_location = "meta"

to

        if (
            (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
            and torch.distributed.is_initialized()
            and (torch.distributed.get_rank() % 8 != 0)
        ):
            map_location = "meta"

here


(this is for 8 gpus per node; for 4 gpus per node should be 4 etc)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants