From 3ec4809f2b5a76d14eba1a0707e30791cc5cc805 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Wed, 19 Nov 2025 01:04:38 +0800 Subject: [PATCH 1/8] [MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings --- mlir/python/CMakeLists.txt | 4 +- .../dialects/{irdl.py => irdl/__init__.py} | 13 +- mlir/python/mlir/dialects/irdl/dsl.py | 343 ++++++++++++++++++ mlir/test/python/dialects/irdsl.py | 308 ++++++++++++++++ 4 files changed, 661 insertions(+), 7 deletions(-) rename mlir/python/mlir/dialects/{irdl.py => irdl/__init__.py} (91%) create mode 100644 mlir/python/mlir/dialects/irdl/dsl.py create mode 100644 mlir/test/python/dialects/irdsl.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 2acb6ee6cfda5..27b87a8b70144 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/IRDLOps.td - SOURCES dialects/irdl.py + SOURCES + dialects/irdl/__init__.py + dialects/irdl/dsl.py DIALECT_NAME irdl GEN_ENUM_BINDINGS ) diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl/__init__.py similarity index 91% rename from mlir/python/mlir/dialects/irdl.py rename to mlir/python/mlir/dialects/irdl/__init__.py index 1ec951b69b646..6b2787ed7966c 100644 --- a/mlir/python/mlir/dialects/irdl.py +++ b/mlir/python/mlir/dialects/irdl/__init__.py @@ -2,13 +2,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._irdl_ops_gen import * -from ._irdl_ops_gen import _Dialect -from ._irdl_enum_gen import * -from .._mlir_libs._mlirDialectsIRDL import * -from ..ir import register_attribute_builder -from ._ods_common import _cext as _ods_cext +from .._irdl_ops_gen import * +from .._irdl_ops_gen import _Dialect +from .._irdl_enum_gen import * +from ..._mlir_libs._mlirDialectsIRDL import * +from ...ir import register_attribute_builder +from .._ods_common import _cext as _ods_cext from typing import Union, Sequence +from . import dsl _ods_ir = _ods_cext.ir diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py new file mode 100644 index 0000000000000..3cc234503665a --- /dev/null +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -0,0 +1,343 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...dialects import irdl as _irdl +from .._ods_common import ( + _cext as _ods_cext, + segmented_accessor as _ods_segmented_accessor, +) +from . import Variadicity +from typing import Dict, List, Union, Callable, Tuple +from dataclasses import dataclass +from inspect import Parameter as _Parameter, Signature as _Signature +from types import SimpleNamespace as _SimpleNameSpace + +_ods_ir = _ods_cext.ir + + +class ConstraintExpr: + def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value: + raise NotImplementedError() + + def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr": + return AnyOf(self, other) + + def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr": + return AllOf(self, other) + + +class ConstraintLoweringContext: + def __init__(self): + # Cache so that the same ConstraintExpr instance reuses its SSA value. + self._cache: Dict[int, _ods_ir.Value] = {} + + def lower(self, expr: ConstraintExpr) -> _ods_ir.Value: + key = id(expr) + if key in self._cache: + return self._cache[key] + v = expr._lower(self) + self._cache[key] = v + return v + + +class Is(ConstraintExpr): + def __init__(self, attr: _ods_ir.Attribute): + self.attr = attr + + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: + return _irdl.is_(self.attr) + + +class IsType(Is): + def __init__(self, typ: _ods_ir.Type): + super().__init__(_ods_ir.TypeAttr.get(typ)) + + +class AnyOf(ConstraintExpr): + def __init__(self, *exprs: ConstraintExpr): + self.exprs = exprs + + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: + return _irdl.any_of(ctx.lower(expr) for expr in self.exprs) + + +class AllOf(ConstraintExpr): + def __init__(self, *exprs: ConstraintExpr): + self.exprs = exprs + + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: + return _irdl.all_of(ctx.lower(expr) for expr in self.exprs) + + +class Any(ConstraintExpr): + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: + return _irdl.any() + + +class BaseName(ConstraintExpr): + def __init__(self, name: str): + self.name = name + + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: + return _irdl.base(base_name=self.name) + + +class BaseRef(ConstraintExpr): + def __init__(self, ref): + self.ref = ref + + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: + return _irdl.base(base_ref=self.ref) + + +class FieldDef: + def __set_name__(self, owner, name: str): + self.name = name + + +@dataclass +class Operand(FieldDef): + constraint: ConstraintExpr + variadicity: Variadicity = Variadicity.single + + +@dataclass +class Result(FieldDef): + constraint: ConstraintExpr + variadicity: Variadicity = Variadicity.single + + +@dataclass +class Attribute(FieldDef): + constraint: ConstraintExpr + + def __post_init__(self): + # just for unified processing, + # currently optional attribute is not supported by IRDL + self.variadicity = Variadicity.single + + +@dataclass +class Operation: + dialect_name: str + name: str + # We store operands and attributes into one list to maintain relative orders + # among them for generating OpView class. + operands_and_attrs: List[Union[Operand, Attribute]] + results: List[Result] + + def _emit(self) -> None: + op = _irdl.operation_(self.name) + ctx = ConstraintLoweringContext() + + operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] + attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] + + with _ods_ir.InsertionPoint(op.body): + if operands: + _irdl.operands_( + [ctx.lower(i.constraint) for i in operands], + [i.name for i in operands], + [i.variadicity for i in operands], + ) + if attrs: + _irdl.attributes_( + [ctx.lower(i.constraint) for i in attrs], + [i.name for i in attrs], + ) + if self.results: + _irdl.results_( + [ctx.lower(i.constraint) for i in self.results], + [i.name for i in self.results], + [i.variadicity for i in self.results], + ) + + def _make_op_view_and_builder(self) -> Tuple[type, Callable]: + operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] + attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] + + def variadicity_to_segment(variadicity: Variadicity) -> int: + if variadicity == Variadicity.variadic: + return -1 + if variadicity == Variadicity.optional: + return 0 + return 1 + + operand_segments = None + if any(i.variadicity != Variadicity.single for i in operands): + operand_segments = [variadicity_to_segment(i.variadicity) for i in operands] + + result_segments = None + if any(i.variadicity != Variadicity.single for i in self.results): + result_segments = [ + variadicity_to_segment(i.variadicity) for i in self.results + ] + + args = self.results + self.operands_and_attrs + positional_args = [ + i.name for i in args if i.variadicity != Variadicity.optional + ] + optional_args = [i.name for i in args if i.variadicity == Variadicity.optional] + + params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)] + for i in positional_args: + params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD)) + for i in optional_args: + params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None)) + params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None)) + params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None)) + + sig = _Signature(params) + op = self + + class _OpView(_ods_ir.OpView): + OPERATION_NAME = f"{op.dialect_name}.{op.name}" + _ODS_REGIONS = (0, True) + _ODS_OPERAND_SEGMENTS = operand_segments + _ODS_RESULT_SEGMENTS = result_segments + + def __init__(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + args = bound.arguments + + _operands = [args[operand.name] for operand in operands] + _results = [args[result.name] for result in op.results] + _attributes = dict( + (attr.name, args[attr.name]) + for attr in attrs + if args[attr.name] is not None + ) + _regions = None + _ods_successors = None + self = args["self"] + super(_OpView, self).__init__( + self.OPERATION_NAME, + self._ODS_REGIONS, + self._ODS_OPERAND_SEGMENTS, + self._ODS_RESULT_SEGMENTS, + attributes=_attributes, + results=_results, + operands=_operands, + successors=_ods_successors, + regions=_regions, + loc=args["loc"], + ip=args["ip"], + ) + + __init__.__signature__ = sig + + for attr in attrs: + setattr( + _OpView, + attr.name, + property(lambda self, name=attr.name: self.attributes[name]), + ) + + def value_range_getter( + value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList], + variadicity: Variadicity, + ): + if variadicity == Variadicity.single: + return value_range[0] + if variadicity == Variadicity.optional: + return value_range[0] if len(value_range) > 0 else None + return value_range + + for i, operand in enumerate(operands): + if operand_segments: + + def getter(self, i=i, operand=operand): + operand_range = _ods_segmented_accessor( + self.operation.operands, + self.operation.attributes["operandSegmentSizes"], + i, + ) + return value_range_getter(operand_range, operand.variadicity) + + setattr(_OpView, operand.name, property(getter)) + else: + setattr( + _OpView, operand.name, property(lambda self, i=i: self.operands[i]) + ) + for i, result in enumerate(self.results): + if result_segments: + + def getter(self, i=i, result=result): + result_range = _ods_segmented_accessor( + self.operation.results, + self.operation.attributes["resultSegmentSizes"], + i, + ) + return value_range_getter(result_range, result.variadicity) + + setattr(_OpView, result.name, property(getter)) + else: + setattr( + _OpView, result.name, property(lambda self, i=i: self.results[i]) + ) + + def _builder(*args, **kwargs) -> _OpView: + return _OpView(*args, **kwargs) + + _builder.__signature__ = _Signature(params[1:]) + + return _OpView, _builder + + +class Dialect: + def __init__(self, name: str): + self.name = name + self.operations: List[Operation] = [] + self.namespace = _SimpleNameSpace() + + def _emit(self) -> None: + d = _irdl.dialect(self.name) + with _ods_ir.InsertionPoint(d.body): + for op in self.operations: + op._emit() + + def _make_module(self) -> _ods_ir.Module: + with _ods_ir.Location.unknown(): + m = _ods_ir.Module.create() + with _ods_ir.InsertionPoint(m.body): + self._emit() + return m + + def _make_dialect_class(self) -> type: + class _Dialect(_ods_ir.Dialect): + DIALECT_NAMESPACE = self.name + + return _Dialect + + def load(self) -> _SimpleNameSpace: + _irdl.load_dialects(self._make_module()) + dialect_class = self._make_dialect_class() + _ods_cext.register_dialect(dialect_class) + for op in self.operations: + _ods_cext.register_operation(dialect_class)(op.op_view) + return self.namespace + + def op(self, name: str) -> Callable[[type], type]: + def decorator(cls: type) -> type: + operands_and_attrs: List[Union[Operand, Attribute]] = [] + results: List[Result] = [] + + for field in cls.__dict__.values(): + if isinstance(field, Operand) or isinstance(field, Attribute): + operands_and_attrs.append(field) + elif isinstance(field, Result): + results.append(field) + + op_def = Operation(self.name, name, operands_and_attrs, results) + op_view, builder = op_def._make_op_view_and_builder() + setattr(op_def, "op_view", op_view) + setattr(op_def, "builder", builder) + self.operations.append(op_def) + self.namespace.__dict__[cls.__name__] = op_view + op_view.__name__ = cls.__name__ + self.namespace.__dict__[name.replace(".", "_")] = builder + return cls + + return decorator diff --git a/mlir/test/python/dialects/irdsl.py b/mlir/test/python/dialects/irdsl.py new file mode 100644 index 0000000000000..8ef30ae0a4c13 --- /dev/null +++ b/mlir/test/python/dialects/irdsl.py @@ -0,0 +1,308 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.dialects.irdl import dsl as irdsl +from mlir.dialects import arith +import sys + + +def run(f): + print("\nTEST:", f.__name__, file=sys.stderr) + with Context(): + f() + + +# CHECK: TEST: testMyInt +@run +def testMyInt(): + myint = irdsl.Dialect("myint") + iattr = irdsl.BaseName("#builtin.integer") + i32 = irdsl.IsType(IntegerType.get_signless(32)) + + @myint.op("constant") + class ConstantOp: + value = irdsl.Attribute(iattr) + cst = irdsl.Result(i32) + + @myint.op("add") + class AddOp: + lhs = irdsl.Operand(i32) + rhs = irdsl.Operand(i32) + res = irdsl.Result(i32) + + # CHECK: irdl.dialect @myint { + # CHECK: irdl.operation @constant { + # CHECK: %0 = irdl.base "#builtin.integer" + # CHECK: irdl.attributes {"value" = %0} + # CHECK: %1 = irdl.is i32 + # CHECK: irdl.results(cst: %1) + # CHECK: } + # CHECK: irdl.operation @add { + # CHECK: %0 = irdl.is i32 + # CHECK: irdl.operands(lhs: %0, rhs: %0) + # CHECK: irdl.results(res: %0) + # CHECK: } + # CHECK: } + print(myint._make_module()) + myint = myint.load() + + # CHECK: ['ConstantOp', 'constant', 'AddOp', 'add'] + print([i for i in myint.__dict__.keys()]) + + i32 = IntegerType.get_signless(32) + with Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + two = myint.constant(i32, IntegerAttr.get(i32, 2)) + three = myint.constant(i32, IntegerAttr.get(i32, 3)) + add1 = myint.add(i32, two, three) + add2 = myint.add(i32, add1, two) + add3 = myint.add(i32, add2, three) + + # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32 + # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32 + # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32 + # CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32 + # CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32 + print(module) + assert module.operation.verify() + + # CHECK: AddOp + print(type(add1).__name__) + # CHECK: ConstantOp + print(type(two).__name__) + # CHECK: myint.add + print(add1.OPERATION_NAME) + # CHECK: None + print(add1._ODS_OPERAND_SEGMENTS) + # CHECK: None + print(add1._ODS_RESULT_SEGMENTS) + # CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32 + print(add1.lhs.owner) + # CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32 + print(add1.rhs.owner) + # CHECK: 2 : i32 + print(two.value) + # CHECK: Value(%0 + print(two.cst) + # CHECK: (res, lhs, rhs, *, loc=None, ip=None) + print(myint.add.__signature__) + # CHECK: (cst, value, *, loc=None, ip=None) + print(myint.constant.__signature__) + + +@run +def testIRDSL(): + test = irdsl.Dialect("irdsl_test") + i32 = irdsl.IsType(IntegerType.get_signless(32)) + i64 = irdsl.IsType(IntegerType.get_signless(64)) + i32or64 = i32 | i64 + any = irdsl.Any() + f32 = irdsl.IsType(F32Type.get()) + iattr = irdsl.BaseName("#builtin.integer") + fattr = irdsl.BaseName("#builtin.float") + + @test.op("constraint") + class ConstraintOp: + a = irdsl.Operand(i32or64) + b = irdsl.Operand(any) + c = irdsl.Operand(f32 | i32) + d = irdsl.Operand(any) + x = irdsl.Attribute(iattr) + y = irdsl.Attribute(fattr) + + @test.op("optional") + class OptionalOp: + a = irdsl.Operand(i32) + b = irdsl.Operand(i32, irdsl.Variadicity.optional) + out1 = irdsl.Result(i32) + out2 = irdsl.Result(i32, irdsl.Variadicity.optional) + out3 = irdsl.Result(i32) + + @test.op("optional2") + class Optional2Op: + a = irdsl.Operand(i32, irdsl.Variadicity.optional) + b = irdsl.Result(i32, irdsl.Variadicity.optional) + + @test.op("variadic") + class VariadicOp: + a = irdsl.Operand(i32) + b = irdsl.Operand(i32, irdsl.Variadicity.optional) + c = irdsl.Operand(i32, irdsl.Variadicity.variadic) + out1 = irdsl.Result(i32, irdsl.Variadicity.variadic) + out2 = irdsl.Result(i32, irdsl.Variadicity.variadic) + out3 = irdsl.Result(i32, irdsl.Variadicity.optional) + out4 = irdsl.Result(i32) + + @test.op("variadic2") + class Variadic2Op: + a = irdsl.Operand(i32, irdsl.Variadicity.variadic) + b = irdsl.Result(i32, irdsl.Variadicity.variadic) + + @test.op("mixed") + class MixedOp: + out = irdsl.Result(i32) + in1 = irdsl.Operand(i32) + in2 = irdsl.Attribute(iattr) + in3 = irdsl.Operand(i32, irdsl.Variadicity.optional) + in4 = irdsl.Attribute(iattr) + in5 = irdsl.Operand(i32) + + # CHECK: irdl.dialect @irdsl_test { + # CHECK: irdl.operation @constraint { + # CHECK: %0 = irdl.is i32 + # CHECK: %1 = irdl.is i64 + # CHECK: %2 = irdl.any_of(%0, %1) + # CHECK: %3 = irdl.any + # CHECK: %4 = irdl.is f32 + # CHECK: %5 = irdl.any_of(%4, %0) + # CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %3) + # CHECK: %6 = irdl.base "#builtin.integer" + # CHECK: %7 = irdl.base "#builtin.float" + # CHECK: irdl.attributes {"x" = %6, "y" = %7} + # CHECK: } + # CHECK: irdl.operation @optional { + # CHECK: %0 = irdl.is i32 + # CHECK: irdl.operands(a: %0, b: optional %0) + # CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0) + # CHECK: } + # CHECK: irdl.operation @optional2 { + # CHECK: %0 = irdl.is i32 + # CHECK: irdl.operands(a: optional %0) + # CHECK: irdl.results(b: optional %0) + # CHECK: } + # CHECK: irdl.operation @variadic { + # CHECK: %0 = irdl.is i32 + # CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0) + # CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0) + # CHECK: } + # CHECK: irdl.operation @variadic2 { + # CHECK: %0 = irdl.is i32 + # CHECK: irdl.operands(a: variadic %0) + # CHECK: irdl.results(b: variadic %0) + # CHECK: } + # CHECK: irdl.operation @mixed { + # CHECK: %0 = irdl.is i32 + # CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0) + # CHECK: %1 = irdl.base "#builtin.integer" + # CHECK: irdl.attributes {"in2" = %1, "in4" = %1} + # CHECK: irdl.results(out: %0) + # CHECK: } + # CHECK: } + print(test._make_module()) + test = test.load() + + # CHECK: (a, b, c, d, x, y, *, loc=None, ip=None) + print(test.constraint.__signature__) + # CHECK: (out1, out3, a, *, out2=None, b=None, loc=None, ip=None) + print(test.optional.__signature__) + # CHECK: (*, b=None, a=None, loc=None, ip=None) + print(test.optional2.__signature__) + # CHECK: (out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None) + print(test.variadic.__signature__) + # CHECK: (b, a, *, loc=None, ip=None) + print(test.variadic2.__signature__) + # CHECK: (out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None) + print(test.mixed.__signature__) + + # CHECK: None None + print( + test.ConstraintOp._ODS_OPERAND_SEGMENTS, test.ConstraintOp._ODS_RESULT_SEGMENTS + ) + # CHECK: [1, 0] [1, 0, 1] + print(test.OptionalOp._ODS_OPERAND_SEGMENTS, test.OptionalOp._ODS_RESULT_SEGMENTS) + # CHECK: [0] [0] + print(test.Optional2Op._ODS_OPERAND_SEGMENTS, test.Optional2Op._ODS_RESULT_SEGMENTS) + # CHECK: [1, 0, -1] [-1, -1, 0, 1] + print(test.VariadicOp._ODS_OPERAND_SEGMENTS, test.VariadicOp._ODS_RESULT_SEGMENTS) + # CHECK: [-1] [-1] + print(test.Variadic2Op._ODS_OPERAND_SEGMENTS, test.Variadic2Op._ODS_RESULT_SEGMENTS) + + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + f32 = F32Type.get() + + with Location.unknown(): + iattr = IntegerAttr.get(i32, 2) + fattr = FloatAttr.get_f32(2.3) + + module = Module.create() + with InsertionPoint(module.body): + ione = arith.constant(i32, 1) + fone = arith.constant(f32, 1.2) + + # CHECK: "irdsl_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> () + c1 = test.constraint(ione, ione, fone, ione, iattr, fattr) + # CHECK: "irdsl_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> () + test.constraint(ione, fone, fone, fone, iattr, fattr) + # CHECK: irdsl_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> () + test.constraint(ione, fone, ione, fone, iattr, fattr) + + # CHECK: %0:2 = "irdsl_test.optional"(%c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32) -> (i32, i32) + o1 = test.optional(i32, i32, ione) + # CHECK: %1:3 = "irdsl_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32, i32) -> (i32, i32, i32) + o2 = test.optional(i32, i32, ione, out2=i32, b=ione) + # CHECK: irdsl_test.optional2"() {operandSegmentSizes = array, resultSegmentSizes = array} : () -> () + o3 = test.optional2() + # CHECK: %2 = "irdsl_test.optional2"() {operandSegmentSizes = array, resultSegmentSizes = array} : () -> i32 + o4 = test.optional2(b=i32) + # CHECK: "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32) -> () + o5 = test.optional2(a=ione) + # CHECK: %3 = "irdsl_test.optional2"(%c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32) -> i32 + o6 = test.optional2(b=i32, a=ione) + + # CHECK: %4:4 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32, i32, i32) -> (i32, i32, i32, i32) + v1 = test.variadic([i32], [i32, i32], i32, ione, [ione, ione]) + # CHECK: %5:5 = "irdsl_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32, i32, i32) -> (i32, i32, i32, i32, i32) + v2 = test.variadic([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione) + # CHECK: %6:4 = "irdsl_test.variadic"(%c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32) -> (i32, i32, i32, i32) + v3 = test.variadic([i32, i32], [i32], i32, ione, []) + # CHECK: "irdsl_test.variadic2"() {operandSegmentSizes = array, resultSegmentSizes = array} : () -> () + v4 = test.variadic2([], []) + # CHECK: "irdsl_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32, i32, i32) -> () + v5 = test.variadic2([], [ione, ione, ione]) + # CHECK: %7:2 = "irdsl_test.variadic2"(%c1_i32) {operandSegmentSizes = array, resultSegmentSizes = array} : (i32) -> (i32, i32) + v6 = test.variadic2([i32, i32], [ione]) + + # CHECK: %8 = "irdsl_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array} : (i32, i32) -> i32 + m1 = test.mixed(i32, ione, iattr, iattr, ione) + # CHECK: %9 = "irdsl_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array} : (i32, i32, i32) -> i32 + m2 = test.mixed(i32, ione, iattr, iattr, ione, in3=ione) + + print(module) + assert module.operation.verify() + + # CHECK: Value(%c1_i32 + print(c1.a) + # CHECK: 2 : i32 + print(c1.x) + # CHECK: Value(%c1_i32 + print(o1.a) + # CHECK: None + print(o1.b) + # CHECK: Value(%c1_i32 + print(o2.b) + # CHECK: 0 + print(o1.out1.result_number) + # CHECK: None + print(o1.out2) + # CHECK: 0 + print(o2.out1.result_number) + # CHECK: 1 + print(o2.out2.result_number) + # CHECK: None + print(o3.a) + # CHECK: Value(%c1_i32 + print(o5.a) + # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)', 'Value(%c1_i32 = arith.constant 1 : i32)'] + print([str(i) for i in v1.c]) + # CHECK: ['Value(%c1_i32 = arith.constant 1 : i32)'] + print([str(i) for i in v2.c]) + # CHECK: [] + print([str(i) for i in v3.c]) + # CHECK: 0 0 + print(len(v4.a), len(v4.b)) + # CHECK: 3 0 + print(len(v5.a), len(v5.b)) + # CHECK: 1 2 + print(len(v6.a), len(v6.b)) From c0b8ba70e97dd09179ba75c6ae5dde04d3be152e Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 12:55:21 +0800 Subject: [PATCH 2/8] make FieldDef private --- mlir/python/mlir/dialects/irdl/dsl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index 3cc234503665a..b1916fdbd5f9b 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -91,25 +91,25 @@ def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: return _irdl.base(base_ref=self.ref) -class FieldDef: +class _FieldDef: def __set_name__(self, owner, name: str): self.name = name @dataclass -class Operand(FieldDef): +class Operand(_FieldDef): constraint: ConstraintExpr variadicity: Variadicity = Variadicity.single @dataclass -class Result(FieldDef): +class Result(_FieldDef): constraint: ConstraintExpr variadicity: Variadicity = Variadicity.single @dataclass -class Attribute(FieldDef): +class Attribute(_FieldDef): constraint: ConstraintExpr def __post_init__(self): From 7de9ad40b6c293efc2bbb8f51140b82ad17522b2 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 13:18:07 +0800 Subject: [PATCH 3/8] refactor some methods --- mlir/python/mlir/dialects/irdl/dsl.py | 63 ++++++++++++++++----------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index b1916fdbd5f9b..307609c1b342e 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -127,13 +127,16 @@ class Operation: operands_and_attrs: List[Union[Operand, Attribute]] results: List[Result] - def _emit(self) -> None: - op = _irdl.operation_(self.name) - ctx = ConstraintLoweringContext() - + def _partition_operands_and_attrs(self) -> Tuple[List[Operand], List[Attribute]]: operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] + return operands, attrs + + def _emit(self) -> None: + ctx = ConstraintLoweringContext() + operands, attrs = self._partition_operands_and_attrs() + op = _irdl.operation_(self.name) with _ods_ir.InsertionPoint(op.body): if operands: _irdl.operands_( @@ -153,27 +156,26 @@ def _emit(self) -> None: [i.variadicity for i in self.results], ) - def _make_op_view_and_builder(self) -> Tuple[type, Callable]: - operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] - attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] - - def variadicity_to_segment(variadicity: Variadicity) -> int: - if variadicity == Variadicity.variadic: - return -1 - if variadicity == Variadicity.optional: - return 0 - return 1 - - operand_segments = None - if any(i.variadicity != Variadicity.single for i in operands): - operand_segments = [variadicity_to_segment(i.variadicity) for i in operands] - - result_segments = None - if any(i.variadicity != Variadicity.single for i in self.results): - result_segments = [ - variadicity_to_segment(i.variadicity) for i in self.results + @staticmethod + def _variadicity_to_segment(variadicity: Variadicity) -> int: + if variadicity == Variadicity.variadic: + return -1 + if variadicity == Variadicity.optional: + return 0 + return 1 + + @staticmethod + def _generate_segments( + operands_or_results: List[Union[Operand, Result]], + ) -> List[int]: + if any(i.variadicity != Variadicity.single for i in operands_or_results): + return [ + Operation._variadicity_to_segment(i.variadicity) + for i in operands_or_results ] + return None + def _generate_init_params(self) -> List[_Parameter]: args = self.results + self.operands_and_attrs positional_args = [ i.name for i in args if i.variadicity != Variadicity.optional @@ -188,7 +190,16 @@ def variadicity_to_segment(variadicity: Variadicity) -> int: params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None)) params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None)) - sig = _Signature(params) + return params + + def _make_op_view_and_builder(self) -> Tuple[type, Callable]: + operands, attrs = self._partition_operands_and_attrs() + + operand_segments = Operation._generate_segments(operands) + result_segments = Operation._generate_segments(self.results) + + params = self._generate_init_params() + init_sig = _Signature(params) op = self class _OpView(_ods_ir.OpView): @@ -198,7 +209,7 @@ class _OpView(_ods_ir.OpView): _ODS_RESULT_SEGMENTS = result_segments def __init__(*args, **kwargs): - bound = sig.bind(*args, **kwargs) + bound = init_sig.bind(*args, **kwargs) bound.apply_defaults() args = bound.arguments @@ -226,7 +237,7 @@ def __init__(*args, **kwargs): ip=args["ip"], ) - __init__.__signature__ = sig + __init__.__signature__ = init_sig for attr in attrs: setattr( From 985cd25e1486ea4e8a7a67f76daf835172c58112 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 14:05:26 +0800 Subject: [PATCH 4/8] more refactor --- mlir/python/mlir/dialects/irdl/dsl.py | 47 ++++++++++++++------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index 307609c1b342e..0a30678e1318c 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -124,17 +124,17 @@ class Operation: name: str # We store operands and attributes into one list to maintain relative orders # among them for generating OpView class. - operands_and_attrs: List[Union[Operand, Attribute]] - results: List[Result] + fields: List[Union[Operand, Attribute, Result]] - def _partition_operands_and_attrs(self) -> Tuple[List[Operand], List[Attribute]]: - operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] - attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] - return operands, attrs + def _partition_fields(self) -> Tuple[List[Operand], List[Attribute], List[Result]]: + operands = [i for i in self.fields if isinstance(i, Operand)] + attrs = [i for i in self.fields if isinstance(i, Attribute)] + results = [i for i in self.fields if isinstance(i, Result)] + return operands, attrs, results def _emit(self) -> None: ctx = ConstraintLoweringContext() - operands, attrs = self._partition_operands_and_attrs() + operands, attrs, results = self._partition_fields() op = _irdl.operation_(self.name) with _ods_ir.InsertionPoint(op.body): @@ -149,11 +149,11 @@ def _emit(self) -> None: [ctx.lower(i.constraint) for i in attrs], [i.name for i in attrs], ) - if self.results: + if results: _irdl.results_( - [ctx.lower(i.constraint) for i in self.results], - [i.name for i in self.results], - [i.variadicity for i in self.results], + [ctx.lower(i.constraint) for i in results], + [i.name for i in results], + [i.variadicity for i in results], ) @staticmethod @@ -176,7 +176,11 @@ def _generate_segments( return None def _generate_init_params(self) -> List[_Parameter]: - args = self.results + self.operands_and_attrs + # results are placed at the beginning of the parameter list, + # but operands and attributes can appear in any relative order. + args = [i for i in self.fields if isinstance(i, Result)] + [ + i for i in self.fields if not isinstance(i, Result) + ] positional_args = [ i.name for i in args if i.variadicity != Variadicity.optional ] @@ -193,10 +197,10 @@ def _generate_init_params(self) -> List[_Parameter]: return params def _make_op_view_and_builder(self) -> Tuple[type, Callable]: - operands, attrs = self._partition_operands_and_attrs() + operands, attrs, results = self._partition_fields() operand_segments = Operation._generate_segments(operands) - result_segments = Operation._generate_segments(self.results) + result_segments = Operation._generate_segments(results) params = self._generate_init_params() init_sig = _Signature(params) @@ -214,7 +218,7 @@ def __init__(*args, **kwargs): args = bound.arguments _operands = [args[operand.name] for operand in operands] - _results = [args[result.name] for result in op.results] + _results = [args[result.name] for result in results] _attributes = dict( (attr.name, args[attr.name]) for attr in attrs @@ -272,7 +276,7 @@ def getter(self, i=i, operand=operand): setattr( _OpView, operand.name, property(lambda self, i=i: self.operands[i]) ) - for i, result in enumerate(self.results): + for i, result in enumerate(results): if result_segments: def getter(self, i=i, result=result): @@ -332,16 +336,13 @@ def load(self) -> _SimpleNameSpace: def op(self, name: str) -> Callable[[type], type]: def decorator(cls: type) -> type: - operands_and_attrs: List[Union[Operand, Attribute]] = [] - results: List[Result] = [] + fields: List[Union[Operand, Attribute, Result]] = [] for field in cls.__dict__.values(): - if isinstance(field, Operand) or isinstance(field, Attribute): - operands_and_attrs.append(field) - elif isinstance(field, Result): - results.append(field) + if isinstance(field, _FieldDef): + fields.append(field) - op_def = Operation(self.name, name, operands_and_attrs, results) + op_def = Operation(self.name, name, fields) op_view, builder = op_def._make_op_view_and_builder() setattr(op_def, "op_view", op_view) setattr(op_def, "builder", builder) From 2405fe4bd522c21ca7e71bd35645d49b938da570 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 14:10:48 +0800 Subject: [PATCH 5/8] more refactor --- mlir/python/mlir/dialects/irdl/dsl.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index 0a30678e1318c..d98d82d1ba4c9 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -336,20 +336,20 @@ def load(self) -> _SimpleNameSpace: def op(self, name: str) -> Callable[[type], type]: def decorator(cls: type) -> type: - fields: List[Union[Operand, Attribute, Result]] = [] - - for field in cls.__dict__.values(): - if isinstance(field, _FieldDef): - fields.append(field) - + fields = [ + field for field in cls.__dict__.values() if isinstance(field, _FieldDef) + ] op_def = Operation(self.name, name, fields) + op_view, builder = op_def._make_op_view_and_builder() + op_view.__name__ = cls.__name__ setattr(op_def, "op_view", op_view) setattr(op_def, "builder", builder) self.operations.append(op_def) + self.namespace.__dict__[cls.__name__] = op_view - op_view.__name__ = cls.__name__ self.namespace.__dict__[name.replace(".", "_")] = builder + return cls return decorator From d74e6603052698a35b3efc84c303d91d28d9a9df Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 14:16:18 +0800 Subject: [PATCH 6/8] add a method to convert op name --- mlir/python/mlir/dialects/irdl/dsl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index d98d82d1ba4c9..2092c2417145a 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -334,6 +334,10 @@ def load(self) -> _SimpleNameSpace: _ods_cext.register_operation(dialect_class)(op.op_view) return self.namespace + @staticmethod + def op_name_to_python_id(name): + return name.replace(".", "_").replace("$", "_") + def op(self, name: str) -> Callable[[type], type]: def decorator(cls: type) -> type: fields = [ @@ -348,7 +352,7 @@ def decorator(cls: type) -> type: self.operations.append(op_def) self.namespace.__dict__[cls.__name__] = op_view - self.namespace.__dict__[name.replace(".", "_")] = builder + self.namespace.__dict__[Dialect.op_name_to_python_id(name)] = builder return cls From bff9fb79903788ad4d2115def90ac2b1771fa81a Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 14:16:58 +0800 Subject: [PATCH 7/8] make it private --- mlir/python/mlir/dialects/irdl/dsl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index 2092c2417145a..0db7cbbed81ca 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -335,7 +335,7 @@ def load(self) -> _SimpleNameSpace: return self.namespace @staticmethod - def op_name_to_python_id(name): + def _op_name_to_python_id(name): return name.replace(".", "_").replace("$", "_") def op(self, name: str) -> Callable[[type], type]: @@ -352,7 +352,7 @@ def decorator(cls: type) -> type: self.operations.append(op_def) self.namespace.__dict__[cls.__name__] = op_view - self.namespace.__dict__[Dialect.op_name_to_python_id(name)] = builder + self.namespace.__dict__[Dialect._op_name_to_python_id(name)] = builder return cls From 4c4c81cf9f4b9f99d52053b4267ce1db8689d083 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Mon, 24 Nov 2025 14:17:25 +0800 Subject: [PATCH 8/8] add type annotation --- mlir/python/mlir/dialects/irdl/dsl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/irdl/dsl.py b/mlir/python/mlir/dialects/irdl/dsl.py index 0db7cbbed81ca..41520f0e0f68c 100644 --- a/mlir/python/mlir/dialects/irdl/dsl.py +++ b/mlir/python/mlir/dialects/irdl/dsl.py @@ -335,7 +335,7 @@ def load(self) -> _SimpleNameSpace: return self.namespace @staticmethod - def _op_name_to_python_id(name): + def _op_name_to_python_id(name: str) -> str: return name.replace(".", "_").replace("$", "_") def op(self, name: str) -> Callable[[type], type]: