Skip to content

Commit

Permalink
tweak mlir shape_tensor helper, fewer MHLO ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 13, 2022
1 parent 8e4cf25 commit d719e5a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)

def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
) -> ir.RankedTensorType:
sizes = [ir_constant(np.array(d, np.dtype('int32'))) if type(d) is int else d
for d in sizes]
int1d = aval_to_ir_type(core.ShapedArray((1,), np.dtype('int32')))
return mhlo.ConcatenateOp([mhlo.ReshapeOp(int1d, d) for d in sizes],
i64_attr(0)).results
d, *ds = [ir_constant(np.array([d], np.dtype('int32'))) if type(d) is int
else mhlo.ReshapeOp(int1d, d) for d in sizes]
if not ds:
return d
else:
return mhlo.ConcatenateOp([d, *ds], i64_attr(0)).result


# IR Types
Expand Down

0 comments on commit d719e5a

Please sign in to comment.