-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][AMDGPU] Updated PermlaneSwapOp
to select correct val
#157586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Updated PermlaneSwapOp
to select correct val
#157586
Conversation
PermlaneSwapOp
to select correct valPermlaneSwapOp
to select correct val
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Gaurav Verma (xintin) Changes
Full diff: https://github.com/llvm/llvm-project/pull/157586.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 203790ed95153..0078eed8b7a67 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1915,7 +1915,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
else
llvm_unreachable("unsupported row length");
- Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+
+ const Value isEqual =
+ rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+
+ // Per `permlane(16|32)` semantics: if the first extracted element equals
+ // 'v', the result is the second element; otherwise it is the first.
+ Value vdstNew =
+ rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
old mode 100644
new mode 100755
index aae2b1d0fd90c..a92321da8f357
--- a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
@@ -4,8 +4,11 @@
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: return %[[RES]] : i32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: return %[[SEL]] : i32
%0 = amdgpu.permlane_swap %arg0 16 : i32
return %0 : i32
}
@@ -14,8 +17,11 @@ func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: return %[[RES]] : i32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: return %[[SEL]] : i32
%0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32
return %0 : i32
}
@@ -24,8 +30,11 @@ func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: return %[[RES]] : i32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: return %[[SEL]] : i32
%0 = amdgpu.permlane_swap %arg0 32 : i32
return %0 : i32
}
@@ -35,8 +44,11 @@ func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 16 : f32
return %0 : f32
@@ -47,8 +59,11 @@ func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 32 : f32
return %0 : f32
@@ -60,8 +75,11 @@ func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 16 : f16
@@ -74,8 +92,11 @@ func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 32 : f16
@@ -90,10 +111,16 @@ func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
-// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
+// CHECK: %[[T1:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
@@ -109,10 +136,16 @@ func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
-// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
+// CHECK: %[[T1:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
@@ -130,9 +163,15 @@ func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
@@ -151,9 +190,15 @@ func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
old mode 100644
new mode 100755
index c6261b37ef8f2..ef631ce8a12e5
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -749,13 +749,19 @@ gpu.module @test_module {
%shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
- // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
- // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
+ // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
- // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
- // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
+ // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
}
|
@llvm/pr-subscribers-backend-amdgpu Author: Gaurav Verma (xintin) Changes
Full diff: https://github.com/llvm/llvm-project/pull/157586.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 203790ed95153..0078eed8b7a67 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1915,7 +1915,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
else
llvm_unreachable("unsupported row length");
- Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+
+ const Value isEqual =
+ rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+
+ // Per `permlane(16|32)` semantics: if the first extracted element equals
+ // 'v', the result is the second element; otherwise it is the first.
+ Value vdstNew =
+ rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
old mode 100644
new mode 100755
index aae2b1d0fd90c..a92321da8f357
--- a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
@@ -4,8 +4,11 @@
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: return %[[RES]] : i32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: return %[[SEL]] : i32
%0 = amdgpu.permlane_swap %arg0 16 : i32
return %0 : i32
}
@@ -14,8 +17,11 @@ func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: return %[[RES]] : i32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: return %[[SEL]] : i32
%0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32
return %0 : i32
}
@@ -24,8 +30,11 @@ func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: return %[[RES]] : i32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ARG0]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: return %[[SEL]] : i32
%0 = amdgpu.permlane_swap %arg0 32 : i32
return %0 : i32
}
@@ -35,8 +44,11 @@ func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 16 : f32
return %0 : f32
@@ -47,8 +59,11 @@ func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[CAST]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[SEL]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 32 : f32
return %0 : f32
@@ -60,8 +75,11 @@ func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 16 : f16
@@ -74,8 +92,11 @@ func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[E0:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[E1:.*]] = llvm.extractvalue %[[PERM]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[E0]], %[[ZEXT]] : i32
+// CHECK: %[[SEL:.*]] = llvm.select %[[CMP]], %[[E1]], %[[E0]] : i1, i32
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SEL]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 32 : f16
@@ -90,10 +111,16 @@ func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
-// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
+// CHECK: %[[T1:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
@@ -109,10 +136,16 @@ func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
-// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
-// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T0_0:.*]] = llvm.extractvalue %[[T0]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T0_1:.*]] = llvm.extractvalue %[[T0]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[T0_0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[T0_1]], %[[T0_0]] : i1, i32
+// CHECK: %[[T1:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[T1_0:.*]] = llvm.extractvalue %[[T1]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[T1_1:.*]] = llvm.extractvalue %[[T1]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[T1_0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[T1_1]], %[[T1_0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
@@ -130,9 +163,15 @@ func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
@@ -151,9 +190,15 @@ func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM0_E1:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP0:.*]] = llvm.icmp "eq" %[[PERM0_E0]], %[[ELEM0]] : i32
+// CHECK: %[[PERM0:.*]] = llvm.select %[[CMP0]], %[[PERM0_E1]], %[[PERM0_E0]] : i1, i32
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
-// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E0:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_E1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][1] : !llvm.struct<(i32, i32)>
+// CHECK: %[[CMP1:.*]] = llvm.icmp "eq" %[[PERM1_E0]], %[[ELEM1]] : i32
+// CHECK: %[[PERM1:.*]] = llvm.select %[[CMP1]], %[[PERM1_E1]], %[[PERM1_E0]] : i1, i32
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
old mode 100644
new mode 100755
index c6261b37ef8f2..ef631ce8a12e5
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -749,13 +749,19 @@ gpu.module @test_module {
%shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
- // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
- // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
+ // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
- // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
- // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ // CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
+ // CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like I could use a bit more context as to why the lowering is like this, but the code style seems fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to check - is this to ensure consistent semantics for inactive lanes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is to be consistent with the thread ids of VDST which remain unchanged for both, PERMLANE16_SWAP
and PERMLANE32_SWAP
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I'm flagging is that I'm trying to work out why, for example, we have
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
earlier - clearly we're not lowering to the full generality of the instruction.
And I don't remember the amdgpu
op actually documenting its semantics, unless that changed while I wasn't looking.
This might mean that the amdgpu
op we're lowering from could use an update to its documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On further reading, I think I have some broad understanding of why this is set up the way it is, so I think we can land it.
37a7670
to
5c4cbfc
Compare
Signed-off-by: xintin <gaurav.verma@amd.com>
Signed-off-by: xintin <gaurav.verma@amd.com>
27f784b
to
d15f560
Compare
PermlaneSwapOp
to select correct valIssue it resolves: the block reduction was failing otherwise as we were selecting the
{0}
always.