diff --git a/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp b/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp index 14257cdfe..34973f04e 100644 --- a/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp +++ b/compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp @@ -24,6 +24,16 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// Some code comes from +// compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp +// of IREE project +// Original licence: +// Copyright 2021 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// #include "byteir/Dialect/Linalg/Transforms/LinalgPromotion.h" #include "byteir/Dialect/GPU/Transforms/Utils.h" @@ -125,11 +135,14 @@ LogicalResult copyWorkgroupMemoryToGlobalMemory(OpBuilder &b, Value src, OpBuilder::InsertionGuard guard(b); auto op = src.getDefiningOp(); + // get the only scf.for op inside the scf.forall op. scf::ForallOp forallOp = op->getParentOfType(); - // copyWorkgroupMemoryToGlobalMemory before the GPU kernel end. - Operation *terminator = forallOp.getBody()->getTerminator(); - b.setInsertionPoint(terminator); + auto forOps = llvm::to_vector(forallOp.getOps()); + if (forOps.size() != 1) + return forallOp.emitError("expected a single scf.for op"); + // copyWorkgroupMemoryToGlobalMemory after gemm compute ends. + b.setInsertionPointAfter(forOps[0]); Operation *copyOp = b.create(src.getLoc(), src, dst); setLinalgTransformationMarker(copyOp, getCopyRelatedToWorkgroupMemoryMarker()); @@ -167,6 +180,81 @@ static LogicalResult promotionImpl(OpBuilder &builder, Operation *op) { return success(); } +// Split input/output operand from copy from shared memory into a separate +// input. +static void insertInputValueIntoGeneric(Value source, + linalg::GenericOp genericOp) { + Location loc = genericOp.getLoc(); + SmallVector inputOperands; + SmallVector operandMaps; + + // Get and add existing input operands and their corresponding indexing maps. + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { + inputOperands.push_back(inputOperand->get()); + operandMaps.push_back(genericOp.getMatchingIndexingMap(inputOperand)); + } + + // Add the new input operand. + inputOperands.push_back(source); + + // Ensure there is only one output operand. + assert(genericOp.getNumDpsInits() == 1); + OpOperand *outputOperand = genericOp.getDpsInitOperand(0); + + // Add indexing maps for the output operand. + operandMaps.push_back(genericOp.getMatchingIndexingMap(outputOperand)); + operandMaps.push_back(genericOp.getMatchingIndexingMap(outputOperand)); + + SmallVector iteratorTypes(genericOp.getNumLoops(), + utils::IteratorType::parallel); + + OpBuilder builder(genericOp); + + // Create a new GenericOp. + auto newGenericOp = builder.create( + loc, inputOperands, outputOperand->get(), operandMaps, iteratorTypes); + + // Move the original operation's blocks to the new operation. + newGenericOp.getRegion().getBlocks().splice( + newGenericOp.getRegion().begin(), genericOp.getRegion().getBlocks()); + + // Add a new argument to the payload. + Block &payload = newGenericOp.getRegion().front(); + payload.addArgument(payload.getArguments().back().getType(), loc); + + // Set the Linalg transformation marker. + setLinalgTransformationMarker(newGenericOp, + getCopyRelatedToWorkgroupMemoryMarker()); +} + +/// Propagate the shared memory copy into the consumer op if it's a fully +/// parallel linalg.generic. +static bool +propagateCopySourceIntoConsumerGeneric(linalg::CopyOp copyOp, + SmallVector &toDelete) { + // Look for a generic Op reading the copyOp target. + Operation *nextOp = copyOp->getNextNode(); + while (nextOp) { + if (isMemoryEffectFree(nextOp)) { + nextOp = nextOp->getNextNode(); + continue; + } + auto consumer = dyn_cast(nextOp); + if (!consumer || consumer.getNumDpsInits() != 1 || + !consumer.getMatchingIndexingMap(consumer.getDpsInitOperand(0)) + .isIdentity()) + break; + auto linalgCopyTarget = copyOp.getDpsInitOperand(0)->get(); + auto linalgCopySource = copyOp.getDpsInputOperand(0)->get(); + if (*consumer.getOutputs().begin() != linalgCopyTarget) + break; + insertInputValueIntoGeneric(linalgCopySource, consumer); + toDelete.push_back(consumer); + return true; + } + return false; +} + struct LinalgPromotionPass : public LinalgPromotionBase { public: LinalgPromotionPass() = default; @@ -187,22 +275,40 @@ struct LinalgPromotionPass : public LinalgPromotionBase { if (isa(linalgOp)) toPromote.push_back(linalgOp); }); + if (toPromote.empty()) + return; + + assert(toPromote.size() == 1); + auto linalgContractOp = toPromote[0]; + OpBuilder builder(linalgContractOp); - for (auto linalgOp : toPromote) { - OpBuilder builder(linalgOp); - - // As we want to mark every generated op, so we do promote seperately. - (void)promotionImpl(builder, linalgOp); - (void)promotionImpl(builder, linalgOp); - - // TODO: - // If we do promotion before we split K, it will be much easier. - // The right order should be split i, j, promote C, split k, promote A\B - // As we know linalg.matmul is in a scf.for, and the subview promotionImpl - // inserts should be in the scf.forall op. - auto forOp = linalgOp->getParentOfType(); - builder.setInsertionPoint(forOp); // before forOp - (void)promotionImpl(builder, linalgOp); + // As we want to mark every generated op, so we do promote seperately. + (void)promotionImpl(builder, linalgContractOp); + (void)promotionImpl(builder, linalgContractOp); + + // TODO: + // If we do promotion before we split K, it will be much easier. + // The right order should be split i, j, promote C, split k, promote A\B + // As we know linalg.matmul is in a scf.for, and the subview promotionImpl + // inserts should be in the scf.forall op. + auto forOp = linalgContractOp->getParentOfType(); + builder.setInsertionPoint(forOp); // before forOp + (void)promotionImpl(builder, linalgContractOp); + + // The linalg.copy should be fused with its consumer linalg.generic. + // So first to find linalg.copy which has marker + // "__byteir_store_matrix_c__" + linalg::CopyOp copyToGlobalOp; + forallOp.walk([&](linalg::CopyOp copyOp) { + if (hasMarker(copyOp, copyMarker[MatmulOperands::C])) { + copyToGlobalOp = copyOp; + } + }); + SmallVector toDelete; + if (propagateCopySourceIntoConsumerGeneric(copyToGlobalOp, toDelete)) { + toDelete.push_back(copyToGlobalOp); + for (Operation *op : toDelete) + op->erase(); } } }; diff --git a/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir b/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir new file mode 100644 index 000000000..7d2443d24 --- /dev/null +++ b/compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir @@ -0,0 +1,61 @@ +// RUN: byteir-opt -linalg-promotion --cse --canonicalize %s | FileCheck %s +#map = affine_map<(d0) -> (d0 * 128)> +module { + func.func private @Unknown0(%arg0: memref<5376x2048xf16>, %arg1: memref<2048x5376xf16>) -> memref<5376x5376xf16> attributes {__byteir_gemm_block_size__ = [64, 2, 1], __byteir_gemm_pipeline_depth__ = 3 : i64, __byteir_gemm_tile_config__ = [128, 128, 32], __byteir_matmul_epilogue_fusion__} { + %cst = arith.constant 0.000000e+00 : f16 + %c0 = arith.constant 0 : index + %c2048 = arith.constant 2048 : index + %c32 = arith.constant 32 : index + %alloc = memref.alloc() : memref<5376x5376xf16> + scf.forall (%arg2, %arg3) in (42, 42) { + %0 = affine.apply #map(%arg2) + %1 = affine.apply #map(%arg3) + %subview = memref.subview %alloc[%0, %1] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>> + linalg.fill ins(%cst : f16) outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>) + scf.for %arg4 = %c0 to %c2048 step %c32 { + %subview_0 = memref.subview %arg0[%0, %arg4] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>> + %subview_1 = memref.subview %arg1[%arg4, %1] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>> + linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%subview_0, %subview_1 : memref<128x32xf16, strided<[2048, 1], offset: ?>>, memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>) + } + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>) { + ^bb0(%out: f16): + %6 = arith.maximumf %out, %cst : f16 + linalg.yield %6 : f16 + } + } {mapping = [#gpu.block, #gpu.block]} + return %alloc : memref<5376x5376xf16> + } +} +// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 128)> +// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-NEXT: module { +// CHECK-NEXT: func.func private @Unknown0(%[[ARG0:.*]]: memref<5376x2048xf16>, %[[ARG1:.*]]: memref<2048x5376xf16>) -> memref<5376x5376xf16> attributes {__byteir_gemm_block_size__ = [64, 2, 1], __byteir_gemm_pipeline_depth__ = 3 : i64, __byteir_gemm_tile_config__ = [128, 128, 32], __byteir_matmul_epilogue_fusion__} { +// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f16 +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[C2048:.*]] = arith.constant 2048 : index +// CHECK-NEXT: %[[C32:.*]] = arith.constant 32 : index +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<5376x5376xf16> +// CHECK-NEXT: scf.forall (%[[ARG2:.*]], %[[ARG3:.*]]) in (42, 42) { +// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOCA_0:.*]] = memref.alloca() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space> +// CHECK-NEXT: %[[ALLOCA_1:.*]] = memref.alloca() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space> +// CHECK-NEXT: %[[APPLY_MAP0:.*]] = affine.apply #[[MAP]](%[[ARG2]]) +// CHECK-NEXT: %[[APPLY_MAP1:.*]] = affine.apply #[[MAP]](%[[ARG3]]) +// CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[APPLY_MAP0]], %[[APPLY_MAP1]]] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>> +// CHECK-NEXT: linalg.fill ins(%[[CST]] : f16) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2048]] step %[[C32]] { +// CHECK-NEXT: %[[SUBVIEW_2:.*]] = memref.subview %[[ARG0]][%[[APPLY_MAP0]], %[[ARG4]]] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>> +// CHECK-NEXT: %[[SUBVIEW_3:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[APPLY_MAP1]]] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>> +// CHECK-NEXT: linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_2]] : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%[[ALLOCA_1]] : memref<128x32xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_3]] : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%[[ALLOCA_0]] : memref<32x128xf16, #gpu.address_space>) +// CHECK-NEXT: linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%[[ALLOCA_1]], %[[ALLOCA_0]] : memref<128x32xf16, #gpu.address_space>, memref<32x128xf16, #gpu.address_space>) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) +// CHECK-NEXT: } +// CHECK-NEXT: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space>) outs(%[[SUBVIEW]] : memref<128x128xf16, strided<[5376, 1], offset: ?>>) attrs = {__internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} { +// CHECK-NEXT: ^bb0(%in: f16, %out: f16): +// CHECK-NEXT: %2 = arith.maximumf %in, %cst : f16 +// CHECK-NEXT: linalg.yield %2 : f16 +// CHECK-NEXT: } +// CHECK-NEXT: } {mapping = [#gpu.block, #gpu.block]} +// CHECK-NEXT: return %[[ALLOC]] : memref<5376x5376xf16> +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file