diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp index a3fdc7ee385ed..d54751098410b 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp @@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final } }; +static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, + Value view, mlir::OperandRange indices, + SmallVectorImpl &resolvedIndices, + Value &memrefBase, StringRef role) { + Operation *defOp = view.getDefiningOp(); + if (!defOp) { + return failure(); + } + return llvm::TypeSwitch(defOp) + .Case([&](memref::SubViewOp subviewOp) { + mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, loc, subviewOp.getMixedOffsets(), + subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices, + resolvedIndices); + memrefBase = subviewOp.getSource(); + return success(); + }) + .Case([&](memref::ExpandShapeOp expandShapeOp) { + if (failed(mlir::memref::resolveSourceIndicesExpandShape( + loc, rewriter, expandShapeOp, indices, resolvedIndices, + false))) { + return failure(); + } + memrefBase = expandShapeOp.getViewSource(); + return success(); + }) + .Case( + [&](memref::CollapseShapeOp collapseShapeOp) { + if (failed(mlir::memref::resolveSourceIndicesCollapseShape( + loc, rewriter, collapseShapeOp, indices, + resolvedIndices))) { + return failure(); + } + memrefBase = collapseShapeOp.getViewSource(); + return success(); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure( + op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or " + "CollapseShapeOp") + .str()); + }); +} + struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value memrefSource; - SmallVector sourceIndices; - auto foldResult = - llvm::TypeSwitch( - op.getSrc().getDefiningOp()) - .Case([&](memref::SubViewOp subviewOp) { - // If the source is a SubViewOp, we can directly rewrite the - // GatherToLDSOp. - mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( - rewriter, loc, subviewOp.getMixedOffsets(), - subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), - op.getSrcIndices(), sourceIndices); - memrefSource = subviewOp.getSource(); - return success(); - }) - .Case( - [&](memref::ExpandShapeOp expandShapeOp) { - if (failed(mlir::memref::resolveSourceIndicesExpandShape( - loc, rewriter, expandShapeOp, op.getSrcIndices(), - sourceIndices, false))) { - return failure(); - } - memrefSource = expandShapeOp.getViewSource(); - return success(); - }) - .Case( - [&](memref::CollapseShapeOp collapseShapeOp) { - if (failed(mlir::memref::resolveSourceIndicesCollapseShape( - loc, rewriter, collapseShapeOp, op.getSrcIndices(), - sourceIndices))) { - return failure(); - } - memrefSource = collapseShapeOp.getViewSource(); - return success(); - }) - .Default([&](Operation *op) { - // If the source is not a SubViewOp, ExpandShapeOp, or - // CollapseShapeOp, we cannot fold the GatherToLDSOp. - return rewriter.notifyMatchFailure( - op, - "source producer is not one of SubViewOp, ExpandShapeOp, or " - "CollapseShapeOp"); - }); + SmallVector sourceIndices, destIndices; + Value memrefSource, memrefDest; + + auto foldSrcResult = + foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(), + sourceIndices, memrefSource, "source"); + + if (failed(foldSrcResult)) { + memrefSource = op.getSrc(); + sourceIndices = op.getSrcIndices(); + } + + auto foldDstResult = + foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(), + destIndices, memrefDest, "destination"); - if (failed(foldResult)) { - return failure(); + if (failed(foldDstResult)) { + memrefDest = op.getDst(); + destIndices = op.getDstIndices(); } rewriter.replaceOpWithNewOp(op, memrefSource, sourceIndices, - op.getDst(), op.getDstIndices(), + memrefDest, destIndices, op.getTransferType()); return success(); diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir index 57afa127c9da8..8ca3dd60414df 100644 --- a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir +++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir @@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) { // CHECK: func @test_expand_shape // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index func.func @test_expand_shape(%offset_i: index, %offset_j: index) { - // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3> + // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3> // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16> // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index - // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]] - // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3> + // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index + // CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index + // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]] + // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3> - %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace> + %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace> %mem = memref.alloc() : memref<8192xf16> - %expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16> + %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16> + %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace> %c0 = arith.constant 0 : index - amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0] + amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0] : vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace> func.return } @@ -80,15 +82,82 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) { // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3> // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16> // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index - // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]] + // CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index + // CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index + // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1] // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3> %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace> + %collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace> %mem = memref.alloc() : memref<64x128xf16> - %collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16> + %collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16> %c0 = arith.constant 0 : index - amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0] + amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j] + : vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace> + func.return +} + + +// ----- + +#gpu_lds_addrspace = 3 + + +// CHECK: func @test_expand_shape_src_raw_buffer +// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index +func.func @test_expand_shape_src_raw_buffer(%mem : memref<8192xf16, #amdgpu.address_space>, %offset_i: index, %offset_j: index) { + // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG1]], %[[ARG2]]] by (64, 128) : index + // CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[IDXM]]], %[[LOCAL]][%[[C0]]] + // CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space>, memref<4096xf16, 3> + + %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace> + %expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16, #amdgpu.address_space> into memref<64x128xf16, #amdgpu.address_space> + + %c0 = arith.constant 0 : index + amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %alloc[%c0] + : vector<8xf16>, memref<64x128xf16, #amdgpu.address_space>, memref<4096xf16, #gpu_lds_addrspace> + func.return +} + +// ----- + +#gpu_lds_addrspace = 3 + +// CHECK: func @test_expand_shape_dst_only +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index +func.func @test_expand_shape_dst_only(%offset_i: index, %offset_j: index) { + // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3> + // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[IDX_LDS:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (64, 64) : index + // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]]], %[[LOCAL]][%[[IDX_LDS]]] + // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3> + + %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace> + %mem = memref.alloc() : memref<8192xf16> + %expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace> + + %c0 = arith.constant 0 : index + amdgpu.gather_to_lds %mem[%offset_i], %expand_alloc[%offset_j, %c0] : vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace> func.return } + +// ----- + +#gpu_lds_addrspace = 3 + +// CHECK: func @test_nop +// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index +func.func @test_nop(%mem : memref<8192xf16, #amdgpu.address_space>, %offset_i: index, %offset_j: index) { + // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3> + // CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[ARG1]]], %[[LOCAL]][%[[ARG2]]] + // CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space>, memref<4096xf16, 3> + + %alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace> + amdgpu.gather_to_lds %mem[%offset_i], %alloc[%offset_j] + : vector<8xf16>, memref<8192xf16, #amdgpu.address_space>, memref<4096xf16, #gpu_lds_addrspace> + func.return +}