Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Feb 5, 2022
1 parent 8008a3f commit e5fa51f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
77 changes: 41 additions & 36 deletions tests/test_triton_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_softmax(BLOCK, WIDTH, DTYPE):


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("block", [32]) # 16, 32,
@pytest.mark.parametrize("block", [32, 43]) # 16, 32,
def test_attention_fwd_bwd(
block,
input_scale=1.0,
Expand Down Expand Up @@ -177,45 +177,50 @@ def loss_fn(x):
query.retain_grad()
key.retain_grad()
value.retain_grad()
block_sparse_attention = BlockSparseAttention(layout, block)
attn_out = block_sparse_attention(
att_mask=attn_mask, q=query, k=key, v=value, scale=scale
)

# ad hoc loss
loss = loss_fn(attn_out)
loss.backward()
grads = [query.grad, key.grad, value.grad]

# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
torch_q = torch_q / math.sqrt(head_dim)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)

# ad hoc loss
torch_loss = loss_fn(torch_attn_out)
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]

# comparison
assert_almost_equal(
loss, torch_loss, err_msg=f"Triton loss {loss} and torch loss {torch_loss}"
)
if block not in [16, 32, 64]:
# Check that unsupported dimensions are caught
with pytest.raises():
_ = BlockSparseAttention(layout, block)
else:
block_sparse_attention = BlockSparseAttention(layout, block)
attn_out = block_sparse_attention(
att_mask=attn_mask, q=query, k=key, v=value, scale=scale
)

for g1, g2 in zip(grads, torch_grads):
# ad hoc loss
loss = loss_fn(attn_out)
loss.backward()
grads = [query.grad, key.grad, value.grad]

# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
torch_q = torch_q / math.sqrt(head_dim)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)

# ad hoc loss
torch_loss = loss_fn(torch_attn_out)
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]

# comparison
assert_almost_equal(
torch.norm(g1),
torch.norm(g2),
err_msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}",
loss, torch_loss, err_msg=f"Triton loss {loss} and torch loss {torch_loss}"
)

for g1, g2 in zip(grads, torch_grads):
assert_almost_equal(
torch.norm(g1),
torch.norm(g2),
err_msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}",
)


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
def test_blocksparse_attention_parity():
Expand Down
9 changes: 8 additions & 1 deletion xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class BlockSparseAttention(Attention):
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
be relaxed in the future.
.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.
.. note: it is possible to pass a specific per batch mask in the forward call,
but this will not lead to any speed up.
Any constant sparsity pattern is better passed through the layout parameter.
Expand All @@ -76,7 +79,11 @@ def __init__(
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
logging.warning(f"New layout dimensions: {layout.shape}")

assert block_size >= 16, "Minimum block size is 16, for now at least"
assert block_size in (
16,
32,
64,
), "Only block sizes in [16, 32, 64] are supported"

super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
Expand Down

0 comments on commit e5fa51f

Please sign in to comment.