diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index ce8cf9433..75971c999 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -55,7 +55,7 @@ def test_dropout(shape, amp): # Check that 0 means no dropout y = dropout(x, p=0) - assert torch.allclose(x.to(y.dtype), y, rtol=tol) + assert torch.allclose(x.to(y.dtype), y, rtol=tol), f"{x[x>y]}" # Check that 1 means dropout for sure y = dropout(x, p=1) diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 94bcf8478..4ad1a3d44 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -69,4 +69,7 @@ def grid(meta): def dropout(x: torch.Tensor, p: float): - return _dropout.apply(x, p) + if p > 0.0: + return _dropout.apply(x, p) + + return x diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index e81e13fd1..6aa6b7ecc 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -24,7 +24,7 @@ ) @triton.jit def k_dropout( - Y, X, S, + Y, X, SEEDS, stride, N, p, @@ -51,7 +51,7 @@ def k_dropout( x = tl.load(x_ptrs, mask=mask) # randomly prune it - seed = S + row + seed = SEEDS + row random = tl.rand(seed.to(tl.int32), offsets) x_keep = random > p