Skip to content

Commit

Permalink
[mlir] enable python bindings for nvgpu transforms (#68088)
Browse files Browse the repository at this point in the history
Expose the autogenerated bindings.

Co-authored-by: Martin Lücke <mluecke@google.com>
  • Loading branch information
2 people authored and pull[bot] committed Feb 10, 2024
1 parent 8219d8e commit 7e9c948
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Expand Up @@ -200,6 +200,15 @@ declare_mlir_dialect_extension_python_bindings(
DIALECT_NAME transform
EXTENSION_NAME memref_transform)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/NVGPUTransformOps.td
SOURCES
dialects/transform/nvgpu.py
DIALECT_NAME transform
EXTENSION_NAME nvgpu_transform)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
20 changes: 20 additions & 0 deletions mlir/python/mlir/dialects/NVGPUTransformOps.td
@@ -0,0 +1,20 @@
//===-- NVGPUTransformOps.td -------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the Python bindings generator for the transform ops provided
// by the NVGPU dialect.
//
//===----------------------------------------------------------------------===//


#ifndef PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS
#define PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS

include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td"

#endif // PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS
5 changes: 5 additions & 0 deletions mlir/python/mlir/dialects/transform/nvgpu.py
@@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._nvgpu_transform_ops_gen import *
27 changes: 27 additions & 0 deletions mlir/test/python/dialects/transform_nvgpu_ext.py
@@ -0,0 +1,27 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import nvgpu


def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
print(module)
return f


@run
def testCreateAsyncGroups():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
nvgpu.CreateAsyncGroupsOp(transform.AnyOpType.get(), sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testCreateAsyncGroups
# CHECK: transform.nvgpu.create_async_groups
20 changes: 20 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
Expand Up @@ -1209,6 +1209,25 @@ gentbl_filegroup(
],
)

gentbl_filegroup(
name = "NVGPUTransformOpsPyGen",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=transform",
"-dialect-extension=nvgpu_transform",
],
"mlir/dialects/_nvgpu_transform_ops_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/NVGPUTransformOps.td",
deps = [
"//mlir:NVGPUTransformOpsTdFiles",
],
)

gentbl_filegroup(
name = "PDLTransformOpsPyGen",
tbl_outs = [
Expand Down Expand Up @@ -1327,6 +1346,7 @@ filegroup(
":GPUTransformOpsPyGen",
":LoopTransformOpsPyGen",
":MemRefTransformOpsPyGen",
":NVGPUTransformOpsPyGen",
":PDLTransformOpsPyGen",
":SparseTensorTransformOpsPyGen",
":StructureTransformEnumPyGen",
Expand Down

0 comments on commit 7e9c948

Please sign in to comment.