Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Update cuDNN flash attention backward to use new cuDNN frontend (>1.0) #11249

Closed
wants to merge 13 commits into from

Conversation

Cjkkkk
Copy link
Contributor

@Cjkkkk Cjkkkk commented Apr 5, 2024

  • Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
  • Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
  • Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 5, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 5, 2024
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Apr 5, 2024

@xla-rotation

@Cjkkkk Cjkkkk requested a review from akuegel April 5, 2024 10:17
@Cjkkkk Cjkkkk changed the title [XLA:GPU] Update cuDNN flash attention to use new cuDNN frontend (>1.0) [XLA:GPU] Update cuDNN flash attention backward to use new cuDNN frontend (>1.0) Apr 5, 2024
xla/service/gpu/cudnn_workspace_rewriter.cc Outdated Show resolved Hide resolved
xla/service/gpu/cudnn_workspace_rewriter.cc Outdated Show resolved Hide resolved
xla/service/gpu/cudnn_workspace_rewriter.cc Outdated Show resolved Hide resolved
xla/stream_executor/cuda/cuda_dnn.cc Outdated Show resolved Hide resolved
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 8, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 8, 2024
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 8, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 8, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 9, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
edac8073869a2f6557d1735e53652a810c9b5e8e by cjkkkk <ske@nvidia.com>:

init

--
21ba463e41fe84e05760274968886269f51f0c20 by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
9e57b11dc1eddb9b1d5f8fce49b0dbbefc9e178a by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
5fc87a3dfecfafbce2815d86a0d23478a0f6abb0 by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
38b17189c2fc4ecb1fb3909df6731336b2b928ef by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
b6d906fbbe5b3b5bbc4a6d8b03ffd9271d9e4c5d by cjkkkk <ske@nvidia.com>:

remove unused param

--
06c5b69735b08f8f17a6b2ff9fdffad15d3edbf0 by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
f8dd6f56fe7b99a344ac39b93a8fad304ace5c7b by cjkkkk <ske@nvidia.com>:

remove comments

--
07af43d9eec096c0d0ab9d11da71178ed35ec9f4 by cjkkkk <ske@nvidia.com>:

rm header order change

--
c979f10eaac514d9801b45aac61088a412092d0d by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
58855fb510adb223d1c4600a06e35c21f1ae0847 by cjkkkk <ske@nvidia.com>:

move use to def

Merging this change closes #11249

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward 58855fb510adb223d1c4600a06e35c21f1ae0847
PiperOrigin-RevId: 623056335
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 9, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
edac8073869a2f6557d1735e53652a810c9b5e8e by cjkkkk <ske@nvidia.com>:

init

--
21ba463e41fe84e05760274968886269f51f0c20 by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
9e57b11dc1eddb9b1d5f8fce49b0dbbefc9e178a by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
5fc87a3dfecfafbce2815d86a0d23478a0f6abb0 by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
38b17189c2fc4ecb1fb3909df6731336b2b928ef by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
b6d906fbbe5b3b5bbc4a6d8b03ffd9271d9e4c5d by cjkkkk <ske@nvidia.com>:

remove unused param

--
06c5b69735b08f8f17a6b2ff9fdffad15d3edbf0 by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
f8dd6f56fe7b99a344ac39b93a8fad304ace5c7b by cjkkkk <ske@nvidia.com>:

remove comments

--
07af43d9eec096c0d0ab9d11da71178ed35ec9f4 by cjkkkk <ske@nvidia.com>:

rm header order change

--
c979f10eaac514d9801b45aac61088a412092d0d by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
58855fb510adb223d1c4600a06e35c21f1ae0847 by cjkkkk <ske@nvidia.com>:

move use to def

Merging this change closes #11249

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward 58855fb510adb223d1c4600a06e35c21f1ae0847
PiperOrigin-RevId: 623056335
@akuegel
Copy link
Member

akuegel commented Apr 9, 2024

It looks like there is a missing dependency:

error: module //third_party/tensorflow/compiler/xla/service/gpu:cudnn_workspace_rewriter does not depend on a module exporting 'third_party/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h'

Can you please add that?

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 9, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 9, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 10, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
edac8073869a2f6557d1735e53652a810c9b5e8e by cjkkkk <ske@nvidia.com>:

init

--
21ba463e41fe84e05760274968886269f51f0c20 by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
9e57b11dc1eddb9b1d5f8fce49b0dbbefc9e178a by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
5fc87a3dfecfafbce2815d86a0d23478a0f6abb0 by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
38b17189c2fc4ecb1fb3909df6731336b2b928ef by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
b6d906fbbe5b3b5bbc4a6d8b03ffd9271d9e4c5d by cjkkkk <ske@nvidia.com>:

remove unused param

--
06c5b69735b08f8f17a6b2ff9fdffad15d3edbf0 by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
f8dd6f56fe7b99a344ac39b93a8fad304ace5c7b by cjkkkk <ske@nvidia.com>:

remove comments

--
07af43d9eec096c0d0ab9d11da71178ed35ec9f4 by cjkkkk <ske@nvidia.com>:

rm header order change

--
c979f10eaac514d9801b45aac61088a412092d0d by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
58855fb510adb223d1c4600a06e35c21f1ae0847 by cjkkkk <ske@nvidia.com>:

move use to def

--
6a7c5448594abb8aed19bc966557f7842239212d by cjkkkk <ske@nvidia.com>:

add missing dependency in BUILD

Merging this change closes #11249

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward 6a7c5448594abb8aed19bc966557f7842239212d
PiperOrigin-RevId: 623056335
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 10, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
edac8073869a2f6557d1735e53652a810c9b5e8e by cjkkkk <ske@nvidia.com>:

init

--
21ba463e41fe84e05760274968886269f51f0c20 by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
9e57b11dc1eddb9b1d5f8fce49b0dbbefc9e178a by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
5fc87a3dfecfafbce2815d86a0d23478a0f6abb0 by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
38b17189c2fc4ecb1fb3909df6731336b2b928ef by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
b6d906fbbe5b3b5bbc4a6d8b03ffd9271d9e4c5d by cjkkkk <ske@nvidia.com>:

remove unused param

--
06c5b69735b08f8f17a6b2ff9fdffad15d3edbf0 by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
f8dd6f56fe7b99a344ac39b93a8fad304ace5c7b by cjkkkk <ske@nvidia.com>:

remove comments

--
07af43d9eec096c0d0ab9d11da71178ed35ec9f4 by cjkkkk <ske@nvidia.com>:

rm header order change

--
c979f10eaac514d9801b45aac61088a412092d0d by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
58855fb510adb223d1c4600a06e35c21f1ae0847 by cjkkkk <ske@nvidia.com>:

move use to def

--
6a7c5448594abb8aed19bc966557f7842239212d by cjkkkk <ske@nvidia.com>:

add missing dependency in BUILD

Merging this change closes #11249

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward 6a7c5448594abb8aed19bc966557f7842239212d
PiperOrigin-RevId: 623056335
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 10, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
edac8073869a2f6557d1735e53652a810c9b5e8e by cjkkkk <ske@nvidia.com>:

init

--
21ba463e41fe84e05760274968886269f51f0c20 by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
9e57b11dc1eddb9b1d5f8fce49b0dbbefc9e178a by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
5fc87a3dfecfafbce2815d86a0d23478a0f6abb0 by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
38b17189c2fc4ecb1fb3909df6731336b2b928ef by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
b6d906fbbe5b3b5bbc4a6d8b03ffd9271d9e4c5d by cjkkkk <ske@nvidia.com>:

remove unused param

--
06c5b69735b08f8f17a6b2ff9fdffad15d3edbf0 by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
f8dd6f56fe7b99a344ac39b93a8fad304ace5c7b by cjkkkk <ske@nvidia.com>:

remove comments

--
07af43d9eec096c0d0ab9d11da71178ed35ec9f4 by cjkkkk <ske@nvidia.com>:

rm header order change

--
c979f10eaac514d9801b45aac61088a412092d0d by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
58855fb510adb223d1c4600a06e35c21f1ae0847 by cjkkkk <ske@nvidia.com>:

move use to def

--
6a7c5448594abb8aed19bc966557f7842239212d by cjkkkk <ske@nvidia.com>:

add missing dependency in BUILD

Merging this change closes #11249

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward 6a7c5448594abb8aed19bc966557f7842239212d
PiperOrigin-RevId: 623056335
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 11, 2024
@Cjkkkk
Copy link
Contributor Author

Cjkkkk commented Apr 11, 2024

@akuegel the google internal test failure should be fixed now.

@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 11, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 11, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
2a1ee169fb2c12fbbd4e791ff406b3272490cf09 by cjkkkk <ske@nvidia.com>:

init

--
e84388a1859ef26acad7ba68cebedf546098d99b by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
3a0b75ccfa88791d5cc437a9cc86333849d5fee3 by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
77e6db53f4f4d0b5344e3492da21e871ceb0bb0b by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
11dd85db2ebea5cbd12b7108bdc995d8f8f5f1ba by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
2f245303893f0aef5dc48557f8254d16ae373d9d by cjkkkk <ske@nvidia.com>:

remove unused param

--
5e648ae07c743e95955f23fbd1c4fa50a13be5de by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
368e51879c10266a3ecc200c105afc37dcba544c by cjkkkk <ske@nvidia.com>:

remove comments

--
2fb7b3854f57c80e0c2e7805349ff5b637ba2007 by cjkkkk <ske@nvidia.com>:

rm header order change

--
d0cb184a32a5489f17ca5a20cd11901544e88ba1 by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
f09eb59328d2416374c31c5336744231b67954cd by cjkkkk <ske@nvidia.com>:

move use to def

--
1c56f1fe670cfc43cc1bfcbc16ed75cf9ce5357f by cjkkkk <ske@nvidia.com>:

add missing dependency in BUILD

--
efd8ff45b5cf13e0ae820c03150068a6412204d0 by cjkkkk <ske@nvidia.com>:

pass in correct descriptor to activation

Merging this change closes #11249

Reverts de029c0

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward efd8ff45b5cf13e0ae820c03150068a6412204d0
PiperOrigin-RevId: 623056335
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 11, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
2a1ee169fb2c12fbbd4e791ff406b3272490cf09 by cjkkkk <ske@nvidia.com>:

init

--
e84388a1859ef26acad7ba68cebedf546098d99b by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
3a0b75ccfa88791d5cc437a9cc86333849d5fee3 by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
77e6db53f4f4d0b5344e3492da21e871ceb0bb0b by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
11dd85db2ebea5cbd12b7108bdc995d8f8f5f1ba by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
2f245303893f0aef5dc48557f8254d16ae373d9d by cjkkkk <ske@nvidia.com>:

remove unused param

--
5e648ae07c743e95955f23fbd1c4fa50a13be5de by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
368e51879c10266a3ecc200c105afc37dcba544c by cjkkkk <ske@nvidia.com>:

remove comments

--
2fb7b3854f57c80e0c2e7805349ff5b637ba2007 by cjkkkk <ske@nvidia.com>:

rm header order change

--
d0cb184a32a5489f17ca5a20cd11901544e88ba1 by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
f09eb59328d2416374c31c5336744231b67954cd by cjkkkk <ske@nvidia.com>:

move use to def

--
1c56f1fe670cfc43cc1bfcbc16ed75cf9ce5357f by cjkkkk <ske@nvidia.com>:

add missing dependency in BUILD

--
efd8ff45b5cf13e0ae820c03150068a6412204d0 by cjkkkk <ske@nvidia.com>:

pass in correct descriptor to activation

Merging this change closes #11249

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11249 from Cjkkkk:flash-attn-backward efd8ff45b5cf13e0ae820c03150068a6412204d0
PiperOrigin-RevId: 623056335
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 11, 2024
… cuDNN frontend (>1.0)

Imported from GitHub PR openxla/xla#11249

* Replace complicated node by node flash attention cuDNN graph construction with simple cudnn sdpa call provided in new cuDNN frontend (>1.0).
* Remove softmax_sum/dq_accum tensors from the custom call outputs since it is not required anymore in new cuDNN frontend.
* Add a pass to query workspace from cuDNN graph instead of manual calculation as workspace size is subject to change.
Copybara import of the project:

--
2a1ee169fb2c12fbbd4e791ff406b3272490cf09 by cjkkkk <ske@nvidia.com>:

init

--
e84388a1859ef26acad7ba68cebedf546098d99b by cjkkkk <ske@nvidia.com>:

fix stat datatype to be fp32

--
3a0b75ccfa88791d5cc437a9cc86333849d5fee3 by cjkkkk <ske@nvidia.com>:

use temp value for workspace & remove softmax buffer/dq_accum

--
77e6db53f4f4d0b5344e3492da21e871ceb0bb0b by cjkkkk <ske@nvidia.com>:

add workspace upperbound

--
11dd85db2ebea5cbd12b7108bdc995d8f8f5f1ba by cjkkkk <ske@nvidia.com>:

use tighter upperbound

--
2f245303893f0aef5dc48557f8254d16ae373d9d by cjkkkk <ske@nvidia.com>:

remove unused param

--
5e648ae07c743e95955f23fbd1c4fa50a13be5de by cjkkkk <ske@nvidia.com>:

add pass to fix workspace size

--
368e51879c10266a3ecc200c105afc37dcba544c by cjkkkk <ske@nvidia.com>:

remove comments

--
2fb7b3854f57c80e0c2e7805349ff5b637ba2007 by cjkkkk <ske@nvidia.com>:

rm header order change

--
d0cb184a32a5489f17ca5a20cd11901544e88ba1 by cjkkkk <ske@nvidia.com>:

move def close to use and remove comments

--
f09eb59328d2416374c31c5336744231b67954cd by cjkkkk <ske@nvidia.com>:

move use to def

--
1c56f1fe670cfc43cc1bfcbc16ed75cf9ce5357f by cjkkkk <ske@nvidia.com>:

add missing dependency in BUILD

--
efd8ff45b5cf13e0ae820c03150068a6412204d0 by cjkkkk <ske@nvidia.com>:

pass in correct descriptor to activation

Merging this change closes #11249

PiperOrigin-RevId: 623771291
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants