Skip to content

Commit

Permalink
[mlir][python] Enable python bindings for Index dialect (#85827)
Browse files Browse the repository at this point in the history
This small patch enables python bindings for the index dialect.

---------

Co-authored-by: Steven Varoumas <steven.varoumas1@huawei.com>
  • Loading branch information
stevenvar and stevenvar committed Mar 20, 2024
1 parent d209d13 commit eb861ac
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Expand Up @@ -108,6 +108,15 @@ declare_mlir_dialect_python_bindings(
dialects/complex.py
DIALECT_NAME complex)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IndexOps.td
SOURCES
dialects/index.py
DIALECT_NAME index
GEN_ENUM_BINDINGS)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/IndexOps.td
@@ -0,0 +1,14 @@
//===-- IndexOps.td - Entry point for Index bindings -----*- tablegen -*---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_INDEX_OPS
#define PYTHON_BINDINGS_INDEX_OPS

include "mlir/Dialect/Index/IR/IndexOps.td"

#endif
6 changes: 6 additions & 0 deletions mlir/python/mlir/dialects/index.py
@@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._index_ops_gen import *
from ._index_enum_gen import *
235 changes: 235 additions & 0 deletions mlir/test/python/dialects/index_dialect.py
@@ -0,0 +1,235 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
from mlir.dialects import index, arith


def run(f):
print("\nTEST:", f.__name__)
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f(ctx)
print(module)


# CHECK-LABEL: TEST: testConstantOp
@run
def testConstantOp(ctx):
a = index.ConstantOp(value=42)
# CHECK: %{{.*}} = index.constant 42


# CHECK-LABEL: TEST: testBoolConstantOp
@run
def testBoolConstantOp(ctx):
a = index.BoolConstantOp(value=True)
# CHECK: %{{.*}} = index.bool.constant true


# CHECK-LABEL: TEST: testAndOp
@run
def testAndOp(ctx):
a = index.ConstantOp(value=42)
r = index.AndOp(a, a)
# CHECK: %{{.*}} = index.and %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testOrOp
@run
def testOrOp(ctx):
a = index.ConstantOp(value=42)
r = index.OrOp(a, a)
# CHECK: %{{.*}} = index.or %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testXOrOp
@run
def testXOrOp(ctx):
a = index.ConstantOp(value=42)
r = index.XOrOp(a, a)
# CHECK: %{{.*}} = index.xor %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testCastSOp
@run
def testCastSOp(ctx):
a = index.ConstantOp(value=42)
b = arith.ConstantOp(value=23, result=IntegerType.get_signless(64))
c = index.CastSOp(input=a, output=IntegerType.get_signless(32))
d = index.CastSOp(input=b, output=IndexType.get())
# CHECK: %{{.*}} = index.casts %{{.*}} : index to i32
# CHECK: %{{.*}} = index.casts %{{.*}} : i64 to index


# CHECK-LABEL: TEST: testCastUOp
@run
def testCastUOp(ctx):
a = index.ConstantOp(value=42)
b = arith.ConstantOp(value=23, result=IntegerType.get_signless(64))
c = index.CastUOp(input=a, output=IntegerType.get_signless(32))
d = index.CastUOp(input=b, output=IndexType.get())
# CHECK: %{{.*}} = index.castu %{{.*}} : index to i32
# CHECK: %{{.*}} = index.castu %{{.*}} : i64 to index


# CHECK-LABEL: TEST: testCeilDivSOp
@run
def testCeilDivSOp(ctx):
a = index.ConstantOp(value=42)
r = index.CeilDivSOp(a, a)
# CHECK: %{{.*}} = index.ceildivs %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testCeilDivUOp
@run
def testCeilDivUOp(ctx):
a = index.ConstantOp(value=42)
r = index.CeilDivUOp(a, a)
# CHECK: %{{.*}} = index.ceildivu %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testCmpOp
@run
def testCmpOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
r = index.CmpOp(pred, lhs=a, rhs=b)
# CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})


# CHECK-LABEL: TEST: testAddOp
@run
def testAddOp(ctx):
a = index.ConstantOp(value=42)
r = index.AddOp(a, a)
# CHECK: %{{.*}} = index.add %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testSubOp
@run
def testSubOp(ctx):
a = index.ConstantOp(value=42)
r = index.SubOp(a, a)
# CHECK: %{{.*}} = index.sub %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testMulOp
@run
def testMulOp(ctx):
a = index.ConstantOp(value=42)
r = index.MulOp(a, a)
# CHECK: %{{.*}} = index.mul %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testDivSOp
@run
def testDivSOp(ctx):
a = index.ConstantOp(value=42)
r = index.DivSOp(a, a)
# CHECK: %{{.*}} = index.divs %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testDivUOp
@run
def testDivUOp(ctx):
a = index.ConstantOp(value=42)
r = index.DivUOp(a, a)
# CHECK: %{{.*}} = index.divu %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testFloorDivSOp
@run
def testFloorDivSOp(ctx):
a = index.ConstantOp(value=42)
r = index.FloorDivSOp(a, a)
# CHECK: %{{.*}} = index.floordivs %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testMaxSOp
@run
def testMaxSOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
r = index.MaxSOp(a, b)
# CHECK: %{{.*}} = index.maxs %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testMaxUOp
@run
def testMaxUOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
r = index.MaxUOp(a, b)
# CHECK: %{{.*}} = index.maxu %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testMinSOp
@run
def testMinSOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
r = index.MinSOp(a, b)
# CHECK: %{{.*}} = index.mins %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testMinUOp
@run
def testMinUOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
r = index.MinUOp(a, b)
# CHECK: %{{.*}} = index.minu %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testRemSOp
@run
def testRemSOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
r = index.RemSOp(a, b)
# CHECK: %{{.*}} = index.rems %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testRemUOp
@run
def testRemUOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
r = index.RemUOp(a, b)
# CHECK: %{{.*}} = index.remu %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testShlOp
@run
def testShlOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=3)
r = index.ShlOp(a, b)
# CHECK: %{{.*}} = index.shl %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testShrSOp
@run
def testShrSOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=3)
r = index.ShrSOp(a, b)
# CHECK: %{{.*}} = index.shrs %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testShrUOp
@run
def testShrUOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=3)
r = index.ShrUOp(a, b)
# CHECK: %{{.*}} = index.shru %{{.*}}, %{{.*}}


# CHECK-LABEL: TEST: testSizeOfOp
@run
def testSizeOfOp(ctx):
r = index.SizeOfOp()
# CHECK: %{{.*}} = index.sizeof

0 comments on commit eb861ac

Please sign in to comment.