Skip to content

Commit

Permalink
[Pallas] Add support for casting int8->fp* in Mosaic lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559313313
  • Loading branch information
sharadmv authored and jax authors committed Aug 23, 2023
1 parent f08df0f commit bad217b
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,20 +736,47 @@ def _dot_general_lowering_rule(

lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule

_INT_DTYPES = {
8: np.dtype(np.int8),
16: np.dtype(np.int16),
32: np.dtype(np.int32),
}


def _convert_element_type_lowering_rule(
ctx: LoweringRuleContext, x, *, new_dtype, weak_type
):
del weak_type
out_aval = ctx.avals_out[0]
old_dtype = ctx.avals_in[0].dtype
out_type = aval_to_ir_type(ctx.avals_out[0])
if old_dtype == jnp.float32 and new_dtype == jnp.bfloat16:
return arith.TruncFOp(out_type, x).result
elif old_dtype == jnp.bfloat16 and new_dtype == jnp.float32:
return arith.ExtFOp(out_type, x).result
elif old_dtype == jnp.bool_ and new_dtype == jnp.int32:
out_type = aval_to_ir_type(out_aval)
if old_dtype == new_dtype:
return x
if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
new_dtype, jnp.floating
):
if old_dtype.itemsize < new_dtype.itemsize:
return arith.ExtFOp(out_type, x).result
else:
return arith.TruncFOp(out_type, x).result
elif old_dtype == jnp.bool_ and jnp.issubdtype(new_dtype, jnp.integer):
return arith.ExtSIOp(out_type, x).result
# TODO(sharadmv,apaszke): Defaulting to bitcast is unreasonable.
return arith.BitcastOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
new_dtype, jnp.floating
):
return arith.SIToFPOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype(
new_dtype, jnp.signedinteger
):
if old_dtype.itemsize < new_dtype.itemsize:
return arith.ExtSIOp(out_type, x).result
else:
return arith.TruncIOp(out_type, x).result
elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
new_dtype, jnp.signedinteger
):
return arith.FPToSIOp(out_type, x).result
raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")


lowering_rules[lax.convert_element_type_p] = _convert_element_type_lowering_rule
Expand Down

0 comments on commit bad217b

Please sign in to comment.