Skip to content

Commit

Permalink
[mlir] provide Python bindings for the Transform dialect
Browse files Browse the repository at this point in the history
Python bindings for extensions of the Transform dialect are defined in separate
Python source files that can be imported on-demand, i.e., that are not imported
with the "main" transform dialect. This requires a minor addition to the
ODS-based bindings generator. This approach is consistent with the current
model for downstream projects that are expected to bundle MLIR Python bindings:
such projects can include their custom extensions into the bundle similarly to
how they include their dialects.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D126208
  • Loading branch information
ftynse committed May 30, 2022
1 parent cc6c159 commit 3f71765
Show file tree
Hide file tree
Showing 12 changed files with 690 additions and 2 deletions.
55 changes: 55 additions & 0 deletions mlir/cmake/modules/AddMLIRPython.cmake
Expand Up @@ -355,6 +355,61 @@ function(declare_mlir_dialect_python_bindings)
endif()
endfunction()

# Function: declare_mlir_dialect_extension_python_bindings
# Helper to generate source groups for dialect extensions, including both
# static source files and a TD_FILE to generate wrappers.
#
# This will generate a source group named ${ADD_TO_PARENT}.${EXTENSION_NAME}.
#
# Arguments:
# ROOT_DIR: Same as for declare_mlir_python_sources().
# ADD_TO_PARENT: Same as for declare_mlir_python_sources(). Unique names
# for the subordinate source groups are derived from this.
# TD_FILE: Tablegen file to generate source for (relative to ROOT_DIR).
# DIALECT_NAME: Python name of the dialect.
# EXTENSION_NAME: Python name of the dialect extension.
# SOURCES: Same as declare_mlir_python_sources().
# SOURCES_GLOB: Same as declare_mlir_python_sources().
# DEPENDS: Additional dependency targets.
function(declare_mlir_dialect_extension_python_bindings)
cmake_parse_arguments(ARG
""
"ROOT_DIR;ADD_TO_PARENT;TD_FILE;DIALECT_NAME;EXTENSION_NAME"
"SOURCES;SOURCES_GLOB;DEPENDS"
${ARGN})
# Source files.
set(_extension_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}")
declare_mlir_python_sources(${_extension_target}
ROOT_DIR "${ARG_ROOT_DIR}"
ADD_TO_PARENT "${ARG_ADD_TO_PARENT}"
SOURCES "${ARG_SOURCES}"
SOURCES_GLOB "${ARG_SOURCES_GLOB}"
)

# Tablegen
if(ARG_TD_FILE)
set(tblgen_target "${ARG_ADD_TO_PARENT}.${ARG_EXTENSION_NAME}.tablegen")
set(td_file "${ARG_ROOT_DIR}/${ARG_TD_FILE}")
get_filename_component(relative_td_directory "${ARG_TD_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${relative_td_directory}")
set(output_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_ops_gen.py")
set(LLVM_TARGET_DEFINITIONS ${td_file})
mlir_tablegen("${output_filename}" -gen-python-op-bindings
-bind-dialect=${ARG_DIALECT_NAME}
-dialect-extension=${ARG_EXTENSION_NAME})
add_public_tablegen_target(${tblgen_target})
if(ARG_DEPENDS)
add_dependencies(${tblgen_target} ${ARG_DEPENDS})
endif()

declare_mlir_python_sources("${_extension_target}.ops_gen"
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT "${_extension_target}"
SOURCES "${output_filename}"
)
endif()
endfunction()

# Function: mlir_python_setup_extension_rpath
# Sets RPATH properties on a target, assuming that it is being output to
# an _mlir_libs directory with all other libraries. For static linkage,
Expand Down
19 changes: 19 additions & 0 deletions mlir/python/CMakeLists.txt
Expand Up @@ -116,6 +116,25 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME linalg
DEPENDS LinalgOdsGen)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
DIALECT_NAME transform)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
21 changes: 21 additions & 0 deletions mlir/python/mlir/dialects/LinalgStructuredTransformOps.td
@@ -0,0 +1,21 @@
//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the Python bindings generator for the structured transform ops
// provided by Linalg (and other dialects).
//
//===----------------------------------------------------------------------===//


#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS

include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td"

#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS
15 changes: 15 additions & 0 deletions mlir/python/mlir/dialects/TransformOps.td
@@ -0,0 +1,15 @@
//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_TRANSFORM_OPS
#define PYTHON_BINDINGS_TRANSFORM_OPS

include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Transform/IR/TransformOps.td"

#endif // PYTHON_BINDINGS_TRANSFORM_OPS
178 changes: 178 additions & 0 deletions mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -0,0 +1,178 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import List, Optional, Sequence, Union

IntOrAttrList = Sequence[Union[IntegerAttr, int]]
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]


def _get_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
"""Creates an array attribute from its operand."""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values

return ArrayAttr.get(values)


def _get_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]]
) -> ArrayAttr:
"""Creates an integer array attribute from its operand.
If the operand is already an array attribute, forwards it. Otherwise treats
the operand as a list of attributes or integers, possibly intersperced, to
create a new array attribute containing integer attributes. Expects the
thread-local MLIR context to have been set by the context manager.
"""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values

attributes = []
for value in values:
if isinstance(value, IntegerAttr):
attributes.append(value)
else:
attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
return ArrayAttr.get(attributes)


def _get_int_int_array_attr(
values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an array attribute containing array attributes of integers.
If the operand is already an array attribute, forwards it. Otherwise treats
the operand as a list of attributes or integers, potentially interpserced, to
create a new array-of-array attribute. Expects the thread-local MLIR context
to have been set by the context manager.
"""
if values is None:
return ArrayAttr.get([])
if isinstance(values, ArrayAttr):
return values

return ArrayAttr.get([_get_int_array_attr(value) for value in values])


class InterchangeOp:
"""Specialization for InterchangeOp class."""

def __init__(self,
target: Union[Operation, Value],
*,
iterator_interchange: OptionalIntList = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
interchange_attr = _get_int_array_attr(iterator_interchange)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
iterator_interchange=interchange_attr,
loc=loc,
ip=ip)


class PadOp:
"""Specialization for PadOp class."""

def __init__(self,
target: Union[Operation, Value],
*,
padding_values: Optional[Union[ArrayAttr,
Sequence[Attribute]]] = None,
padding_dimensions: OptionalIntList = None,
pack_paddings: OptionalIntList = None,
hoist_paddings: OptionalIntList = None,
transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[
ArrayAttr, IntOrAttrList]]]] = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
padding_values_attr = _get_array_attr(padding_values)
padding_dimensions_attr = _get_int_array_attr(padding_dimensions)
pack_paddings_attr = _get_int_array_attr(pack_paddings)
hoist_paddings_attr = _get_int_array_attr(hoist_paddings)
transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
padding_values=padding_values_attr,
padding_dimensions=padding_dimensions_attr,
pack_paddings=pack_paddings_attr,
hoist_paddings=hoist_paddings_attr,
transpose_paddings=transpose_paddings_attr,
loc=loc,
ip=ip)


class ScalarizeOp:
"""Specialization for ScalarizeOp class."""

def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)


class TileOp:
"""Specialization for TileOp class."""

def __init__(self,
target: Union[Operation, Value],
*,
sizes: OptionalIntList = None,
interchange: OptionalIntList = None,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
sizes_attr = _get_int_array_attr(sizes)
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
super().__init__(
pdl_operation_type, [pdl_operation_type] * num_loops,
_get_op_result_or_value(target),
sizes=sizes_attr,
interchange=_get_int_array_attr(interchange) if interchange else None,
loc=loc,
ip=ip)

def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]:
if not attr:
return []
return [IntegerAttr(element).value for element in attr]


class VectorizeOp:
"""Specialization for VectorizeOp class."""

def __init__(self,
target: Union[Operation, Value],
*,
vectorize_padding: Union[bool, BoolAttr] = False,
loc=None,
ip=None):
pdl_operation_type = pdl.OperationType.get()
if isinstance(vectorize_padding, bool):
vectorize_padding = BoolAttr.get(vectorize_padding)
super().__init__(
pdl_operation_type,
_get_op_result_or_value(target),
vectorize_padding=vectorize_padding,
loc=loc,
ip=ip)

0 comments on commit 3f71765

Please sign in to comment.