Skip to content

Commit

Permalink
Merge branch 'main' into vector-transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
XG-zheng committed Jun 19, 2024
2 parents 400418b + d81a078 commit fdf8af8
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 18 deletions.
142 changes: 124 additions & 18 deletions compiler/lib/Dialect/Linalg/Transforms/LinalgPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Some code comes from
// compiler/src/iree/compiler/Codegen/Utils/GPUUtils.cpp
// of IREE project
// Original licence:
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//

#include "byteir/Dialect/Linalg/Transforms/LinalgPromotion.h"
#include "byteir/Dialect/GPU/Transforms/Utils.h"
Expand Down Expand Up @@ -125,11 +135,14 @@ LogicalResult copyWorkgroupMemoryToGlobalMemory(OpBuilder &b, Value src,
OpBuilder::InsertionGuard guard(b);

auto op = src.getDefiningOp();
// get the only scf.for op inside the scf.forall op.
scf::ForallOp forallOp = op->getParentOfType<scf::ForallOp>();
// copyWorkgroupMemoryToGlobalMemory before the GPU kernel end.
Operation *terminator = forallOp.getBody()->getTerminator();
b.setInsertionPoint(terminator);
auto forOps = llvm::to_vector(forallOp.getOps<scf::ForOp>());
if (forOps.size() != 1)
return forallOp.emitError("expected a single scf.for op");

// copyWorkgroupMemoryToGlobalMemory after gemm compute ends.
b.setInsertionPointAfter(forOps[0]);
Operation *copyOp = b.create<linalg::CopyOp>(src.getLoc(), src, dst);
setLinalgTransformationMarker(copyOp,
getCopyRelatedToWorkgroupMemoryMarker());
Expand Down Expand Up @@ -167,6 +180,81 @@ static LogicalResult promotionImpl(OpBuilder &builder, Operation *op) {
return success();
}

// Split input/output operand from copy from shared memory into a separate
// input.
static void insertInputValueIntoGeneric(Value source,
linalg::GenericOp genericOp) {
Location loc = genericOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<AffineMap> operandMaps;

// Get and add existing input operands and their corresponding indexing maps.
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
inputOperands.push_back(inputOperand->get());
operandMaps.push_back(genericOp.getMatchingIndexingMap(inputOperand));
}

// Add the new input operand.
inputOperands.push_back(source);

// Ensure there is only one output operand.
assert(genericOp.getNumDpsInits() == 1);
OpOperand *outputOperand = genericOp.getDpsInitOperand(0);

// Add indexing maps for the output operand.
operandMaps.push_back(genericOp.getMatchingIndexingMap(outputOperand));
operandMaps.push_back(genericOp.getMatchingIndexingMap(outputOperand));

SmallVector<utils::IteratorType> iteratorTypes(genericOp.getNumLoops(),
utils::IteratorType::parallel);

OpBuilder builder(genericOp);

// Create a new GenericOp.
auto newGenericOp = builder.create<linalg::GenericOp>(
loc, inputOperands, outputOperand->get(), operandMaps, iteratorTypes);

// Move the original operation's blocks to the new operation.
newGenericOp.getRegion().getBlocks().splice(
newGenericOp.getRegion().begin(), genericOp.getRegion().getBlocks());

// Add a new argument to the payload.
Block &payload = newGenericOp.getRegion().front();
payload.addArgument(payload.getArguments().back().getType(), loc);

// Set the Linalg transformation marker.
setLinalgTransformationMarker(newGenericOp,
getCopyRelatedToWorkgroupMemoryMarker());
}

/// Propagate the shared memory copy into the consumer op if it's a fully
/// parallel linalg.generic.
static bool
propagateCopySourceIntoConsumerGeneric(linalg::CopyOp copyOp,
SmallVector<Operation *> &toDelete) {
// Look for a generic Op reading the copyOp target.
Operation *nextOp = copyOp->getNextNode();
while (nextOp) {
if (isMemoryEffectFree(nextOp)) {
nextOp = nextOp->getNextNode();
continue;
}
auto consumer = dyn_cast<linalg::GenericOp>(nextOp);
if (!consumer || consumer.getNumDpsInits() != 1 ||
!consumer.getMatchingIndexingMap(consumer.getDpsInitOperand(0))
.isIdentity())
break;
auto linalgCopyTarget = copyOp.getDpsInitOperand(0)->get();
auto linalgCopySource = copyOp.getDpsInputOperand(0)->get();
if (*consumer.getOutputs().begin() != linalgCopyTarget)
break;
insertInputValueIntoGeneric(linalgCopySource, consumer);
toDelete.push_back(consumer);
return true;
}
return false;
}

struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
public:
LinalgPromotionPass() = default;
Expand All @@ -187,22 +275,40 @@ struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
if (isa<linalg::MatmulOp, linalg::BatchMatmulOp>(linalgOp))
toPromote.push_back(linalgOp);
});
if (toPromote.empty())
return;

assert(toPromote.size() == 1);
auto linalgContractOp = toPromote[0];
OpBuilder builder(linalgContractOp);

for (auto linalgOp : toPromote) {
OpBuilder builder(linalgOp);

// As we want to mark every generated op, so we do promote seperately.
(void)promotionImpl<MatmulOperands::A>(builder, linalgOp);
(void)promotionImpl<MatmulOperands::B>(builder, linalgOp);

// TODO:
// If we do promotion before we split K, it will be much easier.
// The right order should be split i, j, promote C, split k, promote A\B
// As we know linalg.matmul is in a scf.for, and the subview promotionImpl
// inserts should be in the scf.forall op.
auto forOp = linalgOp->getParentOfType<scf::ForOp>();
builder.setInsertionPoint(forOp); // before forOp
(void)promotionImpl<MatmulOperands::C>(builder, linalgOp);
// As we want to mark every generated op, so we do promote seperately.
(void)promotionImpl<MatmulOperands::A>(builder, linalgContractOp);
(void)promotionImpl<MatmulOperands::B>(builder, linalgContractOp);

// TODO:
// If we do promotion before we split K, it will be much easier.
// The right order should be split i, j, promote C, split k, promote A\B
// As we know linalg.matmul is in a scf.for, and the subview promotionImpl
// inserts should be in the scf.forall op.
auto forOp = linalgContractOp->getParentOfType<scf::ForOp>();
builder.setInsertionPoint(forOp); // before forOp
(void)promotionImpl<MatmulOperands::C>(builder, linalgContractOp);

// The linalg.copy should be fused with its consumer linalg.generic.
// So first to find linalg.copy which has marker
// "__byteir_store_matrix_c__"
linalg::CopyOp copyToGlobalOp;
forallOp.walk([&](linalg::CopyOp copyOp) {
if (hasMarker(copyOp, copyMarker[MatmulOperands::C])) {
copyToGlobalOp = copyOp;
}
});
SmallVector<Operation *> toDelete;
if (propagateCopySourceIntoConsumerGeneric(copyToGlobalOp, toDelete)) {
toDelete.push_back(copyToGlobalOp);
for (Operation *op : toDelete)
op->erase();
}
}
};
Expand Down
61 changes: 61 additions & 0 deletions compiler/test/Dialect/Linalg/linalg-promotion-epilogue-fusion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: byteir-opt -linalg-promotion --cse --canonicalize %s | FileCheck %s
#map = affine_map<(d0) -> (d0 * 128)>
module {
func.func private @Unknown0(%arg0: memref<5376x2048xf16>, %arg1: memref<2048x5376xf16>) -> memref<5376x5376xf16> attributes {__byteir_gemm_block_size__ = [64, 2, 1], __byteir_gemm_pipeline_depth__ = 3 : i64, __byteir_gemm_tile_config__ = [128, 128, 32], __byteir_matmul_epilogue_fusion__} {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<5376x5376xf16>
scf.forall (%arg2, %arg3) in (42, 42) {
%0 = affine.apply #map(%arg2)
%1 = affine.apply #map(%arg3)
%subview = memref.subview %alloc[%0, %1] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>>
linalg.fill ins(%cst : f16) outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>)
scf.for %arg4 = %c0 to %c2048 step %c32 {
%subview_0 = memref.subview %arg0[%0, %arg4] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>>
%subview_1 = memref.subview %arg1[%arg4, %1] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>>
linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%subview_0, %subview_1 : memref<128x32xf16, strided<[2048, 1], offset: ?>>, memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>)
}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%subview : memref<128x128xf16, strided<[5376, 1], offset: ?>>) {
^bb0(%out: f16):
%6 = arith.maximumf %out, %cst : f16
linalg.yield %6 : f16
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
return %alloc : memref<5376x5376xf16>
}
}
// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 128)>
// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func private @Unknown0(%[[ARG0:.*]]: memref<5376x2048xf16>, %[[ARG1:.*]]: memref<2048x5376xf16>) -> memref<5376x5376xf16> attributes {__byteir_gemm_block_size__ = [64, 2, 1], __byteir_gemm_pipeline_depth__ = 3 : i64, __byteir_gemm_tile_config__ = [128, 128, 32], __byteir_matmul_epilogue_fusion__} {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[C2048:.*]] = arith.constant 2048 : index
// CHECK-NEXT: %[[C32:.*]] = arith.constant 32 : index
// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<5376x5376xf16>
// CHECK-NEXT: scf.forall (%[[ARG2:.*]], %[[ARG3:.*]]) in (42, 42) {
// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {__byteir_alloca_accumulator__} : memref<128x128xf16, #gpu.address_space<workgroup>>
// CHECK-NEXT: %[[ALLOCA_0:.*]] = memref.alloca() {__byteir_alloca_matrix_b__} : memref<32x128xf16, #gpu.address_space<workgroup>>
// CHECK-NEXT: %[[ALLOCA_1:.*]] = memref.alloca() {__byteir_alloca_matrix_a__} : memref<128x32xf16, #gpu.address_space<workgroup>>
// CHECK-NEXT: %[[APPLY_MAP0:.*]] = affine.apply #[[MAP]](%[[ARG2]])
// CHECK-NEXT: %[[APPLY_MAP1:.*]] = affine.apply #[[MAP]](%[[ARG3]])
// CHECK-NEXT: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[APPLY_MAP0]], %[[APPLY_MAP1]]] [128, 128] [1, 1] : memref<5376x5376xf16> to memref<128x128xf16, strided<[5376, 1], offset: ?>>
// CHECK-NEXT: linalg.fill ins(%[[CST]] : f16) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space<workgroup>>)
// CHECK-NEXT: scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2048]] step %[[C32]] {
// CHECK-NEXT: %[[SUBVIEW_2:.*]] = memref.subview %[[ARG0]][%[[APPLY_MAP0]], %[[ARG4]]] [128, 32] [1, 1] : memref<5376x2048xf16> to memref<128x32xf16, strided<[2048, 1], offset: ?>>
// CHECK-NEXT: %[[SUBVIEW_3:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[APPLY_MAP1]]] [32, 128] [1, 1] : memref<2048x5376xf16> to memref<32x128xf16, strided<[5376, 1], offset: ?>>
// CHECK-NEXT: linalg.copy {__byteir_load_matrix_a__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_2]] : memref<128x32xf16, strided<[2048, 1], offset: ?>>) outs(%[[ALLOCA_1]] : memref<128x32xf16, #gpu.address_space<workgroup>>)
// CHECK-NEXT: linalg.copy {__byteir_load_matrix_b__, __internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} ins(%[[SUBVIEW_3]] : memref<32x128xf16, strided<[5376, 1], offset: ?>>) outs(%[[ALLOCA_0]] : memref<32x128xf16, #gpu.address_space<workgroup>>)
// CHECK-NEXT: linalg.matmul {__byteir_gpu_tile_gemm_0, __byteir_mma__, __byteir_mma_level__ = "Threadblock", __byteir_target__ = "nv_sm_80"} ins(%[[ALLOCA_1]], %[[ALLOCA_0]] : memref<128x32xf16, #gpu.address_space<workgroup>>, memref<32x128xf16, #gpu.address_space<workgroup>>) outs(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space<workgroup>>)
// CHECK-NEXT: }
// CHECK-NEXT: linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[ALLOCA]] : memref<128x128xf16, #gpu.address_space<workgroup>>) outs(%[[SUBVIEW]] : memref<128x128xf16, strided<[5376, 1], offset: ?>>) attrs = {__internal_linalg_transform__ = "__byteir_copy_related_to_workgroup_memory__"} {
// CHECK-NEXT: ^bb0(%in: f16, %out: f16):
// CHECK-NEXT: %2 = arith.maximumf %in, %cst : f16
// CHECK-NEXT: linalg.yield %2 : f16
// CHECK-NEXT: }
// CHECK-NEXT: } {mapping = [#gpu.block<y>, #gpu.block<x>]}
// CHECK-NEXT: return %[[ALLOC]] : memref<5376x5376xf16>
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit fdf8af8

Please sign in to comment.