Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] cutlass FlashAttention bias+dropout support #587

Merged
merged 8 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
build/
dist/

# for autocomplete
compile_commands.json

# Pytest verbose output
test-results/

Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- fMHA: Added CUTLASS-based kernel for `xformers.ops.memory_efficient_attention`. This kernel is automatically depending on the inputs, and works on any GPU after P100 [facebookresearch/xformers#362]

## [0.0.15] - 2022-12-13
### Fixed

### Added
- Added tensor attn bias support to CUTLASS FlashAttention
- Added tensor attn bias grad support to CUTLASS FlashAttention
- Added dropout support to CUTLASS FlashAttention

## [0.0.12] - 2022-08-08
### Fixed
- Removed duplicated biases in the FusedMLP layers [facebookresearch/xformers#317]
Expand Down
90 changes: 70 additions & 20 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
)

_out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad(
query, key, value
query, key, value, op=op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

)
ref_lse = ((query.float() / k**0.5) @ key.float().transpose(-2, -1)).logsumexp(-1)

Expand All @@ -616,7 +616,13 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):

@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@pytest.mark.parametrize(
"attn_bias_type", [None, xformers.ops.LowerTriangularMask, torch.Tensor]
"attn_bias_cfg", # (type(bias), bias.requires_grad)
[
(None, False),
(xformers.ops.LowerTriangularMask, False),
(torch.Tensor, True),
(torch.Tensor, False),
],
)
@pytest.mark.parametrize("grad_out_contiguous", [False, True])
@pytest.mark.parametrize(
Expand All @@ -627,9 +633,10 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
def test_backward(
op_device_dtype_B_Mq_Mkv_H_K_Kv,
grad_out_contiguous,
attn_bias_type,
attn_bias_cfg,
fmt,
):
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
(
op_bw,
device,
Expand All @@ -646,9 +653,13 @@ def test_backward(
attn_bias_type=attn_bias_type,
fmt=fmt,
)
op_fw = sample_random_supported_fw(
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
seed=q_len * kv + kv_len * k,
op_fw = (
sample_random_supported_fw(
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
seed=q_len * kv + kv_len * k,
)
if op_bw != fmha.cutlass.BwOp
else fmha.cutlass.FwOp
)
qkv = None

Expand All @@ -666,6 +677,11 @@ def test_backward(
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
if isinstance(attn_bias, torch.Tensor):
attn_bias.requires_grad_(attn_bias_requires_grad)

if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)):
pytest.skip("inputs not supported")

out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=(op_fw, op_bw)
Expand All @@ -692,6 +708,9 @@ def test_backward(
else:
grads = [qkv.grad]
qkv.grad = None
if attn_bias_requires_grad:
grads.append(attn_bias.grad)
attn_bias.grad = None

ref = ref_attention(query, key, value, attn_bias)
ref.backward(grad_out)
Expand All @@ -713,6 +732,12 @@ def test_backward(
assert isinstance(qkv.grad, torch.Tensor)
grads_ref = [qkv.grad]
grads_name = ["qkv"]

if attn_bias_requires_grad:
assert isinstance(attn_bias.grad, torch.Tensor)
grads_ref.append(attn_bias.grad)
grads_name.append("bias")

del query
del key
del value
Expand Down Expand Up @@ -755,49 +780,64 @@ def _vec_binom_test(x, n, p):
return pval


def _get_drop_mask(op, batch_size, q_len, kv_len, p, device):
if op == fmha.cutlass.FwOp:
mask = torch.empty((batch_size, 1, q_len, kv_len), device=device)
rand_uniform = torch.ops.xformers._cutlass_rand_uniform(p, mask)
mask = (rand_uniform > p).to(torch.float32)
mask = mask.reshape(batch_size, q_len, kv_len)
else:
mask = torch.empty((batch_size, q_len, kv_len), device=device)
mask = torch.ops.xformers._temp_dropout(mask, p)

return mask


@cuda_only
@pytest.mark.parametrize("seed", [42, 124])
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
@pytest.mark.parametrize("device", ["cuda"])
def test_dropout(device, q_len, kv_len, batch_size, k_len, p, seed):
@pytest.mark.parametrize("op", ALL_FW_OPS, ids=list(map(lambda t: t.NAME, ALL_FW_OPS)))
def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed):
device = "cuda"
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale

attn_bias = None
op = (fmha.small_k.FwOp, None)

inputs_for_support_check = fmha.Inputs(query, key, value, attn_bias, p, None)
if not op.supports(inputs_for_support_check):
del query, key, value, attn_bias
pytest.skip(f"{op.NAME}: unsupported input")

torch.manual_seed(seed)
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, p, op=op
query, key, value, attn_bias, p, op=(op, None)
)

torch.manual_seed(seed)
out2 = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, p, op=op
query, key, value, attn_bias, p, op=(op, None)
)

assert_allclose(out, out2)

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)

mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
ref = ref_attention(query, key, value, attn_bias, mask, p)
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"

num_trials = 1000
p_val_tol = 0.0001
p_val_tol = 1e-6
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to mention this. dropped p_val_tol from 1e-4 -> 1e-6. otherwise the cutlass dropout mask fails the second binomial test.

took a look at the results it was outputting and it doesn't seem outlandish or anything. for example this test

test_dropout[cutlassF-33-32-2-32-0.7-42]

fails with one of the 2048 elements of masks ending up with 248/1000 keeps (p=0.7), resulting in a p value of 3.9172e-05.

looks like the other implementation uses a new subsequence every 4 elements which might have better independence guarantees but its unlikely to be as performant as the way the CUTLASS dropout implementation is done now

whats your take on this? maybe we can soften the constraint by using a percentage of elements have to pass the test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fmassa who implemented this test and the smallK kernel which also supports dropout

keep_prob = 1 - p
masks = []
for i in range(num_trials):
mask = torch.ops.xformers._temp_dropout(mask, p)
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
masks.append(mask.clone().cpu())
masks = torch.stack(masks, dim=0)
p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob)
Expand Down Expand Up @@ -840,10 +880,8 @@ def _test_dropout_backward(q_len, kv_len, batch_size, k_len, p, op, dtype):
key.grad = None
value.grad = None

mask = torch.empty((batch_size, q_len, kv_len), device=device)

torch.manual_seed(seed)
mask = torch.ops.xformers._temp_dropout(mask, p)
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)

ref = ref_attention(query, key, value, None, mask, p)
ref.backward(grad_out)
Expand Down Expand Up @@ -881,6 +919,18 @@ def test_dropout_backward_flash(q_len, kv_len, batch_size, k_len, p):
)


@cuda_only
@pytest.mark.parametrize("p", [0.3, 0.7])
jfc4050 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("k_len", [16, 32])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
def test_dropout_backward_cutlass(q_len, kv_len, batch_size, k_len, p):
_test_dropout_backward(
q_len, kv_len, batch_size, k_len, p, op=fmha.cutlass.FwOp, dtype=torch.float16
)


@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("kv_len", [3 * 32])
Expand Down