Skip to content

Commit

Permalink
[mlir][linalg] ValueBoundsOpInterface: Add LinalgOps
Browse files Browse the repository at this point in the history
Also add a few more complex test cases.

Differential Revision: https://reviews.llvm.org/D145806
  • Loading branch information
matthias-springer committed Apr 7, 2023
1 parent 10dbf23 commit edc8b60
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,20 @@
//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace linalg {
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Expand Up @@ -41,6 +41,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
Expand Down Expand Up @@ -141,6 +142,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerBufferizableOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
LinalgInterfaces.cpp
LinalgOps.cpp
LinalgDialect.cpp
ValueBoundsOpInterfaceImpl.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
Expand Down Expand Up @@ -30,5 +31,6 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRMemRefDialect
MLIRTensorDialect
MLIRTilingInterface
MLIRValueBoundsOpInterface
MLIRViewLikeInterface
)
43 changes: 43 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,43 @@
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

using namespace mlir;

namespace mlir {
namespace linalg {
namespace {

/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
/// the `ValueBoundsOpInterface` with each of them.
template <typename... Ops> struct LinalgValueBoundsOpInterfaceHelper {
static void registerOpInterface(MLIRContext *ctx) {
(Ops::template attachInterface<DstValueBoundsOpInterfaceExternalModel<Ops>>(
*ctx),
...);
}
};

} // namespace
} // namespace linalg
} // namespace mlir

void mlir::linalg::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
// Register all Linalg structured ops.
LinalgValueBoundsOpInterfaceHelper<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
});
}
79 changes: 79 additions & 0 deletions mlir/test/Dialect/Affine/value-bounds-reification.mlir
Expand Up @@ -20,3 +20,82 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index

return %4, %5, %6 : index, index, index
}

// -----

// CHECK-LABEL: func @reify_slice_bound(
// CHECK: %[[c5:.*]] = arith.constant 5 : index
// CHECK: "test.some_use"(%[[c5]])
func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f32) {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
scf.for %iv = %c0 to %ub step %c4 {
%sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub]
%slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor<?x?xi32> to tensor<1x?xi32>
%filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32>
%bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%bound) : (index) -> ()
}
return
}

// -----

// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 - s1 + 1)>
// CHECK-LABEL: func @scf_for(
// CHECK-SAME: %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index
// CHECK: %[[bound:.*]] = affine.apply #[[$map]]()[%[[ub]], %[[lb]]]
// CHECK: "test.some_use"(%[[bound]])
func.func @scf_for(%lb: index, %ub: index, %step: index) {
scf.for %iv = %lb to %ub step %step {
%0 = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv)[%ub]
%bound = "test.reify_bound"(%0) {type = "UB"} : (index) -> (index)
"test.some_use"(%bound) : (index) -> ()
}
return
}

// -----

// CHECK-LABEL: func @reify_slice_bound2(
func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
%ub2: index, %t1: tensor<1x?xi8>,
%t2: tensor<?x?xi8>, %t3: tensor<1x?xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
scf.for %iv0 = %lb0 to %ub0 step %step0 {
// CHECK: %[[c129:.*]] = arith.constant 129 : index
// CHECK: "test.some_use"(%[[c129]])
%ub1 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%iv0)[%ub0]
%ub1_ub = "test.reify_bound"(%ub1) {type = "UB"} : (index) -> (index)
"test.some_use"(%ub1_ub) : (index) -> ()

// CHECK: %[[c129:.*]] = arith.constant 129 : index
// CHECK: "test.some_use"(%[[c129]])
%lb1 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 32)>()[%ub1]
%lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index)
"test.some_use"(%lb1_ub) : (index) -> ()

scf.for %iv1 = %lb1 to %ub1 step %c32 {
// CHECK: %[[c32:.*]] = arith.constant 32 : index
// CHECK: "test.some_use"(%[[c32]])
%sz = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv1)[%ub1]
%sz_ub = "test.reify_bound"(%sz) {type = "UB"} : (index) -> (index)
"test.some_use"(%sz_ub) : (index) -> ()

scf.for %iv2 = %c0 to %ub2 step %c1 {
%slice1 = tensor.extract_slice %t1[0, %iv2] [1, 1] [1, 1] : tensor<1x?xi8> to tensor<1x1xi8>
%slice2 = tensor.extract_slice %t2[%iv2, 0] [1, %sz] [1, 1] : tensor<?x?xi8> to tensor<1x?xi8>
%slice3 = tensor.extract_slice %t3[0, 0] [1, %sz] [1, 1] : tensor<1x?xi32> to tensor<1x?xi32>
%matmul = linalg.matmul ins(%slice1, %slice2 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%slice3 : tensor<1x?xi32>) -> tensor<1x?xi32>

// CHECK: %[[c32:.*]] = arith.constant 32 : index
// CHECK: "test.some_use"(%[[c32]])
%matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
"test.some_use"(%matmul_ub) : (index) -> ()
}
}
}
return
}
13 changes: 13 additions & 0 deletions mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir
@@ -0,0 +1,13 @@
// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
// RUN: -split-input-file | FileCheck %s

// CHECK-LABEL: func @linalg_fill(
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]]
// CHECK: return %[[dim]]
func.func @linalg_fill(%t: tensor<?xf32>, %f: f32) -> index {
%0 = linalg.fill ins(%f : f32) outs(%t : tensor<?xf32>) -> tensor<?xf32>
%1 = "test.reify_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
return %1 : index
}
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Expand Up @@ -8654,6 +8654,7 @@ cc_library(
":Support",
":TensorDialect",
":TilingInterface",
":ValueBoundsOpInterface",
":ViewLikeInterface",
"//llvm:Support",
],
Expand Down

0 comments on commit edc8b60

Please sign in to comment.