Skip to content

Commit

Permalink
[mlir][py] better support for arith.constant construction
Browse files Browse the repository at this point in the history
Arithmetic constants for vector types can be constructed from objects
implementing Python buffer protocol such as `array.array`. Note that
until Python 3.12, there is no typing support for buffer protocol
implementers, so the annotations use array explicitly.
  • Loading branch information
ftynse committed Mar 7, 2024
1 parent 03588a2 commit d39cb79
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
30 changes: 28 additions & 2 deletions mlir/python/mlir/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
from array import array as _array
from typing import overload

try:
from ..ir import *
Expand Down Expand Up @@ -43,13 +45,37 @@ def _is_float_type(type: Type):
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""

@overload
def __init__(self, value: Attribute, *, loc=None, ip=None):
...

@overload
def __init__(
self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
):
...

def __init__(self, result, value, *, loc=None, ip=None):
if value is None:
assert isinstance(result, Attribute)
super().__init__(result, loc=loc, ip=ip)
return

if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, _array):
if 8 * value.itemsize != result.element_type.width:
raise ValueError(
f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width."
)
if value.typecode in ["i", "l", "q"]:
super().__init__(DenseIntElementsAttr.get(value, type=result))
elif value.typecode in ["f", "d"]:
super().__init__(DenseFPElementsAttr.get(value, type=result))
else:
raise ValueError(f'Unsupported typecode: "{value.typecode}".')
else:
super().__init__(value, loc=loc, ip=ip)

Expand Down Expand Up @@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]:


def constant(
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
40 changes: 40 additions & 0 deletions mlir/test/python/dialects/arith_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
from array import array


def run(f):
Expand Down Expand Up @@ -92,3 +93,42 @@ def __str__(self):
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)


# CHECK-LABEL: TEST: testArrayConstantConstruction
@run
def testArrayConstantConstruction():
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
i32_array = array("i", [1, 2, 3, 4])
i32 = IntegerType.get_signless(32)
vec_i32 = VectorType.get([2, 2], i32)
arith.constant(vec_i32, i32_array)
arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))

# "q" is the equivalent of `long long` in C and requires at least
# 64 bit width integers on both Linux and Windows.
i64_array = array("q", [5, 6, 7, 8])
i64 = IntegerType.get_signless(64)
vec_i64 = VectorType.get([1, 4], i64)
arith.constant(vec_i64, i64_array)
arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))

f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
f32 = F32Type.get()
vec_f32 = VectorType.get([4, 1], f32)
arith.constant(vec_f32, f32_array)
arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))

f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
f64 = F64Type.get()
vec_f64 = VectorType.get([2, 1, 2], f64)
arith.constant(vec_f64, f64_array)
arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))

# CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
# CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
# CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
# CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
print(module)

0 comments on commit d39cb79

Please sign in to comment.