Skip to content

Commit

Permalink
memory_efficient_attention() nondeterministic warning (#635)
Browse files Browse the repository at this point in the history
* add warning of non-deterministic behavior
for efficient_attention_forward_cutlass

* add non-deterministic note

* add alertNotDeterministic to
mem_efficient_attention_backward_cutlass

* Update document, thanks @danthe3rd
  • Loading branch information
takuma104 committed Jan 11, 2023
1 parent 3df785c commit 7aea476
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/Context.h>

#include "kernel_backward.h"

Expand Down Expand Up @@ -67,6 +68,8 @@ mem_efficient_attention_backward_cutlass(
false,
"MemoryEfficient build has been disabled at build time with -DXFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD");
#else
at::globalContext().alertNotDeterministic("mem_efficient_attention_backward_cutlass");

// ndim
TORCH_CHECK(query.dim() == grad_out_.dim());
TORCH_CHECK(query.dim() == key.dim());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/Context.h>

#include "kernel_forward.h"

Expand Down Expand Up @@ -144,6 +145,8 @@ std::tuple<at::Tensor, at::Tensor> efficient_attention_forward_cutlass(
false,
"MemoryEfficient build has been disabled at build time with -DXFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD");
#else
at::globalContext().alertNotDeterministic("efficient_attention_forward_cutlass");

TORCH_CHECK(query.dim() == 4);
TORCH_CHECK(key.dim() == 4);
TORCH_CHECK(value.dim() == 4);
Expand Down
4 changes: 4 additions & 0 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def memory_efficient_attention(
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
:Note:
This operator may be nondeterministic.
Raises:
NotImplementedError: if there is no operator available to compute the MHA
Expand Down

0 comments on commit 7aea476

Please sign in to comment.