From f5f5ef3bf39779d10df4cb447bd576f06792668a Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 12 Sep 2025 08:38:13 +0000 Subject: [PATCH 1/2] updated vdst selection Signed-off-by: xintin --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 203790ed95153..0078eed8b7a67 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1915,7 +1915,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern { else llvm_unreachable("unsupported row length"); - Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); + const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); + const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1}); + + const Value isEqual = + rewriter.create(loc, LLVM::ICmpPredicate::eq, vdst0, v); + + // Per `permlane(16|32)` semantics: if the first extracted element equals + // 'v', the result is the second element; otherwise it is the first. + Value vdstNew = + rewriter.create(loc, isEqual, vdst1, vdst0); permuted.emplace_back(vdstNew); } From d15f560326fdd007fad73daf0f25cec539adadcd Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 12 Sep 2025 08:38:14 +0000 Subject: [PATCH 2/2] updated lit tests Signed-off-by: xintin --- .../Conversion/AMDGPUToROCDL/permlane.mlir | 97 ++++++++++++++----- .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir | 14 ++- 2 files changed, 81 insertions(+), 30 deletions(-) mode change 100644 => 100755 mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir mode change 100644 => 100755 mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir old mode 100644 new mode 100755 index aae2b1d0fd90c..a92321da8f357 --- a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir @@ -4,8 +4,11 @@ // CHECK-SAME: (%[[ARG0:.*]]: i32) func.func @test_permlane16_i32(%arg0 : i32) -> i32 { // CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: return %[[RES]] : i32 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: return %[[SEL]] : i32 %0 = amdgpu.permlane_swap %arg0 16 : i32 return %0 : i32 } @@ -14,8 +17,11 @@ func.func @test_permlane16_i32(%arg0 : i32) -> i32 { // CHECK-SAME: (%[[ARG0:.*]]: i32) func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 { // CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: return %[[RES]] : i32 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: return %[[SEL]] : i32 %0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32 return %0 : i32 } @@ -24,8 +30,11 @@ func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 { // CHECK-SAME: (%[[ARG0:.*]]: i32) func.func @test_permlane32_i32(%arg0 : i32) -> i32 { // CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: return %[[RES]] : i32 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: return %[[SEL]] : i32 %0 = amdgpu.permlane_swap %arg0 32 : i32 return %0 : i32 } @@ -35,8 +44,11 @@ func.func @test_permlane32_i32(%arg0 : i32) -> i32 { func.func @test_permlane16_f32(%arg0 : f32) -> f32 { // CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32 // CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32 // CHECK: return %[[RES_CAST]] : f32 %0 = amdgpu.permlane_swap %arg0 16 : f32 return %0 : f32 @@ -47,8 +59,11 @@ func.func @test_permlane16_f32(%arg0 : f32) -> f32 { func.func @test_permlane32_f32(%arg0 : f32) -> f32 { // CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32 // CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32 // CHECK: return %[[RES_CAST]] : f32 %0 = amdgpu.permlane_swap %arg0 32 : f32 return %0 : f32 @@ -60,8 +75,11 @@ func.func @test_permlane16_f16(%arg0 : f16) -> f16 { // CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16 // CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32 // CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16 // CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16 // CHECK: return %[[RES_CAST]] : f16 %0 = amdgpu.permlane_swap %arg0 16 : f16 @@ -74,8 +92,11 @@ func.func @test_permlane32_f16(%arg0 : f16) -> f16 { // CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16 // CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32 // CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> -// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16 +// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32 +// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32 +// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16 // CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16 // CHECK: return %[[RES_CAST]] : f16 %0 = amdgpu.permlane_swap %arg0 32 : f16 @@ -90,10 +111,16 @@ func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> { // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32> -// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> -// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[T0:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32 +// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32 +// CHECK: %[[T1:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32 +// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32 // CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> // CHECK: return %[[VEC_INSERT1]] : vector<2xi32> @@ -109,10 +136,16 @@ func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> { // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32> -// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> -// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[T0:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32 +// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32 +// CHECK: %[[T1:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32 +// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32 // CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> // CHECK: return %[[VEC_INSERT1]] : vector<2xi32> @@ -130,9 +163,15 @@ func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> { // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32> // CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32 +// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32 // CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32 +// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32 // CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> // CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16> @@ -151,9 +190,15 @@ func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> { // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32> // CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32 +// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32 // CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> -// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)> +// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32 +// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32 // CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> // CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> // CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16> diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir old mode 100644 new mode 100755 index c6261b37ef8f2..ef631ce8a12e5 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -749,13 +749,19 @@ gpu.module @test_module { %shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32 // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)> - // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)> - // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32 + // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)> + // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)> + // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32 + // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32 + // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32 %shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32 // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32 // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)> - // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)> - // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32 + // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)> + // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)> + // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32 + // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32 + // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32 %shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32 func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32 }