Skip to content

Commit

Permalink
Integrate Triton up to [8e0c7b42](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
Moerafaat authored and jax authors committed Apr 29, 2024
1 parent dfc1718 commit ac06df5
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,20 +1615,6 @@ def _get_lowering_rule(ctx: LoweringRuleContext, ptr, *idx, tree):
_STR_TO_CACHE_MODIFIER = {str(c): c for c in tt_dialect.CacheModifier}


def _infer_load_return_type(ptr: ir.Value) -> ir.Type:
if ir.RankedTensorType.isinstance(ptr.type):
ptr_type = ir.RankedTensorType(ptr.type)
element_type = tt_dialect.PointerType(ptr_type.element_type)
return ir.RankedTensorType.get(
ptr_type.shape,
element_type.pointee_type,
ptr_type.encoding,
)
else:
ptr_type = tt_dialect.PointerType(ptr.type)
return ptr_type.pointee_type


def _load(
ptr: ir.Value,
mask: ir.Value | None = None,
Expand Down Expand Up @@ -1685,7 +1671,6 @@ def _load(
other = _ir_cast(other, pointee_type, signed=False)

result = tt_dialect.load(
_infer_load_return_type(ptr),
ptr,
mask=mask,
other=other,
Expand Down Expand Up @@ -1928,7 +1913,17 @@ def _dot(
else:
max_num_imprecise_acc = 0

return tt_dialect.dot(x, y, acc, allow_tf32, max_num_imprecise_acc)
# Ideally replace all allow_tf32 usages with InputPrecision directly
input_precision = tt_dialect.InputPrecision.IEEE
if allow_tf32:
input_precision = tt_dialect.InputPrecision.TF32
return tt_dialect.dot(
x,
y,
acc,
input_precision=input_precision,
max_num_imprecise_acc=max_num_imprecise_acc,
)


_TF32_PRECISIONS = (lax.Precision.HIGH, lax.Precision.DEFAULT)
Expand Down

0 comments on commit ac06df5

Please sign in to comment.