diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c166e0a..e1024eeefd4bd 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -88,13 +88,13 @@ struct MaskedLoadLowering final : OpRewritePattern { if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { return failure(); } + Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), + maskedOp, /*passthru=*/true); // Check if this is either a full inbounds load or an empty, oob load. If // so, take the fast path and don't generate an if condition, because we // know doing the oob load is always safe. if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) { - Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), - maskedOp, /*passthru=*/true); rewriter.replaceOp(maskedOp, load); return success(); } @@ -156,9 +156,7 @@ struct MaskedLoadLowering final : OpRewritePattern { }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { - Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, - /*passthru=*/true); - scf::YieldOp::create(rewriter, loc, res); + scf::YieldOp::create(rewriter, loc, load); }; auto ifOp = diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index f1d0ad545539a..fae0d3870d7fd 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -9,13 +9,13 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space>, vector<4xi1>, vector<4xf32> into vector<4xf32> return %res : vector<4xf32> } - +// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] +// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]] // CHECK: %[[IF:.*]] = scf.if // CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG1]]] // CHECK: } else { -// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1] -// CHECK: %[[SELECT:.*]] = arith.select %[[ARG2]], %[[LOAD]] +// CHECK: scf.yield %[[SELECT]] // CHECK: return %[[IF]] : vector<4xf32> @@ -36,18 +36,17 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp // CHECK-DAG: %[[BYTES:.*]] = arith.constant 2 // CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] // CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[ARG2]]] // CHECK: %[[DELTA:.*]] = arith.subi %[[SIZE]], %[[LINEAR]] // CHECK: %[[COND1:.*]] = arith.cmpi ult, %[[DELTA]], %[[C4]] // CHECK: %[[REM:.*]] = arith.remui %[[DELTA]], %[[BYTES]] // CHECK: %[[COND2:.*]] = arith.cmpi ne, %[[REM]], %[[C0]] - // CHECK: %[[COND:.*]] = arith.andi %[[COND1]], %[[COND2]] // CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (vector<4xf16>) { // CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]] // CHECK: } else { -// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] // CHECK: return %[[IF]] : vector<4xf16> // -----