-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][py] better support for arith.constant construction #83259
Conversation
Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as `array.array`. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly.
@llvm/pr-subscribers-mlir Author: Oleksandr "Alex" Zinenko (ftynse) ChangesArithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as Full diff: https://github.com/llvm/llvm-project/pull/83259.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 61c6917393f1f9..83a50c7ef244f1 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,6 +5,8 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
+from array import array as _array
+from typing import overload
try:
from ..ir import *
@@ -43,13 +45,30 @@ def _is_float_type(type: Type):
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
+ @overload
+ def __init__(self, value: Attribute, *, loc=None, ip=None):
+ ...
+
+ @overload
def __init__(
- self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
):
+ ...
+
+ def __init__(self, result, value, *, loc=None, ip=None):
+ if value is None:
+ assert isinstance(result, Attribute)
+ super().__init__(result, loc=loc, ip=ip)
+ return
+
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, _array) and value.typecode in ["i", "l"]:
+ super().__init__(DenseIntElementsAttr.get(value, type=result))
+ elif isinstance(value, _array) and value.typecode in ["f", "d"]:
+ super().__init__(DenseFPElementsAttr.get(value, type=result))
else:
super().__init__(value, loc=loc, ip=ip)
@@ -79,6 +98,6 @@ def literal_value(self) -> Union[int, float]:
def constant(
- result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 8bb80eed2b8105..ef0e1620bba990 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
+from array import array
def run(f):
@@ -92,3 +93,40 @@ def __str__(self):
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
+
+
+# CHECK-LABEL: TEST: testArrayConstantConstruction
+@run
+def testArrayConstantConstruction():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ i32_array = array("i", [1, 2, 3, 4])
+ i32 = IntegerType.get_signless(32)
+ vec_i32 = VectorType.get([2, 2], i32)
+ arith.constant(vec_i32, i32_array)
+ arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
+
+ i64_array = array("l", [5, 6, 7, 8])
+ i64 = IntegerType.get_signless(64)
+ vec_i64 = VectorType.get([1, 4], i64)
+ arith.constant(vec_i64, i64_array)
+ arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
+
+ f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
+ f32 = F32Type.get()
+ vec_f32 = VectorType.get([4, 1], f32)
+ arith.constant(vec_f32, f32_array)
+ arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
+
+ f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
+ f64 = F64Type.get()
+ vec_f64 = VectorType.get([2, 1, 2], f64)
+ arith.constant(vec_f64, f64_array)
+ arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
+
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
+ # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
+ print(module)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Do you think it makes sense to provide additional overloads where the result type is inferred from the type of the 'array' value?
e.g 'arith.constant(array("i", [1,2,3,4]))' could produce a value of type 'vector<4xi32>'.
Why are we using strings for the dtype? |
Are you referring to the ’array.array’ interface? |
We can have either vector or tensor type for these, I didn't want to prioritize one or another. |
The MLIR tests didn't run in pre-merge because the Flang test failed first, but I believe this broke the arith_dialect.py test. |
Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as
array.array
. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly.