Skip to content

Commit

Permalink
[mlir][Linalg][Python] Create the body of builtin named Linalg ops
Browse files Browse the repository at this point in the history
This revision adds support to properly add the body of registered
builtin named linalg ops.
At this time, indexing_map and iterator_type support is still
missing so the op is not executable yet.

Differential Revision: https://reviews.llvm.org/D99578
  • Loading branch information
nicolasvasilache committed Mar 31, 2021
1 parent 465b9a4 commit 43b9fa3
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 14 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir-c/Dialect/Linalg.h
Expand Up @@ -17,6 +17,11 @@
extern "C" {
#endif

/// Apply the special region builder for the builtin named Linalg op.
/// Assert that `op` is a builtin named Linalg op.
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op);

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);

#ifdef __cplusplus
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Expand Up @@ -37,6 +37,14 @@ def Linalg_Dialect : Dialect {
let dependentDialects = [
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
];
let extraClassDeclaration = [{
using RegionBuilderFunType = llvm::function_ref<void(Block &, ValueRange)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
}
private:
llvm::StringMap<RegionBuilderFunType> namedStructuredOpRegionBuilders;
}];
}

// Whether a type is a RangeType.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/StringMap.h"

#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Bindings/Python/CMakeLists.txt
Expand Up @@ -69,6 +69,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
INSTALL_DIR
python
SOURCES
DialectLinalg.cpp
MainModule.cpp
IRAffine.cpp
IRAttributes.cpp
Expand Down
34 changes: 34 additions & 0 deletions mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -0,0 +1,34 @@
//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "IRModule.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"

#include <pybind11/pybind11.h>

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;

namespace mlir {
namespace python {

void populateDialectLinalgSubmodule(py::module &m) {
m.def(
"fill_builtin_region",
[](PyDialectDescriptor &dialect, PyOperation &op) {
return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get());
},
py::arg("dialect"), py::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
}

} // namespace python
} // namespace mlir
22 changes: 22 additions & 0 deletions mlir/lib/Bindings/Python/DialectLinalg.h
@@ -0,0 +1,22 @@
//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H

#include "PybindUtils.h"

namespace mlir {
namespace python {

void populateDialectLinalgSubmodule(pybind11::module &m);

} // namespace python
} // namespace mlir

#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H
6 changes: 6 additions & 0 deletions mlir/lib/Bindings/Python/MainModule.cpp
Expand Up @@ -10,6 +10,7 @@

#include "PybindUtils.h"

#include "DialectLinalg.h"
#include "ExecutionEngine.h"
#include "Globals.h"
#include "IRModule.h"
Expand Down Expand Up @@ -225,4 +226,9 @@ PYBIND11_MODULE(_mlir, m) {
auto executionEngineModule =
m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
populateExecutionEngineSubmodule(executionEngineModule);

// Define and populate Linalg submodule.
auto dialectsModule = m.def_submodule("dialects");
auto linalgModule = dialectsModule.def_submodule("linalg");
populateDialectLinalgSubmodule(linalgModule);
}
Expand Up @@ -61,11 +61,10 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
raise NotImplementedError(
f"Emission of composite linalg ops not supported: {op_configs}")

# TODO: this file should probably not be called dsl.py but rather is a client
# of the dsl.py.
from .... import linalg as linalg_ops
emit_generic = (emit_generic or
(not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys()))
ctx = ir.Context.current
linalgDialect = ctx.get_dialect_descriptor("linalg")
fully_qualified_name = 'linalg.' + self.op_name
emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name))

op_config = op_configs[0]
if op_config.structured_op:
Expand Down
Expand Up @@ -7,6 +7,9 @@
from mlir.ir import *
from mlir.dialects import linalg
from mlir.dialects import std
# TODO: resolve name collision for Linalg functionality that is injected inside
# the _mlir.dialects.linalg directly via pybind.
from _mlir.dialects.linalg import fill_builtin_region

from .scalar_expr import *
from .config import *
Expand All @@ -16,7 +19,6 @@
"emit_named_structured_op",
]


def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
outs: Value):
Expand Down Expand Up @@ -97,11 +99,18 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
type_mapping, indexing_maps_attr, iterator_types_attr = \
prepare_common_structured_op(op_config, *ins, outs = outs)

if not op_class_name in linalg.__dict__.keys():
# If we get here, there must exist a builtin class `op_class_name`.
ctx = Context.current
fully_qualified_name = 'linalg.' + op_name
if (not ctx.is_registered_operation(fully_qualified_name) or
not op_class_name in linalg.__dict__.keys()):
raise NotImplementedError(
f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}")

named_op = getattr(linalg, op_class_name)(ins, outs, out_types)
linalgDialect = ctx.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, named_op.operation)

if len(out_arg_defs) == 1:
return named_op.result
else:
Expand Down
29 changes: 27 additions & 2 deletions mlir/lib/CAPI/Dialect/Linalg.cpp
Expand Up @@ -10,5 +10,30 @@
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg,
mlir::linalg::LinalgDialect)
using namespace mlir;
using namespace mlir::linalg;

/// Apply the special region builder for the builtin named Linalg op.
/// Assert that `op` is a builtin named Linalg op.
void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect,
MlirOperation mlirOp) {
Operation *op = unwrap(mlirOp);
LinalgDialect::RegionBuilderFunType fun =
static_cast<LinalgDialect *>(unwrap(linalgDialect))
->getRegionBuilder(op->getName().getStringRef());
assert(fun && "Expected a builtin named Linalg op.");
assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region");
assert(op->getRegion(0).getBlocks().empty() &&
"Expected Linalg op with 0 blocks");
SmallVector<Type, 8> argTypes;
auto linalgOp = cast<LinalgOp>(op);
for (auto t : linalgOp.getShapedOperandTypes())
argTypes.push_back(getElementTypeOrSelf(t));
OpBuilder b(op->getContext());
Region &region = op->getRegion(0);
Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes);
// TODO: allow captures.
fun(*body, ValueRange{});
}

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
Expand Up @@ -57,6 +57,38 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
// LinalgDialect
//===----------------------------------------------------------------------===//

/// Trait to check if T provides a `regionBuilder` method.
template <typename T, typename... Args>
using has_region_builder = decltype(T::regionBuilder);
template <typename T>
using detect_has_region_builder = llvm::is_detected<has_region_builder, T>;

/// SFINAE helper for single C++ class without a `regionBuilder` method (e.g.
/// an OpInterface).
template <typename OpType, typename = std::enable_if_t<
!detect_has_region_builder<OpType>::value>>
void addNamedOpBuilderImpl(
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
// Do nothing.
}

template <typename OpType,
typename = std::enable_if_t<detect_has_region_builder<OpType>::value>,
typename = void>
void addNamedOpBuilderImpl(
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
map.insert(std::make_pair(
OpType::getOperationName(),
static_cast<LinalgDialect::RegionBuilderFunType>(OpType::regionBuilder)));
}

template <typename... OpTypes>
void addNamedOpBuilders(
llvm::StringMap<LinalgDialect::RegionBuilderFunType> &map) {
(void)std::initializer_list<int>{0,
(addNamedOpBuilderImpl<OpTypes>(map), 0)...};
}

void mlir::linalg::LinalgDialect::initialize() {
addTypes<RangeType>();
addOperations<
Expand All @@ -72,6 +104,12 @@ void mlir::linalg::LinalgDialect::initialize() {
#include "mlir/Dialect/Linalg/IR/LinalgSparseOps.cpp.inc"
>();

// Fill the Linalg-specific OpName to RegionBuilder map.
addNamedOpBuilders<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(namedStructuredOpRegionBuilders);

addInterfaces<LinalgInlinerInterface>();
}

Expand Down
44 changes: 39 additions & 5 deletions mlir/test/Bindings/Python/dialects/linalg/ops.py
Expand Up @@ -5,7 +5,6 @@
from mlir.dialects import linalg
from mlir.dialects import std


def run(f):
print("\nTEST:", f.__name__)
f()
Expand Down Expand Up @@ -82,9 +81,9 @@ def testStructuredOpOnBuffers():
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
print(module)

# CHECK-LABEL: TEST: testNamedStructuredOp
# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
@run
def testNamedStructuredOp():
def testNamedStructuredOpCustomForm():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
Expand All @@ -93,10 +92,45 @@ def testNamedStructuredOp():
RankedTensorType.get((16, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
# CHECK: linalg.matmul
# TODO: prperly hook up the region.
# First check the named form with custom format
# CHECK: linalg.matmul
# CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>)
# CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>)
# CHECK-SAME: -> tensor<4x8xf32>
# CHECK-NEXT: return
return linalg.matmul(lhs, rhs, outs=[init_result.result])

print(module)

# CHECK-LABEL: TEST: testNamedStructuredOpGenericForm
@run
def testNamedStructuredOpGenericForm():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def named_form(lhs, rhs):
init_result = linalg.InitTensorOp([4, 8], f32)
# CHECK: "linalg.matmul"(%{{.*}})
# CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32):
# CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
# CHECK-NEXT: {operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])

module.operation.print(print_generic_op_form=True)

# CHECK-LABEL: TEST: testNamedStructuredAsGenericOp
@run
def testNamedStructuredAsGenericOp():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def generic_form(lhs, rhs):
Expand Down

0 comments on commit 43b9fa3

Please sign in to comment.