Skip to content

Commit

Permalink
[inductor] Stop using x + tl.zeros(...) in generated triton (pytorc…
Browse files Browse the repository at this point in the history
…h#100163)

For reductions, this changes the accumulator
```python
_tmp2 = tl.zeros([XBLOCK, RBLOCK], tl.int8) + -128
```
to
```python
_tmp2 = tl.full([XBLOCK, RBLOCK], -128, tl.int32)
```
which is equivalent since addition does type promotion from `int8` to `int32`

For constant indexing, this changes
```python
tl.store(in_out_ptr0 + (0 + tl.zeros([XBLOCK, 1], tl.int32)), tmp4, None)
```
to
```python
tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

For variable indexing, this changes
```python
tl.store(out_ptr0 + (0 + tl.zeros([XBLOCK], tl.int32)), tmp1, None)
```
to
```python
tl.store(out_ptr0 + (tl.broadcast_to(x0, [XBLOCK])), tmp1, None)
```

Pull Request resolved: pytorch#100163
Approved by: https://github.com/ngimel
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Apr 28, 2023
1 parent 270a331 commit 5b98910
Showing 1 changed file with 38 additions and 21 deletions.
59 changes: 38 additions & 21 deletions torch/_inductor/codegen/triton.py
Expand Up @@ -13,6 +13,7 @@
import torch

import torch._logging
from torch._prims_common import is_integer_dtype
from ..._dynamo import config as dynamo_config
from .. import config, ir, scheduler
from ..codecache import get_code_path
Expand Down Expand Up @@ -109,6 +110,13 @@ def triton_compute_type(dtype):
return f"tl.{triton_type_name}"


def triton_acc_type(dtype):
if is_integer_dtype(dtype) and dtype.is_signed:
nbits = 64 if dtype == torch.int64 else 32
return f"tl.int{nbits}"
return triton_compute_type(dtype)


def triton_constant(value):
if value == float("inf"):
return 'float("inf")'
Expand Down Expand Up @@ -366,11 +374,11 @@ def libdevice_log(x):

@staticmethod
def isinf(x):
return f"tl.math.isinf({x})"
return f"tl.math.isinf({x}).to(tl.int1)"

@staticmethod
def isnan(x):
return f"tl.math.isnan({x})"
return f"tl.math.isnan({x}).to(tl.int1)"

@staticmethod
def round(x):
Expand Down Expand Up @@ -944,20 +952,18 @@ def indexing(

expand_str = None

if (need_dense and not have_dense) or isinstance(index, sympy.Integer):
if copy_shape:
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
expand_str = f"{copy_shape}.shape"
else:
index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
expand_str = self.dense_size_str()
if isinstance(index, sympy.Integer):
return index_str, set(), "None", expand_str
else:
mask_vars = dense_mask_vars
if isinstance(index, sympy.Integer):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
return index_str, set(), "None", expand_str

if need_dense and not have_dense:
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.broadcast_to({index_str}, {expand_str})"
mask_vars = dense_mask_vars
elif not have_loop_vars and copy_shape:
index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)"
mask_vars = dense_mask_vars
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"

if override_mask:
mask_vars = {override_mask}
Expand Down Expand Up @@ -1213,13 +1219,8 @@ def final_argreduce(buffer, result_var, value, index):
elif (src_dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(src_dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
# NOTE: We should be using tl.full here, but this also does type
# promotion e.g. bool to int32, which is sometimes necessary if
# similar promotion happened elsewhere in the pre-reduction
# operation. We should identify any such cases and fix them.
default_value = f" + {default}" if default != 0 else ""
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(src_dtype)}){default_value}"
f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {triton_acc_type(src_dtype)})"
)

if reduction_type in {"argmax", "argmin"}:
Expand Down Expand Up @@ -1259,7 +1260,23 @@ def final_argreduce(buffer, result_var, value, index):
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)
self.suffix.writeline(f"{result_var} = {final_reduction(accumulator)}")

if src_dtype == torch.bool:
# This is only really used for aten.any. It changes the
# final reduction of a non-persistent reduction from
# tmp5 = triton_helpers.max(_tmp5, 1)[:, None]
# to
# tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1)
# which is needed because tl.reduce doesn't support tl.int1
accumulator = f"{accumulator}.to(tl.int8)"
result_type = triton_compute_type(dtype)
self.suffix.writeline(
f"{result_var} = {final_reduction(accumulator)}.to({result_type})"
)
else:
self.suffix.writeline(
f"{result_var} = {final_reduction(accumulator)}"
)
else:
var_name = self.cse.reduction_cache[(src_dtype, reduction_type, value)]
self.suffix.writeline(f"{result_var} = {var_name}")
Expand Down

0 comments on commit 5b98910

Please sign in to comment.