Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Apr 12, 2024
1 parent 236b1c9 commit 974f4d1
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 9 deletions.
22 changes: 15 additions & 7 deletions compiler/lib/Dialect/mhlo/Transforms/HloAggressiveFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,23 @@ bool isCustomMhloRngUniformOp(Operation *op) {
return false;
}

bool isAliasLikeOp(Operation *op) {
if (llvm::isa<mhlo::ReshapeOp>(op)) {
return true;
} else if (auto slice = llvm::dyn_cast_if_present<mhlo::SliceOp>(op)) {
return isSliceNoStrideSubview(slice);
}
return false;
}

bool isFusibleCandidate(Operation *op) {
if (isCustomMhloRngUniformOp(op))
return true;
return isMhlo(op) && !llvm::isa<mhlo::CustomCallOp>(op);
if (isAliasLikeOp(op))
return false;
if (llvm::isa<mhlo::CustomCallOp>(op))
return false;
return isMhlo(op);
}

bool isFusibleStart(Operation *) { return true; }
Expand All @@ -55,12 +68,7 @@ bool isFusibleTrigger(Operation *) { return true; }

bool isFusibleWith(Operation *, Operation *) { return true; }

bool isValidSingleOp(Operation *op) {
if (llvm::isa<mhlo::ReshapeOp>(op))
return false;
else
return true;
}
bool isValidSingleOp(Operation *op) { return true; }

bool isValidFusionPattern(const MhloFusionPattern &) { return true; }

Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/Dialect/mhlo/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ bool mlir::isSplatMhloConstantValue(Value val, double splat_val) {
return isSplatMhloConstantValue(val.getDefiningOp(), splat_val);
}

bool isSliceSubviewWithoutStride(mhlo::SliceOp op) {
bool mlir::isSliceNoStrideSubview(mhlo::SliceOp op) {
auto type = cast<RankedTensorType>(op.getOperand().getType());
if (!type.hasStaticShape()) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func.func @reshape_add(%arg0: tensor<2xf32>, %arg1: tensor<2x1xf32>) -> (tensor<
return %1 : tensor<2x1xf32>
}
// CHECK-LABEL: func.func @reshape_add
// CHECK-NEXT: mhlo.reshape
// CHECK-NEXT: mhlo.fusion
// CHECK-NEXT: mhlo.reshape
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.return
// CHECK: {__byteir_hlo_aggressive_fusion__}
Expand Down
10 changes: 10 additions & 0 deletions compiler/test/Pipelines/Host/E2E/AliasLike/00_Input.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: byteir-opt %s --hlo-opt="target=CPU" --linalg-tensor-opt="target=CPU" --byre-tensor-opt="entry-func=main append-arg-types" --byteir-bufferize-opt --scf-opt="target=CPU" | FileCheck %s

// CHECK-LABEL: func.func @main

func.func @main(%arg0: tensor<512x200xf32>, %arg1: tensor<128x2x100xf32>) -> tensor<128x2x100xf32> {
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[138, 200]> : tensor<2xi64>, start_indices = dense<[10, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<512x200xf32>) -> tensor<128x200xf32>
%1 = mhlo.reshape %0 : (tensor<128x200xf32>) -> tensor<128x2x100xf32>
%2 = mhlo.add %1, %arg1 : tensor<128x2x100xf32>
return %2 : tensor<128x2x100xf32>
}
27 changes: 27 additions & 0 deletions compiler/test/Pipelines/Host/E2E/AliasLike/01_HostOpt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module {
func.func private @Unknown0(%arg0: memref<128x2x100xf32>, %arg1: memref<128x2x100xf32>) -> memref<128x2x100xf32> attributes {__byteir_hlo_aggressive_fusion__} {
%c1 = arith.constant 1 : index
%c25600 = arith.constant 25600 : index
%c0 = arith.constant 0 : index
%alloc = memref.alloc() : memref<128x2x100xf32>
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<128x2x100xf32> into memref<25600xf32>
%collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2]] : memref<128x2x100xf32> into memref<25600xf32>
%collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2]] : memref<128x2x100xf32> into memref<25600xf32>
scf.for %arg2 = %c0 to %c25600 step %c1 {
%0 = memref.load %collapse_shape[%arg2] : memref<25600xf32>
%1 = memref.load %collapse_shape_0[%arg2] : memref<25600xf32>
%2 = arith.addf %0, %1 : f32
memref.store %2, %collapse_shape_1[%arg2] : memref<25600xf32>
}
return %alloc : memref<128x2x100xf32>
}
func.func @main(%arg0: memref<512x200xf32>, %arg1: memref<128x2x100xf32>) -> memref<128x2x100xf32> attributes {__placeholder__byre.entry_point} {
%subview = memref.subview %arg0[10, 0] [128, 200] [1, 1] : memref<512x200xf32> to memref<128x200xf32, strided<[200, 1], offset: 2000>>
%expand_shape = memref.expand_shape %subview [[0], [1, 2]] : memref<128x200xf32, strided<[200, 1], offset: 2000>> into memref<128x2x100xf32, strided<[200, 100, 1], offset: 2000>>
%alloc = memref.alloc() : memref<128x2x100xf32>
memref.copy %expand_shape, %alloc : memref<128x2x100xf32, strided<[200, 100, 1], offset: 2000>> to memref<128x2x100xf32>
%0 = call @Unknown0(%alloc, %arg1) : (memref<128x2x100xf32>, memref<128x2x100xf32>) -> memref<128x2x100xf32>
return %0 : memref<128x2x100xf32>
}
}

8 changes: 8 additions & 0 deletions compiler/test/Pipelines/Host/E2E/AliasLike/TotalPipeline.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

func.func @main(%arg0: tensor<512x200xf32>, %arg1: tensor<128x2x100xf32>) -> tensor<128x2x100xf32> {
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[138, 200]> : tensor<2xi64>, start_indices = dense<[10, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<512x200xf32>) -> tensor<128x200xf32>
%1 = mhlo.reshape %0 : (tensor<128x200xf32>) -> tensor<128x2x100xf32>
%2 = mhlo.add %1, %arg1 : tensor<128x2x100xf32>
return %2 : tensor<128x2x100xf32>
}

33 changes: 33 additions & 0 deletions compiler/test/Pipelines/Host/E2E/AliasLike/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Testcase(
contents=[Content(stages=(Input, E2E), content=r"""
func.func @main(%arg0: tensor<512x200xf32>, %arg1: tensor<128x2x100xf32>) -> tensor<128x2x100xf32> {
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[138, 200]> : tensor<2xi64>, start_indices = dense<[10, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<512x200xf32>) -> tensor<128x200xf32>
%1 = mhlo.reshape %0 : (tensor<128x200xf32>) -> tensor<128x2x100xf32>
%2 = mhlo.add %1, %arg1 : tensor<128x2x100xf32>
return %2 : tensor<128x2x100xf32>
}
""")],
pipelines=[
InputPipeline(r"""
// CHECK-LABEL: func.func @main
"""),
HostOptPipeline(r"""
// CHECK-LABEL: func.func @Unknown
"""),
ToLLVMPipeline(r"""
// CHECK: llvm.func
"""),
ToLLVMIRPipeline(r"""
// CHECK-LABEL: define void @_mlir_ciface_Unknown
"""),
ByreHostPipeline(r"""
// CHECK-LABEL: func.func @main
"""),
TotalPipeline(r"""
// CHECK-LABEL: define void @_mlir_ciface_Unknown
"""),
ByreOutPipeline(r"""
// CHECK-LABEL: func.func @main
"""),
]
)

0 comments on commit 974f4d1

Please sign in to comment.