Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][py] better support for arith.constant construction #83259

Merged
merged 1 commit into from
Mar 5, 2024

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Feb 28, 2024

Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as array.array. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly.

Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as `array.array`. Note that
until Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 28, 2024

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as array.array. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly.


Full diff: https://github.com/llvm/llvm-project/pull/83259.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/arith.py (+21-2)
  • (modified) mlir/test/python/dialects/arith_dialect.py (+38)
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 61c6917393f1f9..83a50c7ef244f1 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,6 +5,8 @@
 from ._arith_ops_gen import *
 from ._arith_ops_gen import _Dialect
 from ._arith_enum_gen import *
+from array import array as _array
+from typing import overload
 
 try:
     from ..ir import *
@@ -43,13 +45,30 @@ def _is_float_type(type: Type):
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
+    @overload
+    def __init__(self, value: Attribute, *, loc=None, ip=None):
+        ...
+
+    @overload
     def __init__(
-        self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+        self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
     ):
+        ...
+
+    def __init__(self, result, value, *, loc=None, ip=None):
+        if value is None:
+            assert isinstance(result, Attribute)
+            super().__init__(result, loc=loc, ip=ip)
+            return
+
         if isinstance(value, int):
             super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
         elif isinstance(value, float):
             super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+        elif isinstance(value, _array) and value.typecode in ["i", "l"]:
+            super().__init__(DenseIntElementsAttr.get(value, type=result))
+        elif isinstance(value, _array) and value.typecode in ["f", "d"]:
+            super().__init__(DenseFPElementsAttr.get(value, type=result))
         else:
             super().__init__(value, loc=loc, ip=ip)
 
@@ -79,6 +98,6 @@ def literal_value(self) -> Union[int, float]:
 
 
 def constant(
-    result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+    result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 8bb80eed2b8105..ef0e1620bba990 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
 from mlir.ir import *
 import mlir.dialects.arith as arith
 import mlir.dialects.func as func
+from array import array
 
 
 def run(f):
@@ -92,3 +93,40 @@ def __str__(self):
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
             print(b)
+
+
+# CHECK-LABEL: TEST: testArrayConstantConstruction
+@run
+def testArrayConstantConstruction():
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i32_array = array("i", [1, 2, 3, 4])
+            i32 = IntegerType.get_signless(32)
+            vec_i32 = VectorType.get([2, 2], i32)
+            arith.constant(vec_i32, i32_array)
+            arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
+
+            i64_array = array("l", [5, 6, 7, 8])
+            i64 = IntegerType.get_signless(64)
+            vec_i64 = VectorType.get([1, 4], i64)
+            arith.constant(vec_i64, i64_array)
+            arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
+
+            f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
+            f32 = F32Type.get()
+            vec_f32 = VectorType.get([4, 1], f32)
+            arith.constant(vec_f32, f32_array)
+            arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
+
+            f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
+            f64 = F64Type.get()
+            vec_f64 = VectorType.get([2, 1, 2], f64)
+            arith.constant(vec_f64, f64_array)
+            arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
+
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
+        print(module)

Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Do you think it makes sense to provide additional overloads where the result type is inferred from the type of the 'array' value?

e.g 'arith.constant(array("i", [1,2,3,4]))' could produce a value of type 'vector<4xi32>'.

@joker-eph
Copy link
Collaborator

Why are we using strings for the dtype?

@martin-luecke
Copy link
Contributor

Are you referring to the ’array.array’ interface?
That is from the Python api here: https://docs.python.org/3/library/array.html

@ftynse
Copy link
Member Author

ftynse commented Mar 5, 2024

e.g 'arith.constant(array("i", [1,2,3,4]))' could produce a value of type 'vector<4xi32>'.

We can have either vector or tensor type for these, I didn't want to prioritize one or another.

@ftynse ftynse merged commit a691f65 into llvm:main Mar 5, 2024
6 of 7 checks passed
@ftynse ftynse deleted the arith-const branch March 5, 2024 15:10
@joker-eph
Copy link
Collaborator

The MLIR tests didn't run in pre-merge because the Flang test failed first, but I believe this broke the arith_dialect.py test.

joker-eph added a commit that referenced this pull request Mar 6, 2024
joker-eph added a commit that referenced this pull request Mar 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants