Skip to content

Commit

Permalink
move method defs to tblgen
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 committed May 17, 2024
1 parent cfa9fe1 commit 65e10b1
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ isFusableWithConsumer(OpOperand &fusedOperand,
// Either fuse pad with producer or with consumer.
if (auto padOp = dyn_cast<tensor::PadOp>(consumer)) {
if (options.fusePadWithProducers || isPadUsedInSetEncoding(padOp)) {
return isa<LinalgExt::LinalgFusionOpInterface>(producer);
return isa<linalg::LinalgOp>(producer);
}
return false;
}
Expand Down Expand Up @@ -720,7 +720,7 @@ isFusableWithProducer(OpOperand &operand,

if (auto padOp = dyn_cast<tensor::PadOp>(consumer)) {
if (options.fusePadWithProducers || isPadUsedInSetEncoding(padOp)) {
return isa<LinalgExt::LinalgFusionOpInterface>(producer);
return isa<linalg::LinalgOp>(producer);
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_lit_test_suite(
"form_dispatch_regions.mlir",
"form_dispatch_workgroups.mlir",
"form_scalar_dispatches.mlir",
"dispatch_linalg_ext_fusion.mlir",
"fusion_of_tensor_ops.mlir",
"fusion_preprocessing.mlir",
"initialize_empty_tensors.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"collapse_reduction.mlir"
"convert_region_to_workgroups.mlir"
"deduplicate_executables.mlir"
"dispatch_linalg_ext_fusion.mlir"
"dispatch_linalg_on_tensors.mlir"
"dispatch_linalg_on_tensors_default.mlir"
"dispatch_linalg_on_tensors_fusion_with_transpose.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-form-dispatch-workgroups), cse, canonicalize, cse)" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
%3 = tensor.empty() : tensor<4x1xi32>
%expanded = tensor.empty() : tensor<4x1xi64>
%expanded_0 = tensor.empty() : tensor<4x1x16x8x128xf32>
%2 = tensor.empty() : tensor<8192x16x8x128xf32>
%result = tensor.empty() : tensor<8192x16x8x128xf32>

%4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%expanded : tensor<4x1xi64>)
outs(%3 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<4x1xi32>
%5 = iree_linalg_ext.scatter
dimension_map = [0]
unique_indices(false)
ins(%expanded_0, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
outs(%2 : tensor<8192x16x8x128xf32>) {
^bb0(%arg5: f32, %arg6: f32):
iree_linalg_ext.yield %arg5 : f32
} -> tensor<8192x16x8x128xf32>
%6 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<8192x16x8x128xf32>) outs(%result : tensor<8192x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<8192x16x8x128xf32>

util.return %6 : tensor<8192x16x8x128xf32>
}

// CHECK: util.func public @linalgext_scatter_fusion
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK: %[[EXPANDED:.+]] = linalg.generic
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[UPDATE_TENSOR:.+]], %[[GEN:.+]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
// CHECK: %[[GEN:.+]] = linalg.generic
// CHECK-SAME: ins(%[[SCATTER_RESULT]] : tensor<8192x16x8x128xf32>)


// -----


#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
%input = tensor.empty() : tensor<10x10xi64>
%shrunk = tensor.empty() : tensor<10x10xi32>

%4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%input: tensor<10x10xi64>)
outs(%shrunk : tensor<10x10xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<10x10xi32>
%reversed_outs = tensor.empty() : tensor<10x10xi32>
%reversed = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%4 : tensor<10x10xi32>) outs(%reversed_outs : tensor<10x10xi32>) : tensor<10x10xi32>
%generic_outs = tensor.empty() : tensor<10x10xi32>
%6 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]}
ins(%reversed : tensor<10x10xi32>)
outs(%generic_outs : tensor<10x10xi32>) {
^bb0(%in: i32, %out: i32):
%10 = arith.addi %in, %out : i32
linalg.yield %10 : i32
} -> tensor<10x10xi32>

util.return %6 : tensor<10x10xi32>
}

// COM: // CHECK: util.func public @linalgext_reverse_fusion
// COM: // CHECK: %[[SHRUNK:.+]] = linalg.generic
// COM: // CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse
// COM: // CHECK-SAME: ins(%[[SHRUNK]] : tensor<10x10xi32>)
// COM: // CHECK: %[[ADD:.+]] = linalg.generic
// COM: // CHECK-SAME: ins(%[[REVERSED]] : tensor<10x10xi32>)
Original file line number Diff line number Diff line change
Expand Up @@ -605,51 +605,3 @@ util.func public @no_dequantization_like_fusion(%arg0: tensor<32x1x16x1x8xi16>,
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: flow.return %[[MMT4D]] :
// CHECK: util.return %[[DISP]]


// -----


#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
%3 = tensor.empty() : tensor<4x1xi32>
%expanded = tensor.empty() : tensor<4x1xi64>
%expanded_0 = tensor.empty() : tensor<4x1x16x8x128xf32>
%2 = tensor.empty() : tensor<8192x16x8x128xf32>
%result = tensor.empty() : tensor<8192x16x8x128xf32>

%4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%expanded : tensor<4x1xi64>)
outs(%3 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<4x1xi32>
%5 = iree_linalg_ext.scatter
dimension_map = [0]
unique_indices(false)
ins(%expanded_0, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
outs(%2 : tensor<8192x16x8x128xf32>) {
^bb0(%arg5: f32, %arg6: f32):
iree_linalg_ext.yield %arg5 : f32
} -> tensor<8192x16x8x128xf32>
%6 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<8192x16x8x128xf32>) outs(%result : tensor<8192x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<8192x16x8x128xf32>

util.return %6 : tensor<8192x16x8x128xf32>
}

// COM: Only checking for 2nd and 3rd op fusion
// CHECK: util.func public @linalgext_scatter_fusion
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK: %[[GEN:.+]] = linalg.generic
// CHECK-SAME: ins(%[[SCATTER_RESULT]] : tensor<8192x16x8x128xf32>)
// CHECK: flow.return %[[GEN]]
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface", [DestinationS
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getNumParallelLoops();
return llvm::count($_op.getLoopIteratorTypes(), utils::IteratorType::parallel);
}]
>,

InterfaceMethod<
/*desc=*/[{
Return the number of loops.
Return the total number of loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumLoops",
/*args=*/(ins),
/*methodBody=*/""
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getLoopIteratorTypes().size();
}]
>,
InterfaceMethod<
/*desc=*/[{
Expand All @@ -77,11 +80,20 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface", [DestinationS
/*retTy=*/"SmallVector<int64_t, 4>",
/*methodName=*/"getStaticLoopRanges",
/*args=*/(ins),
/*methodBody=*/""
/*methodBody=*/"",
/*defaultImplementation=*/[{
SmallVector<int64_t, 4> loopRanges;
llvm::for_each($_op.getOperands(), [&](Value operand) {
if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
llvm::append_range(loopRanges, shapedType.getShape());
}
});
return loopRanges;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing map for a `result`.
Return the indexing map for an op's `result`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getIndexingMapMatchingResult",
Expand Down Expand Up @@ -112,12 +124,23 @@ def LinalgFusionInterface : OpInterface<"LinalgFusionOpInterface", [DestinationS
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing maps attribute within the current operation.
Return the indexing maps for this op.
}],
/*retTy=*/"ArrayAttr",
/*methodName=*/"getIndexingMaps",
/*args=*/(ins),
/*methodBody=*/""
/*methodBody=*/"",
/*defaultImplementation=*/[{
Builder builder($_op.getContext());
SmallVector<AffineMap> maps;
llvm::append_range(maps, llvm::map_range($_op.getOperands(), [&](Value operand) {
// All inputs must be shaped
assert(llvm::isa<ShapedType>(operand.getType()) && "expected ShapedType operand");
return builder.getMultiDimIdentityMap(
cast<ShapedType>(operand.getType()).getRank());
}));
return builder.getAffineMapArrayAttr(maps);
}]
>,

];
Expand Down
32 changes: 0 additions & 32 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,38 +250,6 @@ ScatterOp::reifyResultShapes(OpBuilder &b,
.reifyResultShapes(b, reifiedReturnShapes);
}

// TODO: Could be implemented for all ops in tblgen
unsigned ScatterOp::getNumParallelLoops() {
return llvm::count(getLoopIteratorTypes(), utils::IteratorType::parallel);
}

// TODO: Could be implemented for all ops in tblgen
unsigned ScatterOp::getNumLoops() { return getLoopIteratorTypes().size(); }

SmallVector<int64_t, 4> ScatterOp::getStaticLoopRanges() {
// TODO: remove stub implementation.
SmallVector<int64_t, 4> loopRanges;
for (auto operand : getOperands()) {
if (auto shapedType = dyn_cast<ShapedType>(operand.getType())) {
auto shape = shapedType.getShape();
llvm::append_range(loopRanges, shape);
}
}
return loopRanges;
}

ArrayAttr ScatterOp::getIndexingMaps() {
// TODO: remove stub implementation.
Builder builder(getContext());
int64_t updateRank = getUpdateType().getRank();
int64_t indicesRank = getIndicesType().getRank();
int64_t originalRank = getOriginalType().getRank();
return builder.getAffineMapArrayAttr(
SmallVector<AffineMap>({builder.getMultiDimIdentityMap(updateRank),
builder.getMultiDimIdentityMap(indicesRank),
builder.getMultiDimIdentityMap(originalRank)}));
}

//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ let opDocGroup = OpGroupNonStructuredOps in {

def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgFusionInterface,
["getNumParallelLoops",
"getNumLoops",
"getStaticLoopRanges",
"getIndexingMaps",
]>,
DeclareOpInterfaceMethods<LinalgFusionInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
Expand Down Expand Up @@ -375,7 +370,7 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",

def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<
DeclareOpInterfaceMethods<
TilingInterface,
["generateScalarImplementation",
"getIterationDomain",
Expand Down

0 comments on commit 65e10b1

Please sign in to comment.