Skip to content

Commit

Permalink
Use ir.FloatType instead of a Pallas-local shim
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607635063
  • Loading branch information
superbobry authored and jax authors committed Feb 16, 2024
1 parent 86db2de commit 8026d19
Showing 1 changed file with 18 additions and 62 deletions.
80 changes: 18 additions & 62 deletions jax/_src/pallas/triton/lowering.py
Expand Up @@ -803,7 +803,7 @@ def _add(x: ir.Value, y: ir.Value):
assert x.type == y.type, (str(x.type), str(y.type))
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.addi(x, y)
elif isinstance(x_element_type, FloatType):
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.addf(x, y)
else:
raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}")
Expand All @@ -818,7 +818,7 @@ def _sub(x: ir.Value, y: ir.Value) -> ir.Value:
assert x.type == y.type, (str(x.type), str(y.type))
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.subi(x, y)
elif isinstance(x_element_type, FloatType):
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.subf(x, y)
raise NotImplementedError(f"unsupported dtype: {y.type}")

Expand All @@ -828,7 +828,7 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value:
x_element_type = _element_type(x.type)
if isinstance(x_element_type, ir.IntegerType):
return arith_dialect.muli(x, y)
elif isinstance(x_element_type, FloatType):
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.mulf(x, y)
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")

Expand All @@ -851,7 +851,7 @@ def _truediv(x: ir.Value, y: ir.Value) -> ir.Value:
x_element_type = ir.F32Type.get()
x = _int_float_cast(x, x_element_type)
y = _int_float_cast(y, x_element_type)
if isinstance(x_element_type, FloatType):
if isinstance(x_element_type, ir.FloatType):
return arith_dialect.divf(x, y)
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")

Expand Down Expand Up @@ -880,7 +880,7 @@ def _cmp(
return arith_dialect.cmpi(
si_pred if x_element_type.is_signed else ui_pred, x, y
)
elif isinstance(x_element_type, FloatType):
elif isinstance(x_element_type, ir.FloatType):
return arith_dialect.cmpf(f_pred, x, y)
else:
raise NotImplementedError(f"unsupported types: {x.type} and {y.type}")
Expand Down Expand Up @@ -1105,47 +1105,6 @@ def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension):
triton_lowering_rules[lax.iota_p] = _iota_lowering_rule


_FLOAT_WIDTH = {
ir.Float8E4M3FNUZType: 8,
ir.Float8E4M3FNType: 8,
ir.Float8E4M3B11FNUZType: 8,
ir.Float8E5M2Type: 8,
ir.BF16Type: 16,
ir.F16Type: 16,
ir.F32Type: 32,
ir.F64Type: 64,
}
_FLOAT_TYPES = tuple(_FLOAT_WIDTH)


class FloatTypeMeta(type):

def __instancecheck__(cls, instance: object) -> bool:
return isinstance(instance, _FLOAT_TYPES)

def __subclasscheck__(cls, subclass: type[object]) -> bool:
return issubclass(subclass, _FLOAT_TYPES)


# TODO(slebedev): Remove once https://github.com/llvm/llvm-project/pull/81720 is merged.
class FloatType(metaclass=FloatTypeMeta):
"""Fake base class for MLIR floating point types."""

def __init__(self, type: ir.Type):
assert isinstance(type, _FLOAT_TYPES)
self.type = type

@property
def is_standard(self) -> bool:
return isinstance(
self.type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)
)

@property
def width(self) -> int:
return _FLOAT_WIDTH[type(self.type)]


def _element_type(t: ir.Type) -> ir.Type:
if ir.RankedTensorType.isinstance(t):
return ir.RankedTensorType(t).element_type
Expand All @@ -1171,7 +1130,7 @@ def _full(t: ir.Type, v: object) -> ir.Type:
element_type = _element_type(t)
if isinstance(element_type, ir.IntegerType):
result = arith_dialect.constant(element_type, int(v))
elif isinstance(element_type, FloatType):
elif isinstance(element_type, ir.FloatType):
result = arith_dialect.constant(element_type, float(v))
else:
raise NotImplementedError
Expand Down Expand Up @@ -1199,8 +1158,8 @@ def _expand_dims(x: ir.Value, axis: int) -> ir.Value:


def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
src_element_type = FloatType(_element_type(src.type))
dst_element_type = FloatType(_element_type(dst_type))
src_element_type = ir.FloatType(_element_type(src.type))
dst_element_type = ir.FloatType(_element_type(dst_type))
if src_element_type.width == 8 or dst_element_type.width == 8:
return tt_dialect.fp_to_fp(
dst_type,
Expand Down Expand Up @@ -1234,8 +1193,8 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:


def _float_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
src_element_type = FloatType(_element_type(src.type))
if not src_element_type.is_standard:
src_element_type = _element_type(src.type)
if not isinstance(src_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)):
raise NotImplementedError(f"cannot cast {src} tp {dst_type}")
dst_element_type = ir.IntegerType(_element_type(dst_type))
if dst_element_type.width == 1:
Expand All @@ -1248,8 +1207,8 @@ def _float_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:

def _int_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
src_element_type = ir.IntegerType(_element_type(src.type))
dst_element_type = FloatType(_element_type(dst_type))
if not dst_element_type.is_standard:
dst_element_type = _element_type(dst_type)
if not isinstance(dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)):
raise NotImplementedError(f"cannot cast {src} tp {dst_type}")
if src_element_type.width == 1 or not src_element_type.is_signed:
return arith_dialect.uitofp(dst_type, src)
Expand Down Expand Up @@ -1283,8 +1242,8 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
):
return _cast(_cast(src, ir.F32Type.get()), dst_type)

if isinstance(src_element_type, FloatType) and isinstance(
dst_element_type, FloatType
if isinstance(src_element_type, ir.FloatType) and isinstance(
dst_element_type, ir.FloatType
):
return _float_float_cast(src, dst_type)

Expand All @@ -1293,12 +1252,12 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value:
):
return _int_int_cast(src, dst_type)

if isinstance(src_element_type, FloatType) and isinstance(
if isinstance(src_element_type, ir.FloatType) and isinstance(
dst_element_type, ir.IntegerType
):
return _float_int_cast(src, dst_type)
if isinstance(src_element_type, ir.IntegerType) and isinstance(
dst_element_type, FloatType
dst_element_type, ir.FloatType
):
return _int_float_cast(src, dst_type)

Expand Down Expand Up @@ -1865,10 +1824,7 @@ def _dot(
acc = _full(ir.RankedTensorType.get([m, n], element_type), 0)

if max_num_imprecise_acc is None:
if (
FloatType(x_type.element_type).width == 8
and FloatType(y_type.element_type).width == 8
):
if isinstance(element_type, ir.FloatType) and element_type.width == 8:
# TODO(slebedev): Fill in from options.
raise NotImplementedError
else:
Expand Down Expand Up @@ -2435,7 +2391,7 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value:
if isinstance(t, ir.IntegerType):
v = int(v)
else:
assert isinstance(t, FloatType)
assert isinstance(t, ir.FloatType)
v = float(v)
return arith_dialect.constant(t, v)
raise NotImplementedError
Expand Down

0 comments on commit 8026d19

Please sign in to comment.