Skip to content

Commit

Permalink
bsr_dense_mm Triton kernel: fix out kwarg (#96648)
Browse files Browse the repository at this point in the history
As per title. The kernel did not handle `out=` correctly and returned a different tensor which only shared storage with `out`.

Pull Request resolved: pytorch/pytorch#96648
Approved by: https://github.com/cpuhrsch
  • Loading branch information
nikitaved authored and cyyever committed Mar 27, 2023
1 parent 5787435 commit 09e1133
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torch/sparse/_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,15 @@ def check(cond, msg):
"should be True.",
)

# Allocate out
if out is None:
out = dense.new_zeros(original_batch_dims_broadcasted + (m, n))
else:
out.zero_()

# Short circuit if lhs is zero
if bsr._nnz() == 0:
return dense.new_zeros(original_batch_dims_broadcasted + (m, n))
return out

# TODO: insert switch
if is_sparse_rowspace_mode is None:
Expand Down Expand Up @@ -486,10 +492,6 @@ def make_triton_contiguous(t):
dense_batch_dims = dense.shape[:-2]
batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)

# Allocate out
if out is None:
out = dense.new_zeros(batch_dims_broadcasted + (m, n))

# Broadcast batch dimensions and squash
def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
return t.broadcast_to(batch_dims + invariant_dims).flatten(
Expand Down Expand Up @@ -520,6 +522,8 @@ def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])

# NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
# out gets modified in-place, so we store a backup copy.
out_backup = out
out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])

# NOTE: this function will ALWAYS create a view
Expand Down Expand Up @@ -570,10 +574,7 @@ def valid_grid_dim(g, mg):

kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)

# Block dims need to rejoin with the corresponding block dimensions
# prior to reshape so that blocks do not end up being transposed.
# NB: type checker is not able to narrow Optional[Tensor] to tensor by this point
return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n)) # type: ignore[union-attr]
return out_backup
else:
bsr_dense_mm = None # type: ignore[assignment]

Expand Down

0 comments on commit 09e1133

Please sign in to comment.