Skip to content

Commit

Permalink
catering for triton blocksparse being probably more reliable in fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Apr 18, 2022
1 parent 33a9416 commit 6369798
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions tests/test_sparse_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def _create_tensor(tensor_type, device, dtype, shape, sparsity):
)


def _seed():
torch.random.manual_seed(42)
torch.cuda.manual_seed_all(42)


def _get_dtype_atol(device: str, tensor_type):
_seed()

# Upstream GPU blocksparse (Triton op) is only validated in fp16 for now
if tensor_type == BlockSparseTensor and "cuda" in device:
return torch.float16, 1e-4

return torch.float32, 1e-5


@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("func", [torch.add, torch.mul])
def test_sparse_binary_ops(func, device):
Expand Down Expand Up @@ -83,6 +98,7 @@ def test_sparse_binary_ops(func, device):
def test_masked_matmul(tensor_type, device):
N, C, H, W, L = 8, 2, 64, 64, 32
sparsity = 0.7
dtype, atol = _get_dtype_atol(tensor_type, device)

shape0 = (N, C, H, W)
shape1 = (N, C, H, L)
Expand All @@ -98,8 +114,8 @@ def test_masked_matmul(tensor_type, device):
)
mask = mask_sparse.to_dense()

a = torch.randn(shape1, device=device)
b = torch.randn(shape2, device=device)
a = torch.randn(shape1, device=device, dtype=dtype)
b = torch.randn(shape2, device=device, dtype=dtype)

aa = a.clone()
bb = b.clone()
Expand All @@ -119,24 +135,25 @@ def test_masked_matmul(tensor_type, device):
res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf")))

assert res.dtype == res_gt.dtype
assert torch.allclose(res_dense, res_gt, atol=5e-6)
assert torch.allclose(
res_dense, res_gt, atol=atol
), f"{torch.max(torch.abs(res_dense-res_gt))}"

# try to workaround non-contiguous issues with triton for now
res_gt.backward(torch.ones_like(res_gt))
res.values().backward(torch.ones_like(res.values()))
# TODO: this is not passing for BlockSparse!!!
if tensor_type != BlockSparseTensor:
assert torch.allclose(a.grad, aa.grad, atol=5e-6)
assert torch.allclose(b.grad, bb.grad, atol=5e-6)

assert torch.allclose(a.grad, aa.grad, atol=atol)
assert torch.allclose(b.grad, bb.grad, atol=atol)


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_bmm(tensor_type, device):
N, C, H, W, L = 8, 2, 64, 64, 32
dtype = torch.float32
sparsity = 0.8
dtype, atol = _get_dtype_atol(tensor_type, device)

sparsity = 0.8
shape0 = (N, C, H, W)
shape1 = (N, C, W, L)

Expand All @@ -153,7 +170,7 @@ def test_bmm(tensor_type, device):
a_sparse.requires_grad_(True)
a.requires_grad_(True)

b = torch.randn(shape1, device=device)
b = torch.randn(shape1, device=device, dtype=dtype)
b2 = b.clone()

b.requires_grad_(True)
Expand All @@ -163,23 +180,26 @@ def test_bmm(tensor_type, device):
res = a_sparse @ b2

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-5)
assert torch.allclose(res, res_gt, atol=atol), f"{torch.max(torch.abs(res-res_gt))}"

res_gt.sum().backward()
res.sum().backward()

a_grad = a.grad.clone().detach()
a_grad[~mask] = 0

assert torch.allclose(b.grad, b2.grad, atol=1e-5)
assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-5)
assert torch.allclose(b.grad, b2.grad, atol=atol)
assert torch.allclose(
a_grad, a_sparse.grad.to_dense(), atol=atol
), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax(tensor_type, device):
N, C, H, W = 8, 2, 64, 64
dtype = torch.float32
dtype, atol = _get_dtype_atol(tensor_type, device)

sparsity = 0.8

shape0 = (N, C, H, W)
Expand All @@ -203,7 +223,9 @@ def test_sparse_softmax(tensor_type, device):
res = res_sparse.to_dense()

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
assert torch.allclose(
res, res_gt, atol=atol
), f"{torch.max(torch.abs(res- res_gt))}"

# WARNING: gradients are modified in-place!
res_sparse.values().backward(torch.ones_like(res_sparse.values()))
Expand All @@ -212,7 +234,9 @@ def test_sparse_softmax(tensor_type, device):
a_grad = a.grad.clone()
a_grad[~mask] = 0

assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-6)
assert torch.allclose(
a_grad, a_sparse.grad.to_dense(), atol=atol
), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
Expand Down

0 comments on commit 6369798

Please sign in to comment.