-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reapply "[mlir][py] better support for arith.constant construction" #84142
Conversation
20815bc
to
b246e0d
Compare
Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as `array.array`. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly.
b246e0d
to
d39cb79
Compare
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesArithmetic constants for vector types can be constructed from objects Reverts llvm/llvm-project#84103 Full diff: https://github.com/llvm/llvm-project/pull/84142.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 61c6917393f1f9..92da5df9bce665 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,6 +5,8 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
+from array import array as _array
+from typing import overload
try:
from ..ir import *
@@ -43,13 +45,37 @@ def _is_float_type(type: Type):
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
+ @overload
+ def __init__(self, value: Attribute, *, loc=None, ip=None):
+ ...
+
+ @overload
def __init__(
- self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
):
+ ...
+
+ def __init__(self, result, value, *, loc=None, ip=None):
+ if value is None:
+ assert isinstance(result, Attribute)
+ super().__init__(result, loc=loc, ip=ip)
+ return
+
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, _array):
+ if 8 * value.itemsize != result.element_type.width:
+ raise ValueError(
+ f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width."
+ )
+ if value.typecode in ["i", "l", "q"]:
+ super().__init__(DenseIntElementsAttr.get(value, type=result))
+ elif value.typecode in ["f", "d"]:
+ super().__init__(DenseFPElementsAttr.get(value, type=result))
+ else:
+ raise ValueError(f'Unsupported typecode: "{value.typecode}".')
else:
super().__init__(value, loc=loc, ip=ip)
@@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]:
def constant(
- result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 8bb80eed2b8105..c9af5e7b46db84 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
+from array import array
def run(f):
@@ -92,3 +93,42 @@ def __str__(self):
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
+
+
+# CHECK-LABEL: TEST: testArrayConstantConstruction
+@run
+def testArrayConstantConstruction():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i32_array = array("i", [1, 2, 3, 4])
+ i32 = IntegerType.get_signless(32)
+ vec_i32 = VectorType.get([2, 2], i32)
+ arith.constant(vec_i32, i32_array)
+ arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
+
+ # "q" is the equivalent of `long long` in C and requires at least
+ # 64 bit width integers on both Linux and Windows.
+ i64_array = array("q", [5, 6, 7, 8])
+ i64 = IntegerType.get_signless(64)
+ vec_i64 = VectorType.get([1, 4], i64)
+ arith.constant(vec_i64, i64_array)
+ arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
+
+ f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
+ f32 = F32Type.get()
+ vec_f32 = VectorType.get([4, 1], f32)
+ arith.constant(vec_f32, f32_array)
+ arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
+
+ f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
+ f64 = F64Type.get()
+ vec_f64 = VectorType.get([2, 1, 2], f64)
+ arith.constant(vec_f64, f64_array)
+ arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
+
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
+ print(module)
|
Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as
array.array
. Note thatuntil Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.
Reverts #84103