diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 61c6917393f1f..83a50c7ef244f 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -5,6 +5,8 @@ from ._arith_ops_gen import * from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * +from array import array as _array +from typing import overload try: from ..ir import * @@ -43,13 +45,30 @@ def _is_float_type(type: Type): class ConstantOp(ConstantOp): """Specialization for the constant op class.""" + @overload + def __init__(self, value: Attribute, *, loc=None, ip=None): + ... + + @overload def __init__( - self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None ): + ... + + def __init__(self, result, value, *, loc=None, ip=None): + if value is None: + assert isinstance(result, Attribute) + super().__init__(result, loc=loc, ip=ip) + return + if isinstance(value, int): super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) elif isinstance(value, float): super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, _array) and value.typecode in ["i", "l"]: + super().__init__(DenseIntElementsAttr.get(value, type=result)) + elif isinstance(value, _array) and value.typecode in ["f", "d"]: + super().__init__(DenseFPElementsAttr.get(value, type=result)) else: super().__init__(value, loc=loc, ip=ip) @@ -79,6 +98,6 @@ def literal_value(self) -> Union[int, float]: def constant( - result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None ) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index 8bb80eed2b810..ef0e1620bba99 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -4,6 +4,7 @@ from mlir.ir import * import mlir.dialects.arith as arith import mlir.dialects.func as func +from array import array def run(f): @@ -92,3 +93,40 @@ def __str__(self): b = a * a # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) print(b) + + +# CHECK-LABEL: TEST: testArrayConstantConstruction +@run +def testArrayConstantConstruction(): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + i32_array = array("i", [1, 2, 3, 4]) + i32 = IntegerType.get_signless(32) + vec_i32 = VectorType.get([2, 2], i32) + arith.constant(vec_i32, i32_array) + arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32)) + + i64_array = array("l", [5, 6, 7, 8]) + i64 = IntegerType.get_signless(64) + vec_i64 = VectorType.get([1, 4], i64) + arith.constant(vec_i64, i64_array) + arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64)) + + f32_array = array("f", [1.0, 2.0, 3.0, 4.0]) + f32 = F32Type.get() + vec_f32 = VectorType.get([4, 1], f32) + arith.constant(vec_f32, f32_array) + arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32)) + + f64_array = array("d", [1.0, 2.0, 3.0, 4.0]) + f64 = F64Type.get() + vec_f64 = VectorType.get([2, 1, 2], f64) + arith.constant(vec_f64, f64_array) + arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64)) + + # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32> + # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64> + # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32> + # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64> + print(module)