Skip to content

Commit

Permalink
Merge 65e10b1 into a3b74bc
Browse files Browse the repository at this point in the history
  • Loading branch information
IanWood1 committed May 17, 2024
2 parents a3b74bc + 65e10b1 commit 724cbc0
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
#include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -390,8 +394,10 @@ static bool hasCompatibleOuterParallelLoops(
// relationship through `operand` have compatible outer-parallel loops.
static bool hasCompatibleOuterParallelLoops(
OpOperand &operand, const llvm::SmallBitVector &rootOuterParallelLoops) {
auto producer = operand.get().getDefiningOp<linalg::LinalgOp>();
auto consumer = dyn_cast<linalg::LinalgOp>(operand.getOwner());
auto producer =
operand.get().getDefiningOp<LinalgExt::LinalgFusionOpInterface>();
auto consumer =
dyn_cast<LinalgExt::LinalgFusionOpInterface>(operand.getOwner());
if (!producer || !consumer)
return false;

Expand Down Expand Up @@ -605,8 +611,10 @@ isFusableWithConsumer(OpOperand &fusedOperand,
return false;
}

auto producerLinalgOp = dyn_cast<linalg::LinalgOp>(producer);
auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer);
auto producerLinalgOp =
dyn_cast<LinalgExt::LinalgFusionOpInterface>(producer);
auto consumerLinalgOp =
dyn_cast<LinalgExt::LinalgFusionOpInterface>(consumer);
if (!producerLinalgOp || !consumerLinalgOp)
return false;

Expand Down Expand Up @@ -744,12 +752,13 @@ isFusableWithProducer(OpOperand &operand,
.Default([](Operation *) { return false; });
}

if (!isa<linalg::LinalgOp>(consumer) || !isa<linalg::LinalgOp>(producer)) {
if (!isa<LinalgExt::LinalgFusionOpInterface>(consumer) ||
!isa<LinalgExt::LinalgFusionOpInterface>(producer)) {
return false;
}

if (!options.aggressiveFusion) {
auto consumerLinalgOp = cast<linalg::LinalgOp>(consumer);
auto consumerLinalgOp = cast<LinalgExt::LinalgFusionOpInterface>(consumer);
if (!consumerLinalgOp.isDpsInit(&operand)) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def FormDispatchRegionsPass :
"mlir::scf::SCFDialect",
"mlir::tensor::TensorDialect",
"IREE::Flow::FlowDialect",
"IREE::LinalgExt::IREELinalgExtDialect",
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_lit_test_suite(
"form_dispatch_regions.mlir",
"form_dispatch_workgroups.mlir",
"form_scalar_dispatches.mlir",
"dispatch_linalg_ext_fusion.mlir",
"fusion_of_tensor_ops.mlir",
"fusion_preprocessing.mlir",
"initialize_empty_tensors.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"collapse_reduction.mlir"
"convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
"dispatch_linalg_ext_fusion.mlir"
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_default.mlir"
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-form-dispatch-workgroups), cse, canonicalize, cse)" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
%3 = tensor.empty() : tensor<4x1xi32>
%expanded = tensor.empty() : tensor<4x1xi64>
%expanded_0 = tensor.empty() : tensor<4x1x16x8x128xf32>
%2 = tensor.empty() : tensor<8192x16x8x128xf32>
%result = tensor.empty() : tensor<8192x16x8x128xf32>

%4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%expanded : tensor<4x1xi64>)
outs(%3 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<4x1xi32>
%5 = iree_linalg_ext.scatter
dimension_map = [0]
unique_indices(false)
ins(%expanded_0, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
outs(%2 : tensor<8192x16x8x128xf32>) {
^bb0(%arg5: f32, %arg6: f32):
iree_linalg_ext.yield %arg5 : f32
} -> tensor<8192x16x8x128xf32>
%6 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<8192x16x8x128xf32>) outs(%result : tensor<8192x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<8192x16x8x128xf32>

util.return %6 : tensor<8192x16x8x128xf32>
}

// CHECK: util.func public @linalgext_scatter_fusion
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK: %[[EXPANDED:.+]] = linalg.generic
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[UPDATE_TENSOR:.+]], %[[GEN:.+]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
// CHECK: %[[GEN:.+]] = linalg.generic
// CHECK-SAME: ins(%[[SCATTER_RESULT]] : tensor<8192x16x8x128xf32>)


// -----


#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
%input = tensor.empty() : tensor<10x10xi64>
%shrunk = tensor.empty() : tensor<10x10xi32>

%4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%input: tensor<10x10xi64>)
outs(%shrunk : tensor<10x10xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<10x10xi32>
%reversed_outs = tensor.empty() : tensor<10x10xi32>
%reversed = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%4 : tensor<10x10xi32>) outs(%reversed_outs : tensor<10x10xi32>) : tensor<10x10xi32>
%generic_outs = tensor.empty() : tensor<10x10xi32>
%6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]}
ins(%reversed : tensor<10x10xi32>)
outs(%generic_outs : tensor<10x10xi32>) {
^bb0(%in: i32, %out: i32):
%10 = arith.addi %in, %out : i32
linalg.yield %10 : i32
} -> tensor<10x10xi32>

util.return %6 : tensor<10x10xi32>
}

// COM: // CHECK: util.func public @linalgext_reverse_fusion
// COM: // CHECK: %[[SHRUNK:.+]] = linalg.generic
// COM: // CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse
// COM: // CHECK-SAME: ins(%[[SHRUNK]] : tensor<10x10xi32>)
// COM: // CHECK: %[[ADD:.+]] = linalg.generic
// COM: // CHECK-SAME: ins(%[[REVERSED]] : tensor<10x10xi32>)
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgStructuredOpsIncGen",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ iree_cc_library(
MLIRIR
MLIRInferTypeOpInterface
MLIRLinalgDialect
MLIRLinalgStructuredOpsIncGenLib
MLIRLinalgUtils
MLIRMathDialect
MLIRMemRefDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"

#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
Expand Down Expand Up @@ -46,9 +49,69 @@ struct IREELinalgExtInlinerInterface : public DialectInlinerInterface {
}
};

// Used to register the LinalgFusionOpInterface with the linalg ops.
namespace {
template <typename ConcreteType>
struct LinalgFusionOpInterfaceAdapter
: public LinalgFusionOpInterface::ExternalModel<
LinalgFusionOpInterfaceAdapter<ConcreteType>, ConcreteType> {
public:
// Forward all the interface methods to the corresponding linalg op.
unsigned getNumParallelLoops(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getNumParallelLoops());
}

unsigned getNumLoops(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getNumLoops());
}

SmallVector<int64_t, 4> getStaticLoopRanges(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getStaticLoopRanges());
}

AffineMap getIndexingMapMatchingResult(mlir::Operation *op,
OpResult result) const {
return (llvm::cast<ConcreteType>(op).getIndexingMapMatchingResult(result));
}

AffineMap getMatchingIndexingMap(mlir::Operation *op,
OpOperand *operand) const {
return (llvm::cast<ConcreteType>(op).getMatchingIndexingMap(operand));
}

ArrayAttr getIndexingMaps(mlir::Operation *op) const {
return (llvm::cast<ConcreteType>(op).getIndexingMaps());
}
};
} // namespace

template <typename... Args>
static void registerOpsWithLinalgExtOpInterface(mlir::MLIRContext *context) {
(Args::template attachInterface<LinalgFusionOpInterfaceAdapter<Args>>(
*context),
...);
}

void IREELinalgExtDialect::initialize() {
mlir::MLIRContext *context = getContext();
context->loadDialect<mlir::linalg::LinalgDialect>();

#define GET_OP_LIST
declarePromisedInterfaces<LinalgFusionOpInterface,
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();

#define GET_OP_LIST
registerOpsWithLinalgExtOpInterface<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
addInterfaces<IREELinalgExtInlinerInterface>();

[[maybe_unused]] bool isInterfacePromised =
hasPromisedInterface<linalg::GenericOp, LinalgFusionOpInterface>();
assert(isInterfacePromised &&
"linalg::GenericOp should have LinalgFusionOpInterface");

addAttributes<
#define GET_ATTRDEF_LIST
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtAttrs.cpp.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTINTERFACES_H_

#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"

#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"

Expand Down
Loading

0 comments on commit 724cbc0

Please sign in to comment.