Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_best_model_at_end failed due to "size mismatch" when DeepSpeed is used #14628

Closed
dunalduck0 opened this issue Dec 5, 2021 · 9 comments · Fixed by #14652
Closed

load_best_model_at_end failed due to "size mismatch" when DeepSpeed is used #14628

dunalduck0 opened this issue Dec 5, 2021 · 9 comments · Fixed by #14652
Assignees

Comments

@dunalduck0
Copy link

dunalduck0 commented Dec 5, 2021

Environment info

  • transformers version: 4.13.0.dev0
  • Platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.17
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.10.0+cu113 (True)
  • Tensorflow version (GPU?): 2.7.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.3.6 (cpu)
  • Jax version: 0.2.25
  • JaxLib version: 0.1.74
  • Using GPU in script?: Yes, 1 GPU
  • Using distributed or parallel set-up in script?: No

Who can help

@stas00

Information

EleutherAI/gpt-neo-1.3B and EleutherAI/gpt-j-6B

The problem arises when using: run_clm.py with DeepSpeed and --load_best_model_at_end

The tasks I am working on is: a toy fine-tuning

To reproduce

Steps to reproduce the behavior:

  1. Run run_clm.py without DeepSpeed finished successfully:
+ export TOKENIZERS_PARALLELISM=false
+ TOKENIZERS_PARALLELISM=false
+ CUDA_VISIBLE_DEVICES=0
+ ./run_clm.py --model_name_or_path=EleutherAI/gpt-neo-1.3B --dataset_name=wikitext --dataset_config_name=wikitext-2-raw-v1 --output_dir=output.test --overwrite_output_dir=true --do_train=true --max_train_samples=100 --do_eval=true --max_eval_samples=100 --logging_strategy=steps --logging_steps=10 --evaluation_strategy=steps --eval_steps=3 --save_strategy=steps --save_steps=3 --save_total_limit=2 --load_best_model_at_end=true --per_device_train_batch_size=16 --per_device_eval_batch_size=4 --gradient_accumulation_steps=1 --gradient_checkpointing=true --num_train_epochs=2
  1. Run similar command with DeepSpeed failed
+ export TOKENIZERS_PARALLELISM=false
+ TOKENIZERS_PARALLELISM=false
+ deepspeed --num_gpus 1 run_clm.py --deepspeed=zero3.json --model_name_or_path=EleutherAI/gpt-neo-1.3B --dataset_name=wikitext --dataset_config_name=wikitext-2-raw-v1 --output_dir=output.test --overwrite_output_dir=true --do_train=true --max_train_samples=100 --do_eval=true --max_eval_samples=100 --logging_strategy=steps --logging_steps=10 --evaluation_strategy=steps --eval_steps=3 --save_strategy=steps --save_steps=3 --save_total_limit=2 --load_best_model_at_end=true --per_device_train_batch_size=16 --per_device_eval_batch_size=4 --gradient_accumulation_steps=1 --gradient_checkpointing=true --num_train_epochs=2

Error stack:

Training completed. Do not forget to share your model on huggingface.co/models =)


[INFO|trainer.py:1431] 2021-12-05 01:29:52,800 >> Loading best model from output.test/checkpoint-9 (score: 2.5225963592529297).
Traceback (most recent call last):
  File "run_clm.py", line 536, in <module>
    main()
  File "run_clm.py", line 484, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/home/meiyang/src/transformers_fork/src/transformers/trainer.py", line 1440, in train
    self._load_state_dict_in_model(state_dict)
  File "/home/meiyang/src/transformers_fork/src/transformers/trainer.py", line 1472, in _load_state_dict_in_model
    load_result = self.model.load_state_dict(state_dict, strict=False)
  File "/home/meiyang/bin/miniconda3/envs/gptj/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPTNeoForCausalLM:
        size mismatch for transformer.wte.weight: copying a param with shape torch.Size([50257, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.0.attn.attention.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.0.attn.attention.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.0.attn.attention.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.0.attn.attention.out_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.0.mlp.c_fc.weight: copying a param with shape torch.Size([8192, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([2048, 8192]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.1.attn.attention.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.1.attn.attention.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.1.attn.attention.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.1.attn.attention.out_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.1.mlp.c_fc.weight: copying a param with shape torch.Size([8192, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([2048, 8192]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.2.attn.attention.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.2.attn.attention.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
        size mismatch for transformer.h.2.attn.attention.q_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([1]).
:

Expected behavior

--load_best_model_at_end should not crash with DeepSpeed

@sgugger
Copy link
Collaborator

sgugger commented Dec 6, 2021

cc @stas00

@stas00 stas00 self-assigned this Dec 6, 2021
@stas00
Copy link
Contributor

stas00 commented Dec 6, 2021

Thank you for this report, @dunalduck0 - I didn't have this "path" tested (or ever used it). I can reproduce the problem with a much faster:

rm -r output_dir; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 deepspeed --num_gpus=1 examples/pytorch/translation/run_translation.py --model_name_or_path hf-internal-testing/tiny-random-t5 --output_dir output_dir --overwrite_output_dir --per_device_train_batch_size 1 --sortish_sampler --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" --source_prefix "translate English to Romanian: "  --deepspeed tests/deepspeed/ds_config_zero3.json  --do_train --max_train_samples 3 --do_eval --max_eval_samples 1 --logging_strategy steps --logging_steps 1 --evaluation_strategy steps --eval_steps 1 --save_strategy steps --save_steps 1  --load_best_model_at_end --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --num_train_epochs 1

will work on solving it and report back when I have something to show.

@stas00
Copy link
Contributor

stas00 commented Dec 6, 2021

Please try with this PR #14652 and let me know if that fixes the issue for you.

Thank you.

@dunalduck0
Copy link
Author

Thank you @stas00. What do I do to get your fix into my local box? I have a Linux bot and a local copy of fork of Huggingface/transformer. It looks like your fix has not been merged into HuggingFace/transformers

@stas00
Copy link
Contributor

stas00 commented Dec 7, 2021

Indeed, it's not merged yet. I need to add tests first.

Here are some of the ways you can try my PR's branch:

if you have gh installed (https://github.com/cli/cli)

git clone https://github.com/huggingface/transformers
cd transformers
gh pr checkout 14652
pip install -e .

Or you can clone my fork and switch to that branch:

git clone https://github.com/stas00/transformers
cd transformers
git checkout ds-load-best-model
pip install -e .

update: The PR is ready to be merged, but I want to make sure you validate it first that it indeed solves the problem for you.

@dunalduck0
Copy link
Author

Will do it tonight

@stas00
Copy link
Contributor

stas00 commented Dec 7, 2021

Perfect. I will merge as soon as you give me green light.

@dunalduck0
Copy link
Author

The fix worked on my box. Thank you stas00 for quick fix.

Minor comment: the logging confused me at first time. I saw DeepSpeed re-initialized everything and I almost thought the program restarted again :P.

@stas00
Copy link
Contributor

stas00 commented Dec 7, 2021

Thank you for testing, @dunalduck0

Once Deepspeed fixes this issue the full restart will go away

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants