Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Oct 30, 2025

This is a fix for a cluster size of 32 when the subgroup size is 64. Previously, only lanes [16, 32) u [48, 64) contained the correct clusterwise reduction value. This PR adds a swizzle instruction to broadcast the correct value down to lanes [0, 16) u [32, 48).

@newling newling marked this pull request as ready for review October 30, 2025 18:50
@newling newling requested a review from fabianmcg as a code owner October 30, 2025 18:50
@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2025

@llvm/pr-subscribers-mlir-gpu

Author: James Newling (newling)

Changes

This is a fix for a cluster size of 32 when the subgroup size is 64. Previously, only lanes [16, 32) u [48, 64) contained the correct clusterwise reduction value. This PR adds a swizzle instruction to broadcast the correct value down to lanes [0, 16) u [32, 48).


Full diff: https://github.com/llvm/llvm-project/pull/165764.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+26-1)
  • (modified) mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir (+82-90)
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 81c3069cec16e..680a46c7f68ac 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -416,13 +416,38 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
   if (ci.clusterSize >= 32) {
     if (chipset.majorVersion <= 9) {
       // Broadcast last value from each row to next row.
-      // Use row mask to avoid polluting rows 1 and 3.
+      // Use row mask to avoid polluting row 0 (and row 2 if wave-64).
       dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
                                   amdgpu::DPPPerm::row_bcast_15,
                                   rewriter.getUnitAttr(), 0xa, allBanks,
                                   /*bound_ctrl*/ false);
       res = vector::makeArithReduction(
           rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
+
+      // For subgroupSize = 64, at this point lanes [16, 32) contain the full
+      // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly,
+      // lanes [48, 64) contain the full reduction over lanes [32, 64), but
+      // lanes [32, 48) do not.
+      //
+      // If subgroup size is 64 and cluster size is 64, we don't need lanes [0,
+      // 16) and [32, 48) to have the correct cluster-32 reduction values at
+      // this point, because only lane 63's value will ultimately be read in
+      // this full-cluster case.
+      //
+      // If subgroup size is 64 and cluster size is 32, we need to ensure that
+      // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction
+      // values (subgroup_reduce guarantees that all lanes within each cluster
+      // contain the final reduction value). We do this by broadcasting lane
+      // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48).
+      //
+      // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations
+      // for an illustration of how this within-cluster broadcast works with a
+      // swizzle.
+      if (ci.subgroupSize == 64 && ci.clusterSize == 32) {
+        res = amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and*/ 0,
+                                               /*or*/ 31,
+                                               /*xor*/ 0);
+      }
     } else if (chipset.majorVersion <= 12) {
       // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
       Value uint32Max = arith::ConstantOp::create(
diff --git a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
index 87a31ca20eb7b..1adc4181e05d3 100644
--- a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
@@ -8,11 +8,11 @@
 
 // RUN: mlir-opt --allow-unregistered-dialect \
 // RUN:   --test-gpu-subgroup-reduce-lowering="expand-to-shuffles target=gfx942" %s \
-// RUN:   | FileCheck %s --check-prefix=CHECK-GFX9
+// RUN:   | FileCheck %s --check-prefixes=CHECK-GFX,CHECK-GFX9
 
 // RUN: mlir-opt --allow-unregistered-dialect \
 // RUN:   --test-gpu-subgroup-reduce-lowering="expand-to-shuffles target=gfx1030" %s \
-// RUN:   | FileCheck %s --check-prefix=CHECK-GFX10
+// RUN:   | FileCheck %s --check-prefixes=CHECK-GFX,CHECK-GFX10
 
 // CHECK-SUB:  gpu.module @kernels {
 // CHECK-SHFL: gpu.module @kernels {
@@ -24,8 +24,7 @@ gpu.module @kernels {
   // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<5xf16>)
   //
   // CHECK-SHFL-LABEL: gpu.func @kernel0(
-  // CHECK-GFX9-LABEL: gpu.func @kernel0(
-  // CHECK-GFX10-LABEL: gpu.func @kernel0(
+  // CHECK-GFX-LABEL: gpu.func @kernel0(
   gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
     // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
     // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
@@ -56,8 +55,7 @@ gpu.module @kernels {
 
     // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster(size = 4)
     // CHECK-SUB: "test.consume"
-    // CHECK-GFX9-COUNT-2: amdgpu.dpp {{.+}}
-    // CHECK-GFX10-COUNT-2: amdgpu.dpp {{.+}}
+    // CHECK-GFX-COUNT-2: amdgpu.dpp {{.+}}
     %sum2 = gpu.subgroup_reduce mul %arg0 cluster(size = 4) : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum2) : (vector<5xf16>) -> ()
 
@@ -74,8 +72,7 @@ gpu.module @kernels {
   // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<1xf32>)
   //
   // CHECK-SHFL-LABEL: gpu.func @kernel1(
-  // CHECK-GFX9-LABEL: gpu.func @kernel1(
-  // CHECK-GFX10-LABEL: gpu.func @kernel1(
+  // CHECK-GFX-LABEL: gpu.func @kernel1(
   gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
     // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
     // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
@@ -100,17 +97,14 @@ gpu.module @kernels {
     // Note stride is dropped because it is == 1.
     // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster(size = 8) : (f32) -> f32
     // CHECK-SUB: "test.consume"
-    // CHECK-GFX9-COUNT-2: amdgpu.dpp {{.+}} quad_perm
-    // CHECK-GFX9: amdgpu.dpp {{.+}} row_half_mirror
-    // CHECK-GFX10-COUNT-2: amdgpu.dpp {{.+}} quad_perm
-    // CHECK-GFX10: amdgpu.dpp {{.+}} row_half_mirror
+    // CHECK-GFX-COUNT-2: amdgpu.dpp {{.+}} quad_perm
+    // CHECK-GFX: amdgpu.dpp {{.+}} row_half_mirror
     %sum2 = gpu.subgroup_reduce add %arg0 cluster(size = 8, stride = 1) : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum2) : (vector<1xf32>) -> ()
 
     // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster(size = 8, stride = 4) : (f32) -> f32
     // CHECK-SUB: "test.consume"
-    // CHECK-GFX9-NOT: amdgpu.dpp
-    // CHECK-GFX10-NOT: amdgpu.dpp
+    // CHECK-GFX-NOT: amdgpu.dpp
     // CHECK-GFX10-NOT: rocdl.permlanex16
     %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster(size = 8, stride = 4) : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum3) : (vector<1xf32>) -> ()
@@ -126,11 +120,8 @@ gpu.module @kernels {
   //
   // CHECK-SHFL-LABEL: gpu.func @kernel2(
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel2(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel2(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel2(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
     // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
     // CHECK-SUB: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
@@ -148,8 +139,7 @@ gpu.module @kernels {
 
   // CHECK-SHFL-LABEL: gpu.func @kernel3(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
-  // CHECK-GFX9-LABEL: gpu.func @kernel3(
-  // CHECK-GFX10-LABEL: gpu.func @kernel3(
+  // CHECK-GFX-LABEL: gpu.func @kernel3(
   gpu.func @kernel3(%arg0: i32) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -169,9 +159,9 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[S4:.+]], %{{.+}} = gpu.shuffle xor %[[A3]], %[[C16]], %[[C32]] : i32
     // CHECK-SHFL: %[[A4:.+]] = arith.addi %[[A3]], %[[S4]] : i32
     // CHECK-SHFL: "test.consume"(%[[A4]]) : (i32) -> ()
-    
+
     // CHECK-GFX9-COUNT-6: amdgpu.dpp
-    
+
     // CHECK-GFX10-COUNT-4: amdgpu.dpp
     // CHECK-GFX10: rocdl.permlanex16
     // CHECK-GFX10-COUNT-2: rocdl.readlane
@@ -185,11 +175,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel3_clustered(
-  // CHECK-GFX9-SAME:    %[[ARG0:.+]]: i32)
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel3_clustered(
-  // CHECK-GFX10-SAME:    %[[ARG0:.+]]: i32)
+  // CHECK-GFX-LABEL: gpu.func @kernel3_clustered(
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: i32)
   gpu.func @kernel3_clustered(%arg0: i32) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -204,19 +191,13 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
     // CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
 
-    // CHECK-GFX9: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX9: %[[A0:.+]] = arith.addi %[[ARG0]], %[[D0]] : i32
-    // CHECK-GFX9: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX9: %[[A1:.+]] = arith.addi %[[A0]], %[[D1]] : i32
-    // CHECK-GFX9: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : i32
-    // CHECK-GFX9: %[[A2:.+]] = arith.addi %[[A1]], %[[D2]] : i32
-
-    // CHECK-GFX10: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX10: %[[A0:.+]] = arith.addi %[[ARG0]], %[[D0]] : i32
-    // CHECK-GFX10: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX10: %[[A1:.+]] = arith.addi %[[A0]], %[[D1]] : i32
-    // CHECK-GFX10: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : i32
-    // CHECK-GFX10: %[[A2:.+]] = arith.addi %[[A1]], %[[D2]] : i32
+    // CHECK-GFX: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i32
+    // CHECK-GFX: %[[A0:.+]] = arith.addi %[[ARG0]], %[[D0]] : i32
+    // CHECK-GFX: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i32
+    // CHECK-GFX: %[[A1:.+]] = arith.addi %[[A0]], %[[D1]] : i32
+    // CHECK-GFX: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : i32
+    // CHECK-GFX: %[[A2:.+]] = arith.addi %[[A1]], %[[D2]] : i32
+
     // CHECK-GFX10: "test.consume"(%[[A2]]) : (i32) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 8) : (i32) -> i32
     "test.consume"(%sum0) : (i32) -> ()
@@ -228,11 +209,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered_strided(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel3_clustered_strided(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel3_clustered_strided(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel3_clustered_strided(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel3_clustered_strided(%arg0: i32) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 4 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 8 : i32
@@ -256,11 +234,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel4(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel4(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel4(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel4(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel4(%arg0: vector<2xf16>) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -298,11 +273,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel4_clustered(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel4_clustered(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel4_clustered(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel4_clustered(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel4_clustered(%arg0: vector<2xf16>) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -319,10 +291,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel5(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel5(
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel5(
-  // CHECK-GFX10-SAME:    %[[ARG0:.+]]: i16)
+  // CHECK-GFX-LABEL: gpu.func @kernel5(
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: i16)
   gpu.func @kernel5(%arg0: i16) kernel {
     // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
     // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
@@ -334,7 +304,7 @@ gpu.module @kernels {
     // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
     // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
     // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
-    
+
     // CHECK-GFX9-COUNT-6: amdgpu.dpp
 
     // CHECK-GFX10: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
@@ -361,11 +331,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel5_clustered(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel5_clustered
-  // CHECK-GFX9-SAME:    %[[ARG0:.+]]: i16)
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel5_clustered
-  // CHECK-GFX10-SAME:    %[[ARG0:.+]]: i16)
+  // CHECK-GFX-LABEL: gpu.func @kernel5_clustered
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: i16)
   gpu.func @kernel5_clustered(%arg0: i16) kernel {
     // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
     // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
@@ -378,25 +345,15 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
     // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
 
-    // CHECK-GFX9: %[[VAR0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR1:.+]] = arith.addi %[[ARG0]], %[[VAR0]] : i16
-    // CHECK-GFX9: %[[VAR2:.+]] = amdgpu.dpp %[[VAR1]] %[[VAR1]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR3:.+]] = arith.addi %[[VAR1]], %[[VAR2]] : i16
-    // CHECK-GFX9: %[[VAR4:.+]] = amdgpu.dpp %[[VAR3]] %[[VAR3]]  row_half_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR5:.+]] = arith.addi %[[VAR3]], %[[VAR4]] : i16
-    // CHECK-GFX9: %[[VAR6:.+]] = amdgpu.dpp %[[VAR5]] %[[VAR5]]  row_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR7:.+]] = arith.addi %[[VAR5]], %[[VAR6]] : i16
-    // CHECK-GFX9: "test.consume"(%[[VAR7]]) : (i16) -> ()
-
-    // CHECK-GFX10: %[[VAR0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR1:.+]] = arith.addi %[[ARG0]], %[[VAR0]] : i16
-    // CHECK-GFX10: %[[VAR2:.+]] = amdgpu.dpp %[[VAR1]] %[[VAR1]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR3:.+]] = arith.addi %[[VAR1]], %[[VAR2]] : i16
-    // CHECK-GFX10: %[[VAR4:.+]] = amdgpu.dpp %[[VAR3]] %[[VAR3]]  row_half_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR5:.+]] = arith.addi %[[VAR3]], %[[VAR4]] : i16
-    // CHECK-GFX10: %[[VAR6:.+]] = amdgpu.dpp %[[VAR5]] %[[VAR5]]  row_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR7:.+]] = arith.addi %[[VAR5]], %[[VAR6]] : i16
-    // CHECK-GFX10: "test.consume"(%[[VAR7]]) : (i16) -> ()
+    // CHECK-GFX: %[[VAR0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR1:.+]] = arith.addi %[[ARG0]], %[[VAR0]] : i16
+    // CHECK-GFX: %[[VAR2:.+]] = amdgpu.dpp %[[VAR1]] %[[VAR1]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR3:.+]] = arith.addi %[[VAR1]], %[[VAR2]] : i16
+    // CHECK-GFX: %[[VAR4:.+]] = amdgpu.dpp %[[VAR3]] %[[VAR3]]  row_half_mirror(unit) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR5:.+]] = arith.addi %[[VAR3]], %[[VAR4]] : i16
+    // CHECK-GFX: %[[VAR6:.+]] = amdgpu.dpp %[[VAR5]] %[[VAR5]]  row_mirror(unit) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR7:.+]] = arith.addi %[[VAR5]], %[[VAR6]] : i16
+    // CHECK-GFX: "test.consume"(%[[VAR7]]) : (i16) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 16) : (i16) -> i16
     "test.consume"(%sum0) : (i16) -> ()
 
@@ -407,11 +364,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel6(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel6(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel6(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel6(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel6(%arg0: vector<3xi8>) kernel {
     // CHECK-SHFL: %[[CZ:.+]] = arith.constant dense<0> : vector<4xi8>
     // CHECK-SHFL: %[[V0:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CZ]] {offsets = [0], strides = [1]} : vector<3xi8> into vector<4xi8>
@@ -433,6 +387,44 @@ gpu.module @kernels {
     gpu.return
   }
 
+  // CHECK-GFX-LABEL: gpu.func @kernel7(
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: f32)
+  //
+  //   Checks, common to gfx942 and gfx1030, of
+  //     (1) quad_perm, followed by reduction resulting in reduction over 2 consecutive lanes,
+  //     (2) quad_perm, followed by reduction resulting in reduction over 4 consecutive lanes,
+  //     (3) row_half_mirror, followed by reduction resulting in reduction over 8 consecutive lanes, and
+  //     (4) row_mirror, followed by reduction resulting in reduction over 16 consecutive lanes.
+  // CHECK-GFX: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A0:.+]] = arith.addf %[[ARG0]], %[[D0]] : f32
+  // CHECK-GFX: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A1:.+]] = arith.addf %[[A0]], %[[D1]] : f32
+  // CHECK-GFX: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A2:.+]] = arith.addf %[[A1]], %[[D2]] : f32
+  // CHECK-GFX: %[[D3:.+]] = amdgpu.dpp %[[A2]] %[[A2]]  row_mirror(unit) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A3:.+]] = arith.addf %[[A2]], %[[D3]] : f32
+  //
+  //   Now, on gfx942:
+  //     (1) Lane 15 gets broadcast to lanes [16, 32) and lane 31 gets broadcast to lanes [48, 64], after which
+  //         the reduction in lanes [16, 32) is over the full cluster of the first 32 lanes, and the reduction in lanes
+  //         [48, 64) is over the full cluster of the last 32 lanes.
+  //     (2) Update the reduction value in lanes [0, 16) and [32, 48) with the final reduction result from
+  //         lanes [16, 32) and [48, 64), respectively.
+  // CHECK-GFX9: %[[BCAST15:.+]] = amdgpu.dpp %[[A3]] %[[A3]]  row_bcast_15(unit) {row_mask = 10 : i32} : f32
+  // CHECK-GFX9: %[[SUM:.+]] = arith.addf %[[A3]], %[[BCAST15]] : f32
+  // CHECK-GFX9: %[[SWIZ:.+]] = amdgpu.swizzle_bitmode %[[SUM]] 0 31 0 : f32
+  // CHECK-GFX9: "test.consume"(%[[SWIZ]]) : (f32) -> ()
+  //
+  //   On gfx1030, the final step is to permute the lanes and perform final reduction:
+  // CHECK-GFX10: rocdl.permlanex16
+  // CHECK-GFX10: arith.addf
+  // CHECK-GFX10: "test.consume"
+   gpu.func @kernel7(%arg0: f32) kernel {
+     %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 32) : (f32) -> (f32)
+     "test.consume"(%sum0) : (f32) -> ()
+     gpu.return
+   }
+
   // CHECK-SHFL-LABEL: gpu.func @kernel_cluster_size_is_subgroup_size(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
   //

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2025

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

This is a fix for a cluster size of 32 when the subgroup size is 64. Previously, only lanes [16, 32) u [48, 64) contained the correct clusterwise reduction value. This PR adds a swizzle instruction to broadcast the correct value down to lanes [0, 16) u [32, 48).


Full diff: https://github.com/llvm/llvm-project/pull/165764.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (+26-1)
  • (modified) mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir (+82-90)
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 81c3069cec16e..680a46c7f68ac 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -416,13 +416,38 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
   if (ci.clusterSize >= 32) {
     if (chipset.majorVersion <= 9) {
       // Broadcast last value from each row to next row.
-      // Use row mask to avoid polluting rows 1 and 3.
+      // Use row mask to avoid polluting row 0 (and row 2 if wave-64).
       dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
                                   amdgpu::DPPPerm::row_bcast_15,
                                   rewriter.getUnitAttr(), 0xa, allBanks,
                                   /*bound_ctrl*/ false);
       res = vector::makeArithReduction(
           rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
+
+      // For subgroupSize = 64, at this point lanes [16, 32) contain the full
+      // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly,
+      // lanes [48, 64) contain the full reduction over lanes [32, 64), but
+      // lanes [32, 48) do not.
+      //
+      // If subgroup size is 64 and cluster size is 64, we don't need lanes [0,
+      // 16) and [32, 48) to have the correct cluster-32 reduction values at
+      // this point, because only lane 63's value will ultimately be read in
+      // this full-cluster case.
+      //
+      // If subgroup size is 64 and cluster size is 32, we need to ensure that
+      // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction
+      // values (subgroup_reduce guarantees that all lanes within each cluster
+      // contain the final reduction value). We do this by broadcasting lane
+      // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48).
+      //
+      // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations
+      // for an illustration of how this within-cluster broadcast works with a
+      // swizzle.
+      if (ci.subgroupSize == 64 && ci.clusterSize == 32) {
+        res = amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and*/ 0,
+                                               /*or*/ 31,
+                                               /*xor*/ 0);
+      }
     } else if (chipset.majorVersion <= 12) {
       // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
       Value uint32Max = arith::ConstantOp::create(
diff --git a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
index 87a31ca20eb7b..1adc4181e05d3 100644
--- a/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-reduce-lowering.mlir
@@ -8,11 +8,11 @@
 
 // RUN: mlir-opt --allow-unregistered-dialect \
 // RUN:   --test-gpu-subgroup-reduce-lowering="expand-to-shuffles target=gfx942" %s \
-// RUN:   | FileCheck %s --check-prefix=CHECK-GFX9
+// RUN:   | FileCheck %s --check-prefixes=CHECK-GFX,CHECK-GFX9
 
 // RUN: mlir-opt --allow-unregistered-dialect \
 // RUN:   --test-gpu-subgroup-reduce-lowering="expand-to-shuffles target=gfx1030" %s \
-// RUN:   | FileCheck %s --check-prefix=CHECK-GFX10
+// RUN:   | FileCheck %s --check-prefixes=CHECK-GFX,CHECK-GFX10
 
 // CHECK-SUB:  gpu.module @kernels {
 // CHECK-SHFL: gpu.module @kernels {
@@ -24,8 +24,7 @@ gpu.module @kernels {
   // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<5xf16>)
   //
   // CHECK-SHFL-LABEL: gpu.func @kernel0(
-  // CHECK-GFX9-LABEL: gpu.func @kernel0(
-  // CHECK-GFX10-LABEL: gpu.func @kernel0(
+  // CHECK-GFX-LABEL: gpu.func @kernel0(
   gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
     // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
     // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
@@ -56,8 +55,7 @@ gpu.module @kernels {
 
     // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} cluster(size = 4)
     // CHECK-SUB: "test.consume"
-    // CHECK-GFX9-COUNT-2: amdgpu.dpp {{.+}}
-    // CHECK-GFX10-COUNT-2: amdgpu.dpp {{.+}}
+    // CHECK-GFX-COUNT-2: amdgpu.dpp {{.+}}
     %sum2 = gpu.subgroup_reduce mul %arg0 cluster(size = 4) : (vector<5xf16>) -> (vector<5xf16>)
     "test.consume"(%sum2) : (vector<5xf16>) -> ()
 
@@ -74,8 +72,7 @@ gpu.module @kernels {
   // CHECK-SUB-SAME:     %[[ARG0:.+]]: vector<1xf32>)
   //
   // CHECK-SHFL-LABEL: gpu.func @kernel1(
-  // CHECK-GFX9-LABEL: gpu.func @kernel1(
-  // CHECK-GFX10-LABEL: gpu.func @kernel1(
+  // CHECK-GFX-LABEL: gpu.func @kernel1(
   gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
     // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
     // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
@@ -100,17 +97,14 @@ gpu.module @kernels {
     // Note stride is dropped because it is == 1.
     // CHECK-SUB: gpu.subgroup_reduce add {{.+}} cluster(size = 8) : (f32) -> f32
     // CHECK-SUB: "test.consume"
-    // CHECK-GFX9-COUNT-2: amdgpu.dpp {{.+}} quad_perm
-    // CHECK-GFX9: amdgpu.dpp {{.+}} row_half_mirror
-    // CHECK-GFX10-COUNT-2: amdgpu.dpp {{.+}} quad_perm
-    // CHECK-GFX10: amdgpu.dpp {{.+}} row_half_mirror
+    // CHECK-GFX-COUNT-2: amdgpu.dpp {{.+}} quad_perm
+    // CHECK-GFX: amdgpu.dpp {{.+}} row_half_mirror
     %sum2 = gpu.subgroup_reduce add %arg0 cluster(size = 8, stride = 1) : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum2) : (vector<1xf32>) -> ()
 
     // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform cluster(size = 8, stride = 4) : (f32) -> f32
     // CHECK-SUB: "test.consume"
-    // CHECK-GFX9-NOT: amdgpu.dpp
-    // CHECK-GFX10-NOT: amdgpu.dpp
+    // CHECK-GFX-NOT: amdgpu.dpp
     // CHECK-GFX10-NOT: rocdl.permlanex16
     %sum3 = gpu.subgroup_reduce add %arg0 uniform cluster(size = 8, stride = 4) : (vector<1xf32>) -> (vector<1xf32>)
     "test.consume"(%sum3) : (vector<1xf32>) -> ()
@@ -126,11 +120,8 @@ gpu.module @kernels {
   //
   // CHECK-SHFL-LABEL: gpu.func @kernel2(
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel2(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel2(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel2(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
     // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
     // CHECK-SUB: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
@@ -148,8 +139,7 @@ gpu.module @kernels {
 
   // CHECK-SHFL-LABEL: gpu.func @kernel3(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
-  // CHECK-GFX9-LABEL: gpu.func @kernel3(
-  // CHECK-GFX10-LABEL: gpu.func @kernel3(
+  // CHECK-GFX-LABEL: gpu.func @kernel3(
   gpu.func @kernel3(%arg0: i32) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -169,9 +159,9 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[S4:.+]], %{{.+}} = gpu.shuffle xor %[[A3]], %[[C16]], %[[C32]] : i32
     // CHECK-SHFL: %[[A4:.+]] = arith.addi %[[A3]], %[[S4]] : i32
     // CHECK-SHFL: "test.consume"(%[[A4]]) : (i32) -> ()
-    
+
     // CHECK-GFX9-COUNT-6: amdgpu.dpp
-    
+
     // CHECK-GFX10-COUNT-4: amdgpu.dpp
     // CHECK-GFX10: rocdl.permlanex16
     // CHECK-GFX10-COUNT-2: rocdl.readlane
@@ -185,11 +175,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel3_clustered(
-  // CHECK-GFX9-SAME:    %[[ARG0:.+]]: i32)
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel3_clustered(
-  // CHECK-GFX10-SAME:    %[[ARG0:.+]]: i32)
+  // CHECK-GFX-LABEL: gpu.func @kernel3_clustered(
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: i32)
   gpu.func @kernel3_clustered(%arg0: i32) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -204,19 +191,13 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
     // CHECK-SHFL: "test.consume"(%[[A2]]) : (i32) -> ()
 
-    // CHECK-GFX9: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX9: %[[A0:.+]] = arith.addi %[[ARG0]], %[[D0]] : i32
-    // CHECK-GFX9: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX9: %[[A1:.+]] = arith.addi %[[A0]], %[[D1]] : i32
-    // CHECK-GFX9: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : i32
-    // CHECK-GFX9: %[[A2:.+]] = arith.addi %[[A1]], %[[D2]] : i32
-
-    // CHECK-GFX10: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX10: %[[A0:.+]] = arith.addi %[[ARG0]], %[[D0]] : i32
-    // CHECK-GFX10: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i32
-    // CHECK-GFX10: %[[A1:.+]] = arith.addi %[[A0]], %[[D1]] : i32
-    // CHECK-GFX10: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : i32
-    // CHECK-GFX10: %[[A2:.+]] = arith.addi %[[A1]], %[[D2]] : i32
+    // CHECK-GFX: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i32
+    // CHECK-GFX: %[[A0:.+]] = arith.addi %[[ARG0]], %[[D0]] : i32
+    // CHECK-GFX: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i32
+    // CHECK-GFX: %[[A1:.+]] = arith.addi %[[A0]], %[[D1]] : i32
+    // CHECK-GFX: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : i32
+    // CHECK-GFX: %[[A2:.+]] = arith.addi %[[A1]], %[[D2]] : i32
+
     // CHECK-GFX10: "test.consume"(%[[A2]]) : (i32) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 8) : (i32) -> i32
     "test.consume"(%sum0) : (i32) -> ()
@@ -228,11 +209,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel3_clustered_strided(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i32)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel3_clustered_strided(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel3_clustered_strided(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel3_clustered_strided(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel3_clustered_strided(%arg0: i32) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 4 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 8 : i32
@@ -256,11 +234,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel4(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel4(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel4(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel4(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel4(%arg0: vector<2xf16>) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -298,11 +273,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel4_clustered(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<2xf16>)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel4_clustered(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel4_clustered(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel4_clustered(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel4_clustered(%arg0: vector<2xf16>) kernel {
     // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
     // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
@@ -319,10 +291,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel5(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel5(
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel5(
-  // CHECK-GFX10-SAME:    %[[ARG0:.+]]: i16)
+  // CHECK-GFX-LABEL: gpu.func @kernel5(
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: i16)
   gpu.func @kernel5(%arg0: i16) kernel {
     // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
     // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
@@ -334,7 +304,7 @@ gpu.module @kernels {
     // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
     // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
     // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
-    
+
     // CHECK-GFX9-COUNT-6: amdgpu.dpp
 
     // CHECK-GFX10: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
@@ -361,11 +331,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel5_clustered(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: i16)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel5_clustered
-  // CHECK-GFX9-SAME:    %[[ARG0:.+]]: i16)
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel5_clustered
-  // CHECK-GFX10-SAME:    %[[ARG0:.+]]: i16)
+  // CHECK-GFX-LABEL: gpu.func @kernel5_clustered
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: i16)
   gpu.func @kernel5_clustered(%arg0: i16) kernel {
     // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
     // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
@@ -378,25 +345,15 @@ gpu.module @kernels {
     // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
     // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
 
-    // CHECK-GFX9: %[[VAR0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR1:.+]] = arith.addi %[[ARG0]], %[[VAR0]] : i16
-    // CHECK-GFX9: %[[VAR2:.+]] = amdgpu.dpp %[[VAR1]] %[[VAR1]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR3:.+]] = arith.addi %[[VAR1]], %[[VAR2]] : i16
-    // CHECK-GFX9: %[[VAR4:.+]] = amdgpu.dpp %[[VAR3]] %[[VAR3]]  row_half_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR5:.+]] = arith.addi %[[VAR3]], %[[VAR4]] : i16
-    // CHECK-GFX9: %[[VAR6:.+]] = amdgpu.dpp %[[VAR5]] %[[VAR5]]  row_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX9: %[[VAR7:.+]] = arith.addi %[[VAR5]], %[[VAR6]] : i16
-    // CHECK-GFX9: "test.consume"(%[[VAR7]]) : (i16) -> ()
-
-    // CHECK-GFX10: %[[VAR0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR1:.+]] = arith.addi %[[ARG0]], %[[VAR0]] : i16
-    // CHECK-GFX10: %[[VAR2:.+]] = amdgpu.dpp %[[VAR1]] %[[VAR1]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR3:.+]] = arith.addi %[[VAR1]], %[[VAR2]] : i16
-    // CHECK-GFX10: %[[VAR4:.+]] = amdgpu.dpp %[[VAR3]] %[[VAR3]]  row_half_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR5:.+]] = arith.addi %[[VAR3]], %[[VAR4]] : i16
-    // CHECK-GFX10: %[[VAR6:.+]] = amdgpu.dpp %[[VAR5]] %[[VAR5]]  row_mirror(unit) {bound_ctrl = true} : i16
-    // CHECK-GFX10: %[[VAR7:.+]] = arith.addi %[[VAR5]], %[[VAR6]] : i16
-    // CHECK-GFX10: "test.consume"(%[[VAR7]]) : (i16) -> ()
+    // CHECK-GFX: %[[VAR0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR1:.+]] = arith.addi %[[ARG0]], %[[VAR0]] : i16
+    // CHECK-GFX: %[[VAR2:.+]] = amdgpu.dpp %[[VAR1]] %[[VAR1]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR3:.+]] = arith.addi %[[VAR1]], %[[VAR2]] : i16
+    // CHECK-GFX: %[[VAR4:.+]] = amdgpu.dpp %[[VAR3]] %[[VAR3]]  row_half_mirror(unit) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR5:.+]] = arith.addi %[[VAR3]], %[[VAR4]] : i16
+    // CHECK-GFX: %[[VAR6:.+]] = amdgpu.dpp %[[VAR5]] %[[VAR5]]  row_mirror(unit) {bound_ctrl = true} : i16
+    // CHECK-GFX: %[[VAR7:.+]] = arith.addi %[[VAR5]], %[[VAR6]] : i16
+    // CHECK-GFX: "test.consume"(%[[VAR7]]) : (i16) -> ()
     %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 16) : (i16) -> i16
     "test.consume"(%sum0) : (i16) -> ()
 
@@ -407,11 +364,8 @@ gpu.module @kernels {
   // CHECK-SHFL-LABEL: gpu.func @kernel6(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
   //
-  // CHECK-GFX9-LABEL: gpu.func @kernel6(
-  // CHECK-GFX9-NOT: amdgpu.dpp
-  //
-  // CHECK-GFX10-LABEL: gpu.func @kernel6(
-  // CHECK-GFX10-NOT: amdgpu.dpp
+  // CHECK-GFX-LABEL: gpu.func @kernel6(
+  // CHECK-GFX-NOT: amdgpu.dpp
   gpu.func @kernel6(%arg0: vector<3xi8>) kernel {
     // CHECK-SHFL: %[[CZ:.+]] = arith.constant dense<0> : vector<4xi8>
     // CHECK-SHFL: %[[V0:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CZ]] {offsets = [0], strides = [1]} : vector<3xi8> into vector<4xi8>
@@ -433,6 +387,44 @@ gpu.module @kernels {
     gpu.return
   }
 
+  // CHECK-GFX-LABEL: gpu.func @kernel7(
+  // CHECK-GFX-SAME:    %[[ARG0:.+]]: f32)
+  //
+  //   Checks, common to gfx942 and gfx1030, of
+  //     (1) quad_perm, followed by reduction resulting in reduction over 2 consecutive lanes,
+  //     (2) quad_perm, followed by reduction resulting in reduction over 4 consecutive lanes,
+  //     (3) row_half_mirror, followed by reduction resulting in reduction over 8 consecutive lanes, and
+  //     (4) row_mirror, followed by reduction resulting in reduction over 16 consecutive lanes.
+  // CHECK-GFX: %[[D0:.+]] = amdgpu.dpp %[[ARG0]] %[[ARG0]]  quad_perm([1 : i32, 0 : i32, 3 : i32, 2 : i32]) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A0:.+]] = arith.addf %[[ARG0]], %[[D0]] : f32
+  // CHECK-GFX: %[[D1:.+]] = amdgpu.dpp %[[A0]] %[[A0]]  quad_perm([2 : i32, 3 : i32, 0 : i32, 1 : i32]) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A1:.+]] = arith.addf %[[A0]], %[[D1]] : f32
+  // CHECK-GFX: %[[D2:.+]] = amdgpu.dpp %[[A1]] %[[A1]]  row_half_mirror(unit) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A2:.+]] = arith.addf %[[A1]], %[[D2]] : f32
+  // CHECK-GFX: %[[D3:.+]] = amdgpu.dpp %[[A2]] %[[A2]]  row_mirror(unit) {bound_ctrl = true} : f32
+  // CHECK-GFX: %[[A3:.+]] = arith.addf %[[A2]], %[[D3]] : f32
+  //
+  //   Now, on gfx942:
+  //     (1) Lane 15 gets broadcast to lanes [16, 32) and lane 31 gets broadcast to lanes [48, 64], after which
+  //         the reduction in lanes [16, 32) is over the full cluster of the first 32 lanes, and the reduction in lanes
+  //         [48, 64) is over the full cluster of the last 32 lanes.
+  //     (2) Update the reduction value in lanes [0, 16) and [32, 48) with the final reduction result from
+  //         lanes [16, 32) and [48, 64), respectively.
+  // CHECK-GFX9: %[[BCAST15:.+]] = amdgpu.dpp %[[A3]] %[[A3]]  row_bcast_15(unit) {row_mask = 10 : i32} : f32
+  // CHECK-GFX9: %[[SUM:.+]] = arith.addf %[[A3]], %[[BCAST15]] : f32
+  // CHECK-GFX9: %[[SWIZ:.+]] = amdgpu.swizzle_bitmode %[[SUM]] 0 31 0 : f32
+  // CHECK-GFX9: "test.consume"(%[[SWIZ]]) : (f32) -> ()
+  //
+  //   On gfx1030, the final step is to permute the lanes and perform final reduction:
+  // CHECK-GFX10: rocdl.permlanex16
+  // CHECK-GFX10: arith.addf
+  // CHECK-GFX10: "test.consume"
+   gpu.func @kernel7(%arg0: f32) kernel {
+     %sum0 = gpu.subgroup_reduce add %arg0 cluster(size = 32) : (f32) -> (f32)
+     "test.consume"(%sum0) : (f32) -> ()
+     gpu.return
+   }
+
   // CHECK-SHFL-LABEL: gpu.func @kernel_cluster_size_is_subgroup_size(
   // CHECK-SHFL-SAME:    %[[ARG0:.+]]: vector<3xi8>)
   //

@newling
Copy link
Contributor Author

newling commented Oct 30, 2025

I need to verify but this I think resolves iree-org/iree#22397

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks!

Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM

// swizzle.
if (ci.subgroupSize == 64 && ci.clusterSize == 32) {
res =
amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and_mask=*/0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you benchmark this vs

resLo = readlane(res, 31)
resHi = readlane(res, 63)
res = select (laneId < 32), resLo, resHi

?

I suspect the latter may be desirable since it doesn't go into the crossbar.

Copy link
Member

@Groverkss Groverkss Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit hard to benchmark. The PR fixes a real correctness issue. Let's not block on something that might be slightly more performant. We can do this as a follow up.

// RUN: mlir-opt --allow-unregistered-dialect \
// RUN: --test-gpu-subgroup-reduce-lowering="expand-to-shuffles target=gfx1030" %s \
// RUN: | FileCheck %s --check-prefix=CHECK-GFX10
// RUN: | FileCheck %s --check-prefixes=CHECK-GFX,CHECK-GFX10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIce cleanup!

@newling newling merged commit 0928f46 into llvm:main Oct 31, 2025
10 checks passed
@newling
Copy link
Contributor Author

newling commented Oct 31, 2025

@krzysz00 I'll follow this up with your suggestion where we can discuss it in more detail (in flight)

DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
…lvm#165764)

This is a fix for a cluster size of 32 when the subgroup size is 64.
Previously, only lanes [16, 32) u [48, 64) contained the correct
clusterwise reduction value. This PR adds a swizzle instruction to
broadcast the correct value down to lanes [0, 16) u [32, 48).
Groverkss added a commit to iree-org/iree that referenced this pull request Nov 3, 2025
#22521)

… to previous state (#22392)"

This reverts commit 4a716e2.

The underlying issue was fixed by
llvm/llvm-project#165764 . Thanks to @newling
for figuring out this tricky issue.

Fixes: #22397

ci-extra: test_torch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants