Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable removing shared tensors by default #28630

Closed
2 of 4 tasks
imoneoi opened this issue Jan 22, 2024 · 8 comments
Closed
2 of 4 tasks

Disable removing shared tensors by default #28630

imoneoi opened this issue Jan 22, 2024 · 8 comments

Comments

@imoneoi
Copy link

imoneoi commented Jan 22, 2024

System Info

- `transformers` version: 4.36.2
- Platform: Linux-5.4.0-167-generic-x86_64-with-glibc2.31
- Python version: 3.11.5
- Huggingface_hub version: 0.20.1
- Safetensors version: 0.4.1
- Accelerate version: 0.25.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes, torchrun

Who can help?

@younesbelkada @Narsil

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Minimal reproduction on DeepSpeed can be found at #27293 where disabling safe_serialization solves this issue.

Related (DeepSpeed): #27293

Expected behavior

Consider disabling removing shared tensors by default in https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2409-L2452. This piece of code determines shared tensors through storage locations, but there are many cases that tensors are views of a large tensor, thus sharing the same location.

One example is when q_proj, k_proj, and v_proj are views of qkv_proj, and also DeepSpeed ZeRO, where all parameters are views of a large flat tensor. We've observed failures in both cases.

Besides, not removing shared tensors will not usually cause a large storage overhead as common shared tensors (such as tied embeddings) take up only a small fraction of the total parameters.

Removed shared tensor {'model.layers.27.self_attn.k_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.28.self_attn.k_proj.weight', 'model.layers.29.self_attn.k_proj.weight', 'model.layers.30.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.28.self_attn.v_proj.weight', 'model.layers.30.self_attn.v_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.29.self_attn.v_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.31.self_attn.k_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.31.self_attn.v_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.11.self_attn.k_proj.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading
@younesbelkada
Copy link
Contributor

Hi @imoneoi
Thanks for the issue!
I don't think we can disable sharding by default as it might break many things such as the ability to load models on a free-tier google colab instance. Among many possible options, few fixes that I see for your case and to fix #27293 are:

1- Warn users if they are using DS to not save their model with safe_serialization
2- Make that block optional through an argument shard_weights=True and either set it to False for DeepSpeed or warn users about it in case they are using DeepSpeed

--> in general we encourage users to use safetensors, so I would say option 2 might be the best solution here

Would you be happy to open a PR with one of these solutions ? cc @amyeroberts @pacman100 @muellerzr what do you think

@LysandreJik
Copy link
Member

Hmmm I think what @imoneoi is reporting is a different issue than what you're describing @younesbelkada, namely that safetensors refuses shared (and not sharded) tensor serialization and therefore removes the copies of the same tensors in the state dict.

We're definitely aiming for this to be frictionless, so the more insights we have in the code that fails, the better we'll be able to help.

Thanks @muellerzr for the minimal reproducer on the other thread, I'm pasting it below:

import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin, HfDeepSpeedConfig
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import unwrap_model

transformers_config = HfDeepSpeedConfig({
    "train_micro_batch_size_per_gpu": 2,
    "gradient_accumulation_steps": 2,
    "gradient_clipping": 1.0,
    "offload_optimizer_device": None,
    "offload_param_device": None,
    "zero3_init_flag": False,
    "zero_optimization": {
    "stage": 2,
    },
})

plugin = DeepSpeedPlugin(transformers_config)

accelerator = Accelerator(deepspeed_plugin=plugin)

model_name = "bert-base-cased"
model = AutoModelForCausalLM.from_pretrained(model_name)

opt = torch.optim.Adam(model.parameters(), lr=1e-5)

model, opt = accelerator._prepare_deepspeed(model, opt)

state_dict = accelerator.get_state_dict(model)

model = unwrap_model(model)
model.save_pretrained(
    "testing_fuyu_8b",
    state_dict=state_dict,
    safe_serialization=True
)

cc @Narsil if you have the bandwidth to take a look, this looks like it's impacting quite a few deepspeed users. Thanks a lot 🙌

@imoneoi
Copy link
Author

imoneoi commented Jan 24, 2024

Temporary solution: set safe_serialization=False will work

@Narsil
Copy link
Contributor

Narsil commented Jan 24, 2024

Did look up, and this snippet works for me with all latest revisions. (accelerate, deepspeed, transformers)

@pacman100
Copy link
Contributor

Hello,

  1. Versions:
- `transformers` version: 4.37.0
- `Accelerate` version: 0.26.1
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.13
- Numpy version: 1.26.0
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 503.54 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
	Not found
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.20.2
- Safetensors version: 0.4.0
- Accelerate version: 0.26.1
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Tensorflow version (GPU?): 2.15.0 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/torch']
torch version .................... 2.1.2+cu121
deepspeed install path ........... ['/raid/sourab/miniconda3/envs/hf/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.12.6, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.1
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 251.77 GB
  1. Code:
import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin, HfDeepSpeedConfig
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import unwrap_model

transformers_config = HfDeepSpeedConfig({
    "train_micro_batch_size_per_gpu": 2,
    "gradient_accumulation_steps": 2,
    "gradient_clipping": 1.0,
    "offload_optimizer_device": None,
    "offload_param_device": None,
    "zero3_init_flag": False,
    "zero_optimization": {
    "stage": 3,
    "stage3_gather_16bit_weights_on_model_save": True
    },
})

plugin = DeepSpeedPlugin(transformers_config)

accelerator = Accelerator(deepspeed_plugin=plugin)

model_name = "bert-base-cased"
model = AutoModelForCausalLM.from_pretrained(model_name)

opt = torch.optim.Adam(model.parameters(), lr=1e-5)

model, opt = accelerator._prepare_deepspeed(model, opt)

state_dict = accelerator.get_state_dict(model)

model = unwrap_model(model)
model.save_pretrained(
    "remove",
    state_dict=state_dict,
    safe_serialization=True
)
  1. Command:
torchrun --nproc-per-node 2 issue_28630.py
  1. Output:
[2024-01-24 13:01:29,798] [INFO] [config.py:974:print_user_config]   json = {
    "train_micro_batch_size_per_gpu": 2, 
    "gradient_accumulation_steps": 2, 
    "gradient_clipping": 1.0, 
    "offload_optimizer_device": null, 
    "offload_param_device": null, 
    "zero3_init_flag": false, 
    "zero_optimization": {
        "stage": 3, 
        "stage3_gather_16bit_weights_on_model_save": true
    }, 
    "steps_per_print": inf, 
    "fp16": {
        "enabled": false
    }, 
    "bf16": {
        "enabled": false
    }, 
    "zero_allow_untested_optimizer": true
}
Removed shared tensor {'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.1.attention.self.value.weight', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.2.output.dense.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.1.attention.output.dense.weight', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.1.output.dense.weight', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.8.intermediate.dense.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.embeddings.position_embeddings.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.1.attention.self.key.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

Observations:

  1. Happens when using DeepSpeed Stage 3 when weights from many layers are concatenated, flattened and sharded across device(s). Basically when using flat tensors from which views are taken for individual layers as mentioned by @imoneoi
  2. Also, this is not limited to just DeepSpeed. For example, when using Torch compile also as shown by Shared tensors not correctly saved. #27293 (comment) and I can reproduce it.
  3. Also, it again happens for FSDP too. Able to reproduce it for Accelerate FSDP always removed {'model.norm.weight'} layer of model when saving them accelerate#2155 (comment) with below command:
accelerate launch --config_file fsdp_config.yaml run_mlm_no_trainer.py \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --model_name_or_path bert-base-cased \
    --output_dir /tmp/test-mlm

with config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: BertLayer
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

output:

Removed shared tensor {'cls.predictions.transform.dense.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.embeddings.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading

Possible Solutions:
Disable safetensors for DeepSpeed/FSDP when there are shared tensors other then the ones specified via model.config.tie_encoder_decoder and model.config.tie_word_embeddings

@pacman100
Copy link
Contributor

I think the reproducer from Zach needs fixes. With below change to only call save_pretrained on main process, the checkpoint is saved properly when using DeepSpeed.

import torch
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin, HfDeepSpeedConfig
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import unwrap_model

transformers_config = HfDeepSpeedConfig({
    "train_micro_batch_size_per_gpu": 2,
    "gradient_accumulation_steps": 2,
    "gradient_clipping": 1.0,
    "offload_optimizer_device": None,
    "offload_param_device": None,
    "zero3_init_flag": False,
    "zero_optimization": {
    "stage": 3,
    "stage3_gather_16bit_weights_on_model_save": True
    },
})

plugin = DeepSpeedPlugin(transformers_config)

accelerator = Accelerator(deepspeed_plugin=plugin)

model_name = "bert-base-cased"
model = AutoModelForCausalLM.from_pretrained(model_name)

opt = torch.optim.Adam(model.parameters(), lr=1e-5)

model, opt = accelerator._prepare_deepspeed(model, opt)

state_dict = accelerator.get_state_dict(model)

+ if accelerator.is_main_process:
        model = unwrap_model(model)
        model.save_pretrained(
            "remove",
            state_dict=state_dict,
            safe_serialization=True
        )

@imoneoi
Copy link
Author

imoneoi commented Jan 25, 2024

@pacman100 @younesbelkada Thanks for your observations! Should we consider disabling safetensors and warn the user about safetensors is disabled when shared tensors are found as a quick fix to mitigate issues in deepspeed, FSDP and torch.compile?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants