Skip to content

Commit

Permalink
Revert "Revert "[mlir][py] better support for arith.constant construc…
Browse files Browse the repository at this point in the history
…tion" (#…"

This reverts commit 96fc548.
  • Loading branch information
ftynse committed Mar 6, 2024
1 parent 6e27dd4 commit 10361ae
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
23 changes: 21 additions & 2 deletions mlir/python/mlir/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
from array import array as _array
from typing import overload

try:
from ..ir import *
Expand Down Expand Up @@ -43,13 +45,30 @@ def _is_float_type(type: Type):
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""

@overload
def __init__(self, value: Attribute, *, loc=None, ip=None):
...

@overload
def __init__(
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
):
...

def __init__(self, result, value, *, loc=None, ip=None):
if value is None:
assert isinstance(result, Attribute)
super().__init__(result, loc=loc, ip=ip)
return

if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, _array) and value.typecode in ["i", "l"]:
super().__init__(DenseIntElementsAttr.get(value, type=result))
elif isinstance(value, _array) and value.typecode in ["f", "d"]:
super().__init__(DenseFPElementsAttr.get(value, type=result))
else:
super().__init__(value, loc=loc, ip=ip)

Expand Down Expand Up @@ -79,6 +98,6 @@ def literal_value(self) -> Union[int, float]:


def constant(
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
38 changes: 38 additions & 0 deletions mlir/test/python/dialects/arith_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
from array import array


def run(f):
Expand Down Expand Up @@ -92,3 +93,40 @@ def __str__(self):
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)


# CHECK-LABEL: TEST: testArrayConstantConstruction
@run
def testArrayConstantConstruction():
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
i32_array = array("i", [1, 2, 3, 4])
i32 = IntegerType.get_signless(32)
vec_i32 = VectorType.get([2, 2], i32)
arith.constant(vec_i32, i32_array)
arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))

i64_array = array("l", [5, 6, 7, 8])
i64 = IntegerType.get_signless(64)
vec_i64 = VectorType.get([1, 4], i64)
arith.constant(vec_i64, i64_array)
arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))

f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
f32 = F32Type.get()
vec_f32 = VectorType.get([4, 1], f32)
arith.constant(vec_f32, f32_array)
arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))

f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
f64 = F64Type.get()
vec_f64 = VectorType.get([2, 1, 2], f64)
arith.constant(vec_f64, f64_array)
arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))

# CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
# CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
# CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
# CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
print(module)

0 comments on commit 10361ae

Please sign in to comment.