Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply "[mlir][py] better support for arith.constant construction" #84142

Merged
merged 1 commit into from
Mar 7, 2024

Conversation

ftynse
Copy link
Member

@ftynse ftynse commented Mar 6, 2024

Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as array.array. Note that
until Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.

Reverts #84103

@ftynse ftynse force-pushed the revert-84103-revert-83259-arith-const branch 4 times, most recently from 20815bc to b246e0d Compare March 7, 2024 10:44
Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as `array.array`. Note that
until Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.
@ftynse ftynse force-pushed the revert-84103-revert-83259-arith-const branch from b246e0d to d39cb79 Compare March 7, 2024 15:02
@ftynse ftynse changed the title WIP: Reapply "[mlir][py] better support for arith.constant construction"" Reapply "[mlir][py] better support for arith.constant construction"" Mar 7, 2024
@ftynse ftynse changed the title Reapply "[mlir][py] better support for arith.constant construction"" Reapply "[mlir][py] better support for arith.constant construction" Mar 7, 2024
@ftynse ftynse marked this pull request as ready for review March 7, 2024 16:13
@ftynse ftynse merged commit 5d59fa9 into main Mar 7, 2024
6 checks passed
@ftynse ftynse deleted the revert-84103-revert-83259-arith-const branch March 7, 2024 16:14
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Mar 7, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 7, 2024

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

Changes

Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as array.array. Note that
until Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.

Reverts llvm/llvm-project#84103


Full diff: https://github.com/llvm/llvm-project/pull/84142.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/arith.py (+28-2)
  • (modified) mlir/test/python/dialects/arith_dialect.py (+40)
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 61c6917393f1f9..92da5df9bce665 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,6 +5,8 @@
 from ._arith_ops_gen import *
 from ._arith_ops_gen import _Dialect
 from ._arith_enum_gen import *
+from array import array as _array
+from typing import overload
 
 try:
     from ..ir import *
@@ -43,13 +45,37 @@ def _is_float_type(type: Type):
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
+    @overload
+    def __init__(self, value: Attribute, *, loc=None, ip=None):
+        ...
+
+    @overload
     def __init__(
-        self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+        self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
     ):
+        ...
+
+    def __init__(self, result, value, *, loc=None, ip=None):
+        if value is None:
+            assert isinstance(result, Attribute)
+            super().__init__(result, loc=loc, ip=ip)
+            return
+
         if isinstance(value, int):
             super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
         elif isinstance(value, float):
             super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+        elif isinstance(value, _array):
+            if 8 * value.itemsize != result.element_type.width:
+                raise ValueError(
+                    f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width."
+                )
+            if value.typecode in ["i", "l", "q"]:
+                super().__init__(DenseIntElementsAttr.get(value, type=result))
+            elif value.typecode in ["f", "d"]:
+                super().__init__(DenseFPElementsAttr.get(value, type=result))
+            else:
+                raise ValueError(f'Unsupported typecode: "{value.typecode}".')
         else:
             super().__init__(value, loc=loc, ip=ip)
 
@@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]:
 
 
 def constant(
-    result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+    result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
 ) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 8bb80eed2b8105..c9af5e7b46db84 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,6 +4,7 @@
 from mlir.ir import *
 import mlir.dialects.arith as arith
 import mlir.dialects.func as func
+from array import array
 
 
 def run(f):
@@ -92,3 +93,42 @@ def __str__(self):
             b = a * a
             # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
             print(b)
+
+
+# CHECK-LABEL: TEST: testArrayConstantConstruction
+@run
+def testArrayConstantConstruction():
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            i32_array = array("i", [1, 2, 3, 4])
+            i32 = IntegerType.get_signless(32)
+            vec_i32 = VectorType.get([2, 2], i32)
+            arith.constant(vec_i32, i32_array)
+            arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
+
+            # "q" is the equivalent of `long long` in C and requires at least
+            # 64 bit width integers on both Linux and Windows.
+            i64_array = array("q", [5, 6, 7, 8])
+            i64 = IntegerType.get_signless(64)
+            vec_i64 = VectorType.get([1, 4], i64)
+            arith.constant(vec_i64, i64_array)
+            arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
+
+            f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
+            f32 = F32Type.get()
+            vec_f32 = VectorType.get([4, 1], f32)
+            arith.constant(vec_f32, f32_array)
+            arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
+
+            f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
+            f64 = F64Type.get()
+            vec_f64 = VectorType.get([2, 1, 2], f64)
+            arith.constant(vec_f64, f64_array)
+            arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
+
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
+        # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
+        print(module)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants