Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support InternLM #4137

Merged
merged 17 commits into from
Sep 18, 2023
Merged

Conversation

wangruohui
Copy link
Contributor

This PR is to support a new model named InternLM.

This model is similar to llama but with bias on qkvo matmul. So I primarily duplicate codes of llama model and add support for bias in Attention's python code.

This branch is previously checkoutted from tag v0.10.0 and I have tested locally on single GPU or multiple with TP.

@wangruohui
Copy link
Contributor Author

@microsoft-github-policy-service agree

@wangruohui wangruohui closed this Aug 11, 2023
@wangruohui wangruohui deleted the support_internlm_0.10.0 branch August 11, 2023 13:00
@wangruohui wangruohui restored the support_internlm_0.10.0 branch August 11, 2023 13:01
@wangruohui wangruohui reopened this Aug 11, 2023
@molly-smith molly-smith self-assigned this Aug 18, 2023
@lekurile lekurile self-assigned this Aug 18, 2023
@lekurile
Copy link
Contributor

Hi @wangruohui,

Thank you for expanding our inference capabilities and adding the container for InternLM. Can you please provide an example of a model being used with our inference test?

For example, for the BLOOM model, a command may look like this:

deepspeed --num_gpus 1 inference-test.py --model bigscience/bloom-3b --use_meta --use_kernel

It would be nice to have a similar command for InternLM so we can do some testing on our side as well.

Thanks!

@wangruohui
Copy link
Contributor Author

wangruohui commented Aug 30, 2023

Hello @lekurile

I am working on the test script you provided. Some note:

  1. As InternLM is not integrated into the main branch of transformers, one still need to add trust_remote_code=True to load the model from hub. See
    support trust_remote_code in inference test DeepSpeedExamples#709 for detailed modifications.

  2. To make things easier to manage, I git clone the internlm-7b model from huggingface hub to my home directory like

git lfs install
git clone https://huggingface.co/internlm/internlm-7b

So all test commands below are based on this directory ~/internlm-7b. You may change this on your working server.

  1. InternLM has two variants, the base model internlm-7b and the finetuned one internlm-chat-7b. Tests below use internlm-7b. Please check the model name in below test commands.

  2. I set --greedy to make result reproducible.

Test results

HF Baseline

deepspeed --num_gpus 1 inference-test.py --model ~/internlm-7b/ --hf_baseline --greedy --trust-remote-code
generation time is 1.5033915042877197 sec

in=DeepSpeed is a machine learning framework
out=DeepSpeed is a machine learning framework for deep learning. It is a Python package that provides a set of tools for training and evaluating deep neural networks. It is designed to be easy to use and to provide a consistent interface for all of the different types of neural networks that can be trained

Use Kernels

deepspeed --num_gpus 1 inference-test.py --model ~/internlm-7b/ --use_kernel --greedy --trust_remote_code
generation time is 0.5656006336212158 sec

in=DeepSpeed is a machine learning framework
out=DeepSpeed is a machine learning framework for deep learning. It is a Python package that provides a set of tools for training and evaluating deep neural networks. It is designed to be easy to use and to provide a consistent interface for all of the different types of neural networks that can be trained

Tensor Parallel

deepspeed --num_gpus 2 inference-test.py --model ~/internlm-7b/ --use_kernel --greedy --trust_remote_code
generation time is 0.6287708282470703 sec

in=DeepSpeed is a machine learning framework
out=DeepSpeed is a machine learning framework for deep learning. It is a Python package that provides a set of tools for training and evaluating deep neural networks. It is designed to be easy to use and to provide a consistent interface for all of the different types of neural networks that can be trained

@lekurile
Copy link
Contributor

@wangruohui Appreciate the very detailed testing and reproduction details. I've merged the latest in master and kicked off some tests. I'll also try getting the model to run on my end as well.

Thanks,
Lev

@wangruohui
Copy link
Contributor Author

Hello,

Any updates?
And feel free to talk to me if you need some help at my side.

@lekurile
Copy link
Contributor

lekurile commented Sep 11, 2023

Hello,

Any updates? And feel free to talk to me if you need some help at my side.

Hi @wangruohui,

Looks like the changes in this PR may have caused some issues with other models, specifically in the following unit tests:

  • unit/inference/test_inference.py::TestMPSize::test[fp32-gpt-neo] FAILED [ 80%]
  • unit/inference/test_inference.py::TestMPSize::test[fp16-bloom] FAILED [ 86%]
  • unit/inference/test_inference.py::TestMPSize::test[fp16-gpt-neo] FAILED [ 88%]

I'm suspecting this is due to changes in deepspeed/ops/transformer/inference/ds_attention.py and deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py potentially breaking the behavior of GPT-Neo and BLOOM models, since all the other file changes are self-contained to InternLM.

Can you kindly test these models for compatibility with the InternLM changes on your side? I can try testing as well.

Thanks,
Lev

@wangruohui
Copy link
Contributor Author

Hello @lekurile

I made some modification to make GPTNeo compatible but I cannot set up an exactly the same env to run all tests. Would you please allow the CI to run to check if the problem is solved?

@molly-smith molly-smith removed their assignment Sep 14, 2023
@lekurile lekurile added this pull request to the merge queue Sep 18, 2023
Merged via the queue into microsoft:master with commit 367d6f9 Sep 18, 2023
16 checks passed
CurryRice233 pushed a commit to CurryRice233/DeepSpeed that referenced this pull request Sep 28, 2023
* origin/master:
  Allow multiple inference engines in single script (microsoft#4384)
  adds triton flash attention2 kernel (microsoft#4337)
  Fix llama meta tensor loading in AutoTP and kernel injected inference (microsoft#3608)
  Fix min torch version (microsoft#4375)
  Fix multinode runner to properly append to PDSH_SSH_ARGS_APPEND (microsoft#4373)
  add the missing method (microsoft#4363)
  Openfold fix (microsoft#4368)
  deepspeed4science japanese blog (microsoft#4369)
  deepspeed4science chinese blog (microsoft#4366)
  Enable workflow dispatch on Torch 1.10 CI tests (microsoft#4361)
  Update conda env to have max pydantic version (microsoft#4362)
  add deepspeed4science blog link (microsoft#4364)
  added check to avoid undefined behavior when the input_id length is greater than max_tokens (microsoft#4349)
  Add the policy to run llama model from the official repo (microsoft#4313)
  fix deepspeed4science links (microsoft#4358)
  DeepSpeed4Science (microsoft#4357)
  Support InternLM (microsoft#4137)
  Pass base_dir to model files can be loaded for auto-tp/meta-tensor. (microsoft#4348)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants