Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds triton flash attention2 kernel #4337

Merged
merged 14 commits into from
Sep 21, 2023
Merged

adds triton flash attention2 kernel #4337

merged 14 commits into from
Sep 21, 2023

Conversation

stephen-youn
Copy link
Contributor

@stephen-youn stephen-youn commented Sep 14, 2023

This PR adds flash attention 2 kernel implemented in triton2.1 for inference.
Benchmarking on bert-base shows 4~13% further latency reduction (the longer the sequence length, the larger the performance gain) when compared to the case using regular attention triton kernel.

image

Copy link
Contributor

@lekurile lekurile left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment about the TestModelTask unit test, other than that LGTM!

tests/unit/inference/test_inference.py Outdated Show resolved Hide resolved
Copy link
Contributor

@lekurile lekurile left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@stephen-youn stephen-youn added this pull request to the merge queue Sep 20, 2023
Merged via the queue into master with commit 0e0748c Sep 21, 2023
16 checks passed
CurryRice233 pushed a commit to CurryRice233/DeepSpeed that referenced this pull request Sep 28, 2023
* origin/master:
  Allow multiple inference engines in single script (microsoft#4384)
  adds triton flash attention2 kernel (microsoft#4337)
  Fix llama meta tensor loading in AutoTP and kernel injected inference (microsoft#3608)
  Fix min torch version (microsoft#4375)
  Fix multinode runner to properly append to PDSH_SSH_ARGS_APPEND (microsoft#4373)
  add the missing method (microsoft#4363)
  Openfold fix (microsoft#4368)
  deepspeed4science japanese blog (microsoft#4369)
  deepspeed4science chinese blog (microsoft#4366)
  Enable workflow dispatch on Torch 1.10 CI tests (microsoft#4361)
  Update conda env to have max pydantic version (microsoft#4362)
  add deepspeed4science blog link (microsoft#4364)
  added check to avoid undefined behavior when the input_id length is greater than max_tokens (microsoft#4349)
  Add the policy to run llama model from the official repo (microsoft#4313)
  fix deepspeed4science links (microsoft#4358)
  DeepSpeed4Science (microsoft#4357)
  Support InternLM (microsoft#4137)
  Pass base_dir to model files can be loaded for auto-tp/meta-tensor. (microsoft#4348)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants