Skip to content

Commit

Permalink
code review, thanks @fmassa !
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Apr 21, 2022
1 parent 7ec9726 commit a9c6065
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
author = "Facebook AI Research"

# The full version, including alpha/beta/rc tags
release = "0.0.11.dev"
release = "0.0.10"

# -- General configuration ---------------------------------------------------

Expand Down
9 changes: 6 additions & 3 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,19 @@ def test_dropout(shape, amp, bias, p):
x_ref = (x + b if bias else x).to(y.dtype)
assert torch.allclose(x_ref, y, rtol=tol), f"{x[x>y]}"

# Check that .99 means dropout for sure
# Check that 1 means drop all
y = triton_dropout(x, p=1, bias=b)
x_ref = (x + b if bias else x).to(y.dtype)
assert torch.allclose(torch.zeros_like(y), y, rtol=tol)

# Check that .99 means probably dropout
y = triton_dropout(x, p=0.99, bias=b)
x_ref = (x + b if bias else x).to(y.dtype)
assert not torch.allclose(x_ref, y, rtol=tol)

# Check that the drops are different for every row (could catch broken seeds per row)
y = triton_dropout(x, p=0.5)

print(y)

y = y.flatten(0, 1) if y.ndim == 3 else y
assert not torch.sum(torch.eq(y[0, :] == 0.0, y[1, :] == 0.0)) == y.shape[1]

Expand Down
2 changes: 1 addition & 1 deletion xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

# Please update the doc version in docs/source/conf.py as well.
__version__ = "0.0.11.dev"
__version__ = "0.0.10"

_is_sparse_available = True
_is_triton_available = torch.cuda.is_available()
Expand Down
5 changes: 4 additions & 1 deletion xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def dropout(
Optionally add a bias, the computation will be fused.
"""

assert p < 1.0, f"We don't want to drop all the values, most probably {p}"
assert p <= 1.0 and p >= 0.0

if p == 1.0:
return torch.zeros_like(x)

# Micro optim, skip dropout
if p == 0.0:
Expand Down

0 comments on commit a9c6065

Please sign in to comment.