-
Notifications
You must be signed in to change notification settings - Fork 552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create BlockSparse Tensor #202
Changes from all commits
6b0ff2a
ef27f63
1faec9d
a04c6d1
4db37a9
6667da9
0bf0059
08fb5a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
|
||
# needed to register custom ops | ||
import xformers # noqa: F401 | ||
from xformers.ops import masked_matmul | ||
from xformers.sparse import BlockSparseTensor | ||
|
||
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | ||
_devices = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] | ||
|
||
|
||
def _create_tensor(device, BLOCK=32, Z=8, C=2, H=64, W=64, dtype=torch.float32): | ||
layout = torch.randint(2, (C, H // BLOCK, W // BLOCK)) | ||
layout[:, :, 0] = 1 | ||
layout[:, 0, :] = 1 | ||
values = torch.randn(Z, layout.sum(), BLOCK, BLOCK, device=device).to(dtype) | ||
|
||
mask = ( | ||
layout[None, :, :, None, :, None] | ||
.repeat(Z, 1, 1, BLOCK, 1, BLOCK) | ||
.reshape(Z, C, H, W) | ||
) | ||
|
||
return BlockSparseTensor(values, layout), mask.bool() | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
def test_masked_matmul(device): | ||
BLOCK = 16 | ||
N, C, H, W, L = 8, 2, 64, 64, 32 | ||
mask_block, _ = _create_tensor(device, BLOCK, N, C, H, W, dtype=torch.bool) | ||
mask = mask_block.to_dense() | ||
|
||
a = torch.randn(N, C, H, L, device=device) | ||
b = torch.randn(N, C, W, L, device=device) | ||
|
||
aa = a.clone() | ||
bb = b.clone() | ||
|
||
a.requires_grad_(True) | ||
b.requires_grad_(True) | ||
aa.requires_grad_(True) | ||
bb.requires_grad_(True) | ||
|
||
bt = b.transpose(-2, -1) | ||
bbt = bb.transpose(-2, -1) | ||
|
||
# res_gt = masked_matmul(a, b, mask) | ||
res_gt = a @ bt | ||
# res_gt[~mask] = 0 | ||
res_gt = torch.where(mask, res_gt, torch.zeros_like(res_gt)) | ||
res = masked_matmul(aa, bbt, mask_block) | ||
|
||
res_dense = res.to_dense() | ||
# res_dense[~mask] = float('-inf') | ||
|
||
assert res.dtype == res_gt.dtype | ||
assert torch.allclose(res_dense, res_gt) | ||
|
||
# try to workaround non-contiguous issues with triton for now | ||
res_gt.backward(torch.ones_like(res_gt)) | ||
res._blocksparse_values.backward(torch.ones_like(res._blocksparse_values)) | ||
# TODO: this is not passing!!! | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is that only when a row is [0], or do you have other issues ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The failures here are due to triton-lang/triton#419 @ptillet Are we planning on releasing a new version of |
||
# assert torch.allclose(a.grad, aa.grad, atol=1e-7) | ||
# assert torch.allclose(b.grad, bb.grad, atol=1e-7) | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
def test_bmm(device): | ||
BLOCK = 16 | ||
N, C, H, W, L = 8, 2, 64, 64, 32 | ||
a_block, mask = _create_tensor(device, BLOCK, N, C, H, W) | ||
a = a_block.to_dense() | ||
|
||
a_block.requires_grad_(True) | ||
a.requires_grad_(True) | ||
|
||
b = torch.randn(N, C, W, L, device=device) | ||
b2 = b.clone() | ||
|
||
b.requires_grad_(True) | ||
b2.requires_grad_(True) | ||
|
||
res_gt = a @ b | ||
res = a_block @ b2 | ||
|
||
assert res.dtype == res_gt.dtype | ||
assert torch.allclose(res, res_gt, atol=1e-5) | ||
|
||
res_gt.sum().backward() | ||
res.sum().backward() | ||
|
||
a_grad = a.grad.clone().detach() | ||
a_grad[~mask] = 0 | ||
|
||
assert torch.allclose(b.grad, b2.grad, atol=1e-5) | ||
assert torch.allclose(a_grad, a_block.grad.to_dense(), atol=1e-7) | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
def test_sparse_softmax(device): | ||
a_block, mask = _create_tensor(device) | ||
a = a_block.to_dense() | ||
a[~mask] = float("-inf") | ||
|
||
res_gt = torch.softmax(a, dim=-1) | ||
res_block = torch.softmax(a_block, dim=-1) | ||
|
||
res = res_block.to_dense() | ||
|
||
assert res.dtype == res_gt.dtype | ||
assert torch.allclose(res, res_gt) | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
def test_sparse_softmax_backward(device): | ||
a_block, mask = _create_tensor(device) | ||
a = a_block.to_dense() | ||
a_block.requires_grad_(True) | ||
|
||
a[~mask] = float("-inf") | ||
a.requires_grad_(True) | ||
|
||
res_gt = torch.softmax(a, dim=-1) | ||
res_block = torch.softmax(a_block, dim=-1) | ||
|
||
# WARNING: gradients are modified in-place! | ||
res_block._blocksparse_values.backward( | ||
torch.ones_like(res_block._blocksparse_values) | ||
) | ||
res_gt.backward(torch.ones_like(res_gt)) | ||
|
||
assert torch.allclose(a.grad, a_block.grad.to_dense(), atol=2e-7) | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
def test_deepcopy(device): | ||
import copy | ||
|
||
a_block, mask = _create_tensor(device) | ||
|
||
b_block = copy.deepcopy(a_block) | ||
assert torch.equal(a_block, b_block) | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
def test_module_buffer(device): | ||
a_block, _ = _create_tensor(device) | ||
b_block, _ = _create_tensor(device) | ||
|
||
module = torch.nn.Module() | ||
# test that register_buffer works | ||
module.register_buffer("a_block", a_block) | ||
|
||
assert module.a_block is a_block | ||
|
||
module.to(device) | ||
assert module.a_block.device == torch.device(device) | ||
|
||
state_dict = module.state_dict() | ||
assert "a_block" in state_dict | ||
assert torch.equal(a_block.to(device), state_dict["a_block"]) | ||
|
||
module.load_state_dict(state_dict) | ||
|
||
module.load_state_dict({"a_block": b_block}) | ||
assert torch.equal(module.a_block, b_block.to(device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the tests should cover fp16 also, most people will use that in fp16 and for some gpus it will follow a different code path (v100 for instance)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried running on fp16 but I got consistent segfaults. This might be due to my version of triton, or something else I'm doing wrong. Anyway, I'll let the traceback here if it can be of use.
cc @ptillet
Traceback of the segfault