From 8026d198b180d5b19457bffa43f3a3028e7743f7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 16 Feb 2024 03:29:35 -0800 Subject: [PATCH] Use ir.FloatType instead of a Pallas-local shim PiperOrigin-RevId: 607635063 --- jax/_src/pallas/triton/lowering.py | 80 +++++++----------------------- 1 file changed, 18 insertions(+), 62 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 8c0a3649a10b..8f4d44880fe8 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -803,7 +803,7 @@ def _add(x: ir.Value, y: ir.Value): assert x.type == y.type, (str(x.type), str(y.type)) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.addi(x, y) - elif isinstance(x_element_type, FloatType): + elif isinstance(x_element_type, ir.FloatType): return arith_dialect.addf(x, y) else: raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") @@ -818,7 +818,7 @@ def _sub(x: ir.Value, y: ir.Value) -> ir.Value: assert x.type == y.type, (str(x.type), str(y.type)) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.subi(x, y) - elif isinstance(x_element_type, FloatType): + elif isinstance(x_element_type, ir.FloatType): return arith_dialect.subf(x, y) raise NotImplementedError(f"unsupported dtype: {y.type}") @@ -828,7 +828,7 @@ def _mul(x: ir.Value, y: ir.Value) -> ir.Value: x_element_type = _element_type(x.type) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.muli(x, y) - elif isinstance(x_element_type, FloatType): + elif isinstance(x_element_type, ir.FloatType): return arith_dialect.mulf(x, y) raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") @@ -851,7 +851,7 @@ def _truediv(x: ir.Value, y: ir.Value) -> ir.Value: x_element_type = ir.F32Type.get() x = _int_float_cast(x, x_element_type) y = _int_float_cast(y, x_element_type) - if isinstance(x_element_type, FloatType): + if isinstance(x_element_type, ir.FloatType): return arith_dialect.divf(x, y) raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") @@ -880,7 +880,7 @@ def _cmp( return arith_dialect.cmpi( si_pred if x_element_type.is_signed else ui_pred, x, y ) - elif isinstance(x_element_type, FloatType): + elif isinstance(x_element_type, ir.FloatType): return arith_dialect.cmpf(f_pred, x, y) else: raise NotImplementedError(f"unsupported types: {x.type} and {y.type}") @@ -1105,47 +1105,6 @@ def _iota_lowering_rule(ctx: LoweringRuleContext, *, dtype, shape, dimension): triton_lowering_rules[lax.iota_p] = _iota_lowering_rule -_FLOAT_WIDTH = { - ir.Float8E4M3FNUZType: 8, - ir.Float8E4M3FNType: 8, - ir.Float8E4M3B11FNUZType: 8, - ir.Float8E5M2Type: 8, - ir.BF16Type: 16, - ir.F16Type: 16, - ir.F32Type: 32, - ir.F64Type: 64, -} -_FLOAT_TYPES = tuple(_FLOAT_WIDTH) - - -class FloatTypeMeta(type): - - def __instancecheck__(cls, instance: object) -> bool: - return isinstance(instance, _FLOAT_TYPES) - - def __subclasscheck__(cls, subclass: type[object]) -> bool: - return issubclass(subclass, _FLOAT_TYPES) - - -# TODO(slebedev): Remove once https://github.com/llvm/llvm-project/pull/81720 is merged. -class FloatType(metaclass=FloatTypeMeta): - """Fake base class for MLIR floating point types.""" - - def __init__(self, type: ir.Type): - assert isinstance(type, _FLOAT_TYPES) - self.type = type - - @property - def is_standard(self) -> bool: - return isinstance( - self.type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type) - ) - - @property - def width(self) -> int: - return _FLOAT_WIDTH[type(self.type)] - - def _element_type(t: ir.Type) -> ir.Type: if ir.RankedTensorType.isinstance(t): return ir.RankedTensorType(t).element_type @@ -1171,7 +1130,7 @@ def _full(t: ir.Type, v: object) -> ir.Type: element_type = _element_type(t) if isinstance(element_type, ir.IntegerType): result = arith_dialect.constant(element_type, int(v)) - elif isinstance(element_type, FloatType): + elif isinstance(element_type, ir.FloatType): result = arith_dialect.constant(element_type, float(v)) else: raise NotImplementedError @@ -1199,8 +1158,8 @@ def _expand_dims(x: ir.Value, axis: int) -> ir.Value: def _float_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: - src_element_type = FloatType(_element_type(src.type)) - dst_element_type = FloatType(_element_type(dst_type)) + src_element_type = ir.FloatType(_element_type(src.type)) + dst_element_type = ir.FloatType(_element_type(dst_type)) if src_element_type.width == 8 or dst_element_type.width == 8: return tt_dialect.fp_to_fp( dst_type, @@ -1234,8 +1193,8 @@ def _int_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: def _float_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: - src_element_type = FloatType(_element_type(src.type)) - if not src_element_type.is_standard: + src_element_type = _element_type(src.type) + if not isinstance(src_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)): raise NotImplementedError(f"cannot cast {src} tp {dst_type}") dst_element_type = ir.IntegerType(_element_type(dst_type)) if dst_element_type.width == 1: @@ -1248,8 +1207,8 @@ def _float_int_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: def _int_float_cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: src_element_type = ir.IntegerType(_element_type(src.type)) - dst_element_type = FloatType(_element_type(dst_type)) - if not dst_element_type.is_standard: + dst_element_type = _element_type(dst_type) + if not isinstance(dst_element_type, (ir.BF16Type, ir.F16Type, ir.F32Type, ir.F64Type)): raise NotImplementedError(f"cannot cast {src} tp {dst_type}") if src_element_type.width == 1 or not src_element_type.is_signed: return arith_dialect.uitofp(dst_type, src) @@ -1283,8 +1242,8 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: ): return _cast(_cast(src, ir.F32Type.get()), dst_type) - if isinstance(src_element_type, FloatType) and isinstance( - dst_element_type, FloatType + if isinstance(src_element_type, ir.FloatType) and isinstance( + dst_element_type, ir.FloatType ): return _float_float_cast(src, dst_type) @@ -1293,12 +1252,12 @@ def _cast(src: ir.Value, dst_type: ir.Type) -> ir.Value: ): return _int_int_cast(src, dst_type) - if isinstance(src_element_type, FloatType) and isinstance( + if isinstance(src_element_type, ir.FloatType) and isinstance( dst_element_type, ir.IntegerType ): return _float_int_cast(src, dst_type) if isinstance(src_element_type, ir.IntegerType) and isinstance( - dst_element_type, FloatType + dst_element_type, ir.FloatType ): return _int_float_cast(src, dst_type) @@ -1865,10 +1824,7 @@ def _dot( acc = _full(ir.RankedTensorType.get([m, n], element_type), 0) if max_num_imprecise_acc is None: - if ( - FloatType(x_type.element_type).width == 8 - and FloatType(y_type.element_type).width == 8 - ): + if isinstance(element_type, ir.FloatType) and element_type.width == 8: # TODO(slebedev): Fill in from options. raise NotImplementedError else: @@ -2435,7 +2391,7 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: if isinstance(t, ir.IntegerType): v = int(v) else: - assert isinstance(t, FloatType) + assert isinstance(t, ir.FloatType) v = float(v) return arith_dialect.constant(t, v) raise NotImplementedError