Skip to content

[mlir][LLVM] Let decomposeValue/composeValue handle aggregates#183405

Merged
krzysz00 merged 2 commits intollvm:mainfrom
krzysz00:gpu-subgroup-broadcast-lower-structs
Feb 27, 2026
Merged

[mlir][LLVM] Let decomposeValue/composeValue handle aggregates#183405
krzysz00 merged 2 commits intollvm:mainfrom
krzysz00:gpu-subgroup-broadcast-lower-structs

Conversation

@krzysz00
Copy link
Contributor

This commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example.

(This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like)

This commit updates the LLVM::decomposeValue and LLVM::composeValue
methods to handle aggregate types - LLVM arrays and structs, and to
have different behaviors on dealing with types like pointers that
can't be bitcast to fixed-size integers. This allows the "any type" on
gpu.subgroup_broadcast to be more comprehensive - you can broadcast a
memref to a subgroup by decomposing it, for example.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2026

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-backend-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example.

(This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like)


Full diff: https://github.com/llvm/llvm-project/pull/183405.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+10-4)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+8-4)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+10-4)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+135-33)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+69)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index cacd500d41291..562ce48e23f26 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -66,10 +66,16 @@ bool opHasUnsupportedFloatingPointTypes(Operation *op,
 } // namespace detail
 
 /// Decomposes a `src` value into a set of values of type `dstType` through
-/// series of bitcasts and vector ops. Src and dst types are expected to be int
-/// or float types or vector types of them.
-SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
-                                  Type dstType);
+/// series of bitcasts and vector ops. Handles int, float, vector types as well
+/// as LLVM aggregate types (LLVMArrayType, LLVMStructType) by recursively
+/// extracting elements.
+///
+/// When `permitVariablySizedScalars` is true, leaf types that have no fixed
+/// bit width (e.g., `!llvm.ptr`) are passed through as-is (1 element in
+/// result). When false (default), encountering such a type returns failure.
+LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src,
+                             Type dstType, SmallVectorImpl<Value> &result,
+                             bool permitVariablySizedScalars = false);
 
 /// Composes a set of `src` values into a single value of type `dstType` through
 /// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3c2c61b2426e9..379d6180596e9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2530,8 +2530,10 @@ struct AMDGPUSwizzleBitModeLowering
     Location loc = op.getLoc();
     Type i32 = rewriter.getI32Type();
     Value src = adaptor.getSrc();
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, src, i32);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
     unsigned andMask = op.getAndMask();
     unsigned orMask = op.getOrMask();
     unsigned xorMask = op.getXorMask();
@@ -2573,8 +2575,10 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
     bool fi = op.getFetchInactive();
     bool boundctrl = op.getBoundCtrl();
 
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, src, i32);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
 
     SmallVector<Value> permuted;
     for (Value v : decomposed) {
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 096554d53e031..c5f0aedb33143 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -271,8 +271,11 @@ struct GPUSubgroupBroadcastOpToROCDL
 
     Type i32 = rewriter.getI32Type();
     Location loc = op.getLoc();
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, src, i32);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
+                                    /*permitVariablySizedScalars=*/true)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
 
     SmallVector<Value> results;
     results.reserve(decomposed.size());
@@ -359,8 +362,11 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Value dwordAlignedDstLane =
         LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
 
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
+                                    decomposed)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
     SmallVector<Value> swizzled;
     for (Value v : decomposed) {
       Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..38f71b916b5b8 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -384,23 +384,68 @@ static unsigned getBitWidth(Type type) {
   return vec.getNumElements() * getBitWidth(vec.getElementType());
 }
 
+/// Returns true if every leaf in `type` (recursing through LLVM arrays and
+/// structs) is either equal to `dstType` or has a fixed bit width.
+static bool isFixedSizeAggregate(Type type, Type dstType) {
+  if (type == dstType)
+    return true;
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
+    return isFixedSizeAggregate(arrayType.getElementType(), dstType);
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(type))
+    return llvm::all_of(structType.getBody(), [&](Type fieldType) {
+      return isFixedSizeAggregate(fieldType, dstType);
+    });
+  if (auto vecTy = dyn_cast<VectorType>(type))
+    return !vecTy.isScalable();
+  return type.isIntOrFloat();
+}
+
 static Value createI32Constant(OpBuilder &builder, Location loc,
                                int32_t value) {
   Type i32 = builder.getI32Type();
   return LLVM::ConstantOp::create(builder, loc, i32, value);
 }
 
-SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
-                                              Value src, Type dstType) {
+/// Recursive implementation of decomposeValue. When
+/// `permitVariablySizedScalars` is false, callers must ensure
+/// isFixedSizeAggregate() holds before calling this.
+static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src,
+                               Type dstType, SmallVectorImpl<Value> &result) {
   Type srcType = src.getType();
-  if (srcType == dstType)
-    return {src};
+  if (srcType == dstType) {
+    result.push_back(src);
+    return;
+  }
+
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(srcType)) {
+    for (auto i : llvm::seq(arrayType.getNumElements())) {
+      Value elem = LLVM::ExtractValueOp::create(builder, loc, src, i);
+      decomposeValueImpl(builder, loc, elem, dstType, result);
+    }
+    return;
+  }
+
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(srcType)) {
+    for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+      Value field = LLVM::ExtractValueOp::create(builder, loc, src,
+                                                 static_cast<int64_t>(i));
+      decomposeValueImpl(builder, loc, field, dstType, result);
+    }
+    return;
+  }
+
+  // Variably sized leaf types (e.g., ptr) — pass through as-is.
+  if (!srcType.isIntOrFloat() && !isa<VectorType>(srcType)) {
+    result.push_back(src);
+    return;
+  }
 
   unsigned srcBitWidth = getBitWidth(srcType);
   unsigned dstBitWidth = getBitWidth(dstType);
   if (srcBitWidth == dstBitWidth) {
     Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
-    return {cast};
+    result.push_back(cast);
+    return;
   }
 
   if (dstBitWidth > srcBitWidth) {
@@ -410,7 +455,8 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
 
     auto largerInt = builder.getIntegerType(dstBitWidth);
     Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
-    return {res};
+    result.push_back(res);
+    return;
   }
   assert(srcBitWidth % dstBitWidth == 0 &&
          "src bit width must be a multiple of dst bit width");
@@ -419,47 +465,94 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
 
   src = LLVM::BitcastOp::create(builder, loc, vecType, src);
 
-  SmallVector<Value> res;
   for (auto i : llvm::seq(numElements)) {
     Value idx = createI32Constant(builder, loc, i);
     Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
-    res.emplace_back(elem);
+    result.push_back(elem);
   }
-
-  return res;
 }
 
-Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
-                               Type dstType) {
-  assert(!src.empty() && "src range must not be empty");
-  if (src.size() == 1) {
-    Value res = src.front();
-    if (res.getType() == dstType)
-      return res;
+LogicalResult mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+                                         Value src, Type dstType,
+                                         SmallVectorImpl<Value> &result,
+                                         bool permitVariablySizedScalars) {
+  // Check the type tree before emitting any IR, so that a failing pattern
+  // leaves the IR unmodified.
+  if (!permitVariablySizedScalars &&
+      !isFixedSizeAggregate(src.getType(), dstType))
+    return failure();
 
-    unsigned srcBitWidth = getBitWidth(res.getType());
-    unsigned dstBitWidth = getBitWidth(dstType);
-    if (dstBitWidth < srcBitWidth) {
-      auto largerInt = builder.getIntegerType(srcBitWidth);
-      if (res.getType() != largerInt)
-        res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+  decomposeValueImpl(builder, loc, src, dstType, result);
+  return success();
+}
+
+/// Recursive implementation of composeValue. Consumes elements from `src`
+/// starting at `offset`, advancing it past the consumed elements.
+static Value composeValueImpl(OpBuilder &builder, Location loc, ValueRange src,
+                              size_t &offset, Type dstType) {
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(dstType)) {
+    Value result = LLVM::PoisonOp::create(builder, loc, arrayType);
+    Type elemType = arrayType.getElementType();
+    for (auto i : llvm::seq(arrayType.getNumElements())) {
+      Value elem = composeValueImpl(builder, loc, src, offset, elemType);
+      result = LLVM::InsertValueOp::create(builder, loc, result, elem, i);
+    }
+    return result;
+  }
 
-      auto smallerInt = builder.getIntegerType(dstBitWidth);
-      res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(dstType)) {
+    Value result = LLVM::PoisonOp::create(builder, loc, structType);
+    for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+      Value field = composeValueImpl(builder, loc, src, offset, fieldType);
+      result = LLVM::InsertValueOp::create(builder, loc, result, field,
+                                           static_cast<int64_t>(i));
     }
+    return result;
+  }
+
+  // Variably sized leaf types (e.g., ptr) — consume and return as-is.
+  if (!dstType.isIntOrFloat() && !isa<VectorType>(dstType))
+    return src[offset++];
+
+  unsigned dstBitWidth = getBitWidth(dstType);
 
-    if (res.getType() != dstType)
-      res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+  Value front = src[offset];
+  if (front.getType() == dstType) {
+    ++offset;
+    return front;
+  }
 
-    return res;
+  // Single element wider than or equal to dst: bitcast/trunc.
+  if (front.getType().isIntOrFloat() || isa<VectorType>(front.getType())) {
+    unsigned srcBitWidth = getBitWidth(front.getType());
+    if (srcBitWidth >= dstBitWidth) {
+      ++offset;
+      Value res = front;
+      if (dstBitWidth < srcBitWidth) {
+        auto largerInt = builder.getIntegerType(srcBitWidth);
+        if (res.getType() != largerInt)
+          res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+
+        auto smallerInt = builder.getIntegerType(dstBitWidth);
+        res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+      }
+      if (res.getType() != dstType)
+        res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+      return res;
+    }
   }
 
-  int64_t numElements = src.size();
-  auto srcType = VectorType::get(numElements, src.front().getType());
-  Value res = LLVM::PoisonOp::create(builder, loc, srcType);
-  for (auto &&[i, elem] : llvm::enumerate(src)) {
+  // Multiple elements narrower than dst: gather into a vector and bitcast.
+  unsigned elemBitWidth = getBitWidth(front.getType());
+  assert(dstBitWidth % elemBitWidth == 0 &&
+         "dst bit width must be a multiple of element bit width");
+  int64_t numElements = dstBitWidth / elemBitWidth;
+  auto vecType = VectorType::get(numElements, front.getType());
+  Value res = LLVM::PoisonOp::create(builder, loc, vecType);
+  for (auto i : llvm::seq(numElements)) {
     Value idx = createI32Constant(builder, loc, i);
-    res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
+    res = LLVM::InsertElementOp::create(builder, loc, vecType, res,
+                                        src[offset++], idx);
   }
 
   if (res.getType() != dstType)
@@ -468,6 +561,15 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
   return res;
 }
 
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                               Type dstType) {
+  assert(!src.empty() && "src range must not be empty");
+  size_t offset = 0;
+  Value result = composeValueImpl(builder, loc, src, offset, dstType);
+  assert(offset == src.size() && "not all decomposed values were consumed");
+  return result;
+}
+
 Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
                                        const LLVMTypeConverter &converter,
                                        MemRefType type, Value memRefDesc,
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 6b20a442c30c5..1f958791b41bb 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -851,4 +851,73 @@ func.func @broadcast_4xi16(%arg0 : vector<4xi16>) -> vector<4xi16> {
   %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<4xi16>
   func.return %0 : vector<4xi16>
 }
+
+// CHECK-LABEL: func @broadcast_2x2xi16
+//  CHECK-SAME:   (%[[ARG:.+]]: !llvm.array<2 x vector<2xi16>>)
+func.func @broadcast_2x2xi16(%arg0 : vector<2x2xi16>) -> vector<2x2xi16> {
+  // CHECK-DAG: %[[E0:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.array<2 x vector<2xi16>>
+  // CHECK-DAG: %[[BC0:.+]] = llvm.bitcast %[[E0]] : vector<2xi16> to i32
+  // CHECK-DAG: %[[E1:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.array<2 x vector<2xi16>>
+  // CHECK-DAG: %[[BC1:.+]] = llvm.bitcast %[[E1]] : vector<2xi16> to i32
+  // CHECK: %[[R0:.+]] = rocdl.readfirstlane %[[BC0]] : i32
+  // CHECK: %[[R1:.+]] = rocdl.readfirstlane %[[BC1]] : i32
+  // CHECK: %[[POISON:.+]] = llvm.mlir.poison : !llvm.array<2 x vector<2xi16>>
+  // CHECK: %[[RBC0:.+]] = llvm.bitcast %[[R0]] : i32 to vector<2xi16>
+  // CHECK: %[[INS0:.+]] = llvm.insertvalue %[[RBC0]], %[[POISON]][0] : !llvm.array<2 x vector<2xi16>>
+  // CHECK: %[[RBC1:.+]] = llvm.bitcast %[[R1]] : i32 to vector<2xi16>
+  // CHECK: %[[INS1:.+]] = llvm.insertvalue %[[RBC1]], %[[INS0]][1] : !llvm.array<2 x vector<2xi16>>
+  // CHECK: llvm.return %[[INS1]]
+  %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<2x2xi16>
+  func.return %0 : vector<2x2xi16>
+}
+
+// CHECK-LABEL: func @broadcast_memref
+func.func @broadcast_memref(%arg0 : memref<?xi32>) -> memref<?xi32> {
+  // CHECK-DAG: %[[PTR0:.+]] = llvm.extractvalue %{{.+}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[PTR1:.+]] = llvm.extractvalue %{{.+}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[OFF:.+]] = llvm.extractvalue %{{.+}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[OFFVEC:.+]] = llvm.bitcast %[[OFF]] : i64 to vector<2xi32>
+  // CHECK-DAG: %[[OFF0:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[OFF1:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[SZ_ARR:.+]] = llvm.extractvalue %{{.+}}[3] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[SZ:.+]] = llvm.extractvalue %[[SZ_ARR]][0] : !llvm.array<1 x i64>
+  // CHECK-DAG: %[[SZVEC:.+]] = llvm.bitcast %[[SZ]] : i64 to vector<2xi32>
+  // CHECK-DAG: %[[SZ0:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[SZ1:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[STR_ARR:.+]] = llvm.extractvalue %{{.+}}[4] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[STR:.+]] = llvm.extractvalue %[[STR_ARR]][0] : !llvm.array<1 x i64>
+  // CHECK-DAG: %[[STRVEC:.+]] = llvm.bitcast %[[STR]] : i64 to vector<2xi32>
+  // CHECK-DAG: %[[STR0:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[STR1:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+  //
+  // CHECK: %[[RPTR0:.+]] = rocdl.readfirstlane %[[PTR0]] : !llvm.ptr
+  // CHECK: %[[RPTR1:.+]] = rocdl.readfirstlane %[[PTR1]] : !llvm.ptr
+  // CHECK: %[[ROFF0:.+]] = rocdl.readfirstlane %[[OFF0]] : i32
+  // CHECK: %[[ROFF1:.+]] = rocdl.readfirstlane %[[OFF1]] : i32
+  // CHECK: %[[RSZ0:.+]] = rocdl.readfirstlane %[[SZ0]] : i32
+  // CHECK: %[[RSZ1:.+]] = rocdl.readfirstlane %[[SZ1]] : i32
+  // CHECK: %[[RSTR0:.+]] = rocdl.readfirstlane %[[STR0]] : i32
+  // CHECK: %[[RSTR1:.+]] = rocdl.readfirstlane %[[STR1]] : i32
+  //
+  // CHECK: %[[SOUT:.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[S0:.+]] = llvm.insertvalue %[[RPTR0]], %[[SOUT]][0]
+  // CHECK: %[[S1:.+]] = llvm.insertvalue %[[RPTR1]], %[[S0]][1]
+  // CHECK: %[[OVEC:.+]] = llvm.insertelement %[[ROFF0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[OVEC2:.+]] = llvm.insertelement %[[ROFF1]], %[[OVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[OCAST:.+]] = llvm.bitcast %[[OVEC2]] : vector<2xi32> to i64
+  // CHECK: %[[S2:.+]] = llvm.insertvalue %[[OCAST]], %[[S1]][2]
+  // CHECK: %[[SZVEC2:.+]] = llvm.insertelement %[[RSZ0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[SZVEC3:.+]] = llvm.insertelement %[[RSZ1]], %[[SZVEC2]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[SZCAST:.+]] = llvm.bitcast %[[SZVEC3]] : vector<2xi32> to i64
+  // CHECK: %[[SZA:.+]] = llvm.insertvalue %[[SZCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+  // CHECK: %[[S3:.+]] = llvm.insertvalue %[[SZA]], %[[S2]][3]
+  // CHECK: %[[STRVEC2:.+]] = llvm.insertelement %[[RSTR0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[STRVEC3:.+]] = llvm.insertelement %[[RSTR1]], %[[STRVEC2]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[STRCAST:.+]] = llvm.bitcast %[[STRVEC3]] : vector<2xi32> to i64
+  // CHECK: %[[STR_ARR:.+]] = llvm.insertvalue %[[STRCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+  // CHECK: %[[S4:.+]] = llvm.insertvalue %[[STR_ARR]], %[[S3]][4]
+  // CHECK: llvm.return %[[S4]]
+  %0 = gpu.subgroup_broadcast %arg0, first_active_lane : memref<?xi32>
+  func.return %0 : memref<?xi32>
+}
 }

@llvmbot
Copy link
Member

llvmbot commented Feb 25, 2026

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

This commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example.

(This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like)


Full diff: https://github.com/llvm/llvm-project/pull/183405.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+10-4)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+8-4)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+10-4)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+135-33)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+69)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index cacd500d41291..562ce48e23f26 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -66,10 +66,16 @@ bool opHasUnsupportedFloatingPointTypes(Operation *op,
 } // namespace detail
 
 /// Decomposes a `src` value into a set of values of type `dstType` through
-/// series of bitcasts and vector ops. Src and dst types are expected to be int
-/// or float types or vector types of them.
-SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
-                                  Type dstType);
+/// series of bitcasts and vector ops. Handles int, float, vector types as well
+/// as LLVM aggregate types (LLVMArrayType, LLVMStructType) by recursively
+/// extracting elements.
+///
+/// When `permitVariablySizedScalars` is true, leaf types that have no fixed
+/// bit width (e.g., `!llvm.ptr`) are passed through as-is (1 element in
+/// result). When false (default), encountering such a type returns failure.
+LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src,
+                             Type dstType, SmallVectorImpl<Value> &result,
+                             bool permitVariablySizedScalars = false);
 
 /// Composes a set of `src` values into a single value of type `dstType` through
 /// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3c2c61b2426e9..379d6180596e9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2530,8 +2530,10 @@ struct AMDGPUSwizzleBitModeLowering
     Location loc = op.getLoc();
     Type i32 = rewriter.getI32Type();
     Value src = adaptor.getSrc();
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, src, i32);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
     unsigned andMask = op.getAndMask();
     unsigned orMask = op.getOrMask();
     unsigned xorMask = op.getXorMask();
@@ -2573,8 +2575,10 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
     bool fi = op.getFetchInactive();
     bool boundctrl = op.getBoundCtrl();
 
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, src, i32);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
 
     SmallVector<Value> permuted;
     for (Value v : decomposed) {
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 096554d53e031..c5f0aedb33143 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -271,8 +271,11 @@ struct GPUSubgroupBroadcastOpToROCDL
 
     Type i32 = rewriter.getI32Type();
     Location loc = op.getLoc();
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, src, i32);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
+                                    /*permitVariablySizedScalars=*/true)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
 
     SmallVector<Value> results;
     results.reserve(decomposed.size());
@@ -359,8 +362,11 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Value dwordAlignedDstLane =
         LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
 
-    SmallVector<Value> decomposed =
-        LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
+    SmallVector<Value> decomposed;
+    if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
+                                    decomposed)))
+      return rewriter.notifyMatchFailure(op,
+                                         "failed to decompose value to i32");
     SmallVector<Value> swizzled;
     for (Value v : decomposed) {
       Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..38f71b916b5b8 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -384,23 +384,68 @@ static unsigned getBitWidth(Type type) {
   return vec.getNumElements() * getBitWidth(vec.getElementType());
 }
 
+/// Returns true if every leaf in `type` (recursing through LLVM arrays and
+/// structs) is either equal to `dstType` or has a fixed bit width.
+static bool isFixedSizeAggregate(Type type, Type dstType) {
+  if (type == dstType)
+    return true;
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
+    return isFixedSizeAggregate(arrayType.getElementType(), dstType);
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(type))
+    return llvm::all_of(structType.getBody(), [&](Type fieldType) {
+      return isFixedSizeAggregate(fieldType, dstType);
+    });
+  if (auto vecTy = dyn_cast<VectorType>(type))
+    return !vecTy.isScalable();
+  return type.isIntOrFloat();
+}
+
 static Value createI32Constant(OpBuilder &builder, Location loc,
                                int32_t value) {
   Type i32 = builder.getI32Type();
   return LLVM::ConstantOp::create(builder, loc, i32, value);
 }
 
-SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
-                                              Value src, Type dstType) {
+/// Recursive implementation of decomposeValue. When
+/// `permitVariablySizedScalars` is false, callers must ensure
+/// isFixedSizeAggregate() holds before calling this.
+static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src,
+                               Type dstType, SmallVectorImpl<Value> &result) {
   Type srcType = src.getType();
-  if (srcType == dstType)
-    return {src};
+  if (srcType == dstType) {
+    result.push_back(src);
+    return;
+  }
+
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(srcType)) {
+    for (auto i : llvm::seq(arrayType.getNumElements())) {
+      Value elem = LLVM::ExtractValueOp::create(builder, loc, src, i);
+      decomposeValueImpl(builder, loc, elem, dstType, result);
+    }
+    return;
+  }
+
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(srcType)) {
+    for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+      Value field = LLVM::ExtractValueOp::create(builder, loc, src,
+                                                 static_cast<int64_t>(i));
+      decomposeValueImpl(builder, loc, field, dstType, result);
+    }
+    return;
+  }
+
+  // Variably sized leaf types (e.g., ptr) — pass through as-is.
+  if (!srcType.isIntOrFloat() && !isa<VectorType>(srcType)) {
+    result.push_back(src);
+    return;
+  }
 
   unsigned srcBitWidth = getBitWidth(srcType);
   unsigned dstBitWidth = getBitWidth(dstType);
   if (srcBitWidth == dstBitWidth) {
     Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
-    return {cast};
+    result.push_back(cast);
+    return;
   }
 
   if (dstBitWidth > srcBitWidth) {
@@ -410,7 +455,8 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
 
     auto largerInt = builder.getIntegerType(dstBitWidth);
     Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
-    return {res};
+    result.push_back(res);
+    return;
   }
   assert(srcBitWidth % dstBitWidth == 0 &&
          "src bit width must be a multiple of dst bit width");
@@ -419,47 +465,94 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
 
   src = LLVM::BitcastOp::create(builder, loc, vecType, src);
 
-  SmallVector<Value> res;
   for (auto i : llvm::seq(numElements)) {
     Value idx = createI32Constant(builder, loc, i);
     Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
-    res.emplace_back(elem);
+    result.push_back(elem);
   }
-
-  return res;
 }
 
-Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
-                               Type dstType) {
-  assert(!src.empty() && "src range must not be empty");
-  if (src.size() == 1) {
-    Value res = src.front();
-    if (res.getType() == dstType)
-      return res;
+LogicalResult mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+                                         Value src, Type dstType,
+                                         SmallVectorImpl<Value> &result,
+                                         bool permitVariablySizedScalars) {
+  // Check the type tree before emitting any IR, so that a failing pattern
+  // leaves the IR unmodified.
+  if (!permitVariablySizedScalars &&
+      !isFixedSizeAggregate(src.getType(), dstType))
+    return failure();
 
-    unsigned srcBitWidth = getBitWidth(res.getType());
-    unsigned dstBitWidth = getBitWidth(dstType);
-    if (dstBitWidth < srcBitWidth) {
-      auto largerInt = builder.getIntegerType(srcBitWidth);
-      if (res.getType() != largerInt)
-        res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+  decomposeValueImpl(builder, loc, src, dstType, result);
+  return success();
+}
+
+/// Recursive implementation of composeValue. Consumes elements from `src`
+/// starting at `offset`, advancing it past the consumed elements.
+static Value composeValueImpl(OpBuilder &builder, Location loc, ValueRange src,
+                              size_t &offset, Type dstType) {
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(dstType)) {
+    Value result = LLVM::PoisonOp::create(builder, loc, arrayType);
+    Type elemType = arrayType.getElementType();
+    for (auto i : llvm::seq(arrayType.getNumElements())) {
+      Value elem = composeValueImpl(builder, loc, src, offset, elemType);
+      result = LLVM::InsertValueOp::create(builder, loc, result, elem, i);
+    }
+    return result;
+  }
 
-      auto smallerInt = builder.getIntegerType(dstBitWidth);
-      res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+  if (auto structType = dyn_cast<LLVM::LLVMStructType>(dstType)) {
+    Value result = LLVM::PoisonOp::create(builder, loc, structType);
+    for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+      Value field = composeValueImpl(builder, loc, src, offset, fieldType);
+      result = LLVM::InsertValueOp::create(builder, loc, result, field,
+                                           static_cast<int64_t>(i));
     }
+    return result;
+  }
+
+  // Variably sized leaf types (e.g., ptr) — consume and return as-is.
+  if (!dstType.isIntOrFloat() && !isa<VectorType>(dstType))
+    return src[offset++];
+
+  unsigned dstBitWidth = getBitWidth(dstType);
 
-    if (res.getType() != dstType)
-      res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+  Value front = src[offset];
+  if (front.getType() == dstType) {
+    ++offset;
+    return front;
+  }
 
-    return res;
+  // Single element wider than or equal to dst: bitcast/trunc.
+  if (front.getType().isIntOrFloat() || isa<VectorType>(front.getType())) {
+    unsigned srcBitWidth = getBitWidth(front.getType());
+    if (srcBitWidth >= dstBitWidth) {
+      ++offset;
+      Value res = front;
+      if (dstBitWidth < srcBitWidth) {
+        auto largerInt = builder.getIntegerType(srcBitWidth);
+        if (res.getType() != largerInt)
+          res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+
+        auto smallerInt = builder.getIntegerType(dstBitWidth);
+        res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+      }
+      if (res.getType() != dstType)
+        res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+      return res;
+    }
   }
 
-  int64_t numElements = src.size();
-  auto srcType = VectorType::get(numElements, src.front().getType());
-  Value res = LLVM::PoisonOp::create(builder, loc, srcType);
-  for (auto &&[i, elem] : llvm::enumerate(src)) {
+  // Multiple elements narrower than dst: gather into a vector and bitcast.
+  unsigned elemBitWidth = getBitWidth(front.getType());
+  assert(dstBitWidth % elemBitWidth == 0 &&
+         "dst bit width must be a multiple of element bit width");
+  int64_t numElements = dstBitWidth / elemBitWidth;
+  auto vecType = VectorType::get(numElements, front.getType());
+  Value res = LLVM::PoisonOp::create(builder, loc, vecType);
+  for (auto i : llvm::seq(numElements)) {
     Value idx = createI32Constant(builder, loc, i);
-    res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
+    res = LLVM::InsertElementOp::create(builder, loc, vecType, res,
+                                        src[offset++], idx);
   }
 
   if (res.getType() != dstType)
@@ -468,6 +561,15 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
   return res;
 }
 
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+                               Type dstType) {
+  assert(!src.empty() && "src range must not be empty");
+  size_t offset = 0;
+  Value result = composeValueImpl(builder, loc, src, offset, dstType);
+  assert(offset == src.size() && "not all decomposed values were consumed");
+  return result;
+}
+
 Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
                                        const LLVMTypeConverter &converter,
                                        MemRefType type, Value memRefDesc,
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 6b20a442c30c5..1f958791b41bb 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -851,4 +851,73 @@ func.func @broadcast_4xi16(%arg0 : vector<4xi16>) -> vector<4xi16> {
   %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<4xi16>
   func.return %0 : vector<4xi16>
 }
+
+// CHECK-LABEL: func @broadcast_2x2xi16
+//  CHECK-SAME:   (%[[ARG:.+]]: !llvm.array<2 x vector<2xi16>>)
+func.func @broadcast_2x2xi16(%arg0 : vector<2x2xi16>) -> vector<2x2xi16> {
+  // CHECK-DAG: %[[E0:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.array<2 x vector<2xi16>>
+  // CHECK-DAG: %[[BC0:.+]] = llvm.bitcast %[[E0]] : vector<2xi16> to i32
+  // CHECK-DAG: %[[E1:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.array<2 x vector<2xi16>>
+  // CHECK-DAG: %[[BC1:.+]] = llvm.bitcast %[[E1]] : vector<2xi16> to i32
+  // CHECK: %[[R0:.+]] = rocdl.readfirstlane %[[BC0]] : i32
+  // CHECK: %[[R1:.+]] = rocdl.readfirstlane %[[BC1]] : i32
+  // CHECK: %[[POISON:.+]] = llvm.mlir.poison : !llvm.array<2 x vector<2xi16>>
+  // CHECK: %[[RBC0:.+]] = llvm.bitcast %[[R0]] : i32 to vector<2xi16>
+  // CHECK: %[[INS0:.+]] = llvm.insertvalue %[[RBC0]], %[[POISON]][0] : !llvm.array<2 x vector<2xi16>>
+  // CHECK: %[[RBC1:.+]] = llvm.bitcast %[[R1]] : i32 to vector<2xi16>
+  // CHECK: %[[INS1:.+]] = llvm.insertvalue %[[RBC1]], %[[INS0]][1] : !llvm.array<2 x vector<2xi16>>
+  // CHECK: llvm.return %[[INS1]]
+  %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<2x2xi16>
+  func.return %0 : vector<2x2xi16>
+}
+
+// CHECK-LABEL: func @broadcast_memref
+func.func @broadcast_memref(%arg0 : memref<?xi32>) -> memref<?xi32> {
+  // CHECK-DAG: %[[PTR0:.+]] = llvm.extractvalue %{{.+}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[PTR1:.+]] = llvm.extractvalue %{{.+}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[OFF:.+]] = llvm.extractvalue %{{.+}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[OFFVEC:.+]] = llvm.bitcast %[[OFF]] : i64 to vector<2xi32>
+  // CHECK-DAG: %[[OFF0:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[OFF1:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[SZ_ARR:.+]] = llvm.extractvalue %{{.+}}[3] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[SZ:.+]] = llvm.extractvalue %[[SZ_ARR]][0] : !llvm.array<1 x i64>
+  // CHECK-DAG: %[[SZVEC:.+]] = llvm.bitcast %[[SZ]] : i64 to vector<2xi32>
+  // CHECK-DAG: %[[SZ0:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[SZ1:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[STR_ARR:.+]] = llvm.extractvalue %{{.+}}[4] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK-DAG: %[[STR:.+]] = llvm.extractvalue %[[STR_ARR]][0] : !llvm.array<1 x i64>
+  // CHECK-DAG: %[[STRVEC:.+]] = llvm.bitcast %[[STR]] : i64 to vector<2xi32>
+  // CHECK-DAG: %[[STR0:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK-DAG: %[[STR1:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+  //
+  // CHECK: %[[RPTR0:.+]] = rocdl.readfirstlane %[[PTR0]] : !llvm.ptr
+  // CHECK: %[[RPTR1:.+]] = rocdl.readfirstlane %[[PTR1]] : !llvm.ptr
+  // CHECK: %[[ROFF0:.+]] = rocdl.readfirstlane %[[OFF0]] : i32
+  // CHECK: %[[ROFF1:.+]] = rocdl.readfirstlane %[[OFF1]] : i32
+  // CHECK: %[[RSZ0:.+]] = rocdl.readfirstlane %[[SZ0]] : i32
+  // CHECK: %[[RSZ1:.+]] = rocdl.readfirstlane %[[SZ1]] : i32
+  // CHECK: %[[RSTR0:.+]] = rocdl.readfirstlane %[[STR0]] : i32
+  // CHECK: %[[RSTR1:.+]] = rocdl.readfirstlane %[[STR1]] : i32
+  //
+  // CHECK: %[[SOUT:.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: %[[S0:.+]] = llvm.insertvalue %[[RPTR0]], %[[SOUT]][0]
+  // CHECK: %[[S1:.+]] = llvm.insertvalue %[[RPTR1]], %[[S0]][1]
+  // CHECK: %[[OVEC:.+]] = llvm.insertelement %[[ROFF0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[OVEC2:.+]] = llvm.insertelement %[[ROFF1]], %[[OVEC]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[OCAST:.+]] = llvm.bitcast %[[OVEC2]] : vector<2xi32> to i64
+  // CHECK: %[[S2:.+]] = llvm.insertvalue %[[OCAST]], %[[S1]][2]
+  // CHECK: %[[SZVEC2:.+]] = llvm.insertelement %[[RSZ0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[SZVEC3:.+]] = llvm.insertelement %[[RSZ1]], %[[SZVEC2]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[SZCAST:.+]] = llvm.bitcast %[[SZVEC3]] : vector<2xi32> to i64
+  // CHECK: %[[SZA:.+]] = llvm.insertvalue %[[SZCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+  // CHECK: %[[S3:.+]] = llvm.insertvalue %[[SZA]], %[[S2]][3]
+  // CHECK: %[[STRVEC2:.+]] = llvm.insertelement %[[RSTR0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[STRVEC3:.+]] = llvm.insertelement %[[RSTR1]], %[[STRVEC2]][%{{.+}} : i32] : vector<2xi32>
+  // CHECK: %[[STRCAST:.+]] = llvm.bitcast %[[STRVEC3]] : vector<2xi32> to i64
+  // CHECK: %[[STR_ARR:.+]] = llvm.insertvalue %[[STRCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+  // CHECK: %[[S4:.+]] = llvm.insertvalue %[[STR_ARR]], %[[S3]][4]
+  // CHECK: llvm.return %[[S4]]
+  %0 = gpu.subgroup_broadcast %arg0, first_active_lane : memref<?xi32>
+  func.return %0 : memref<?xi32>
+}
 }

@kuhar kuhar requested a review from amd-eochoalo February 25, 2026 22:51
Copy link
Contributor

@amd-eochoalo amd-eochoalo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a couple of questions!

@krzysz00 krzysz00 merged commit 852c6ef into llvm:main Feb 27, 2026
10 checks passed
krzysz00 added a commit to krzysz00/llvm-project that referenced this pull request Feb 27, 2026
Currently, as pointed out in the reviews for llvm#183405, decomposeValues
and composeValues should be able to emit zexts and truncations for
cases like i48 and vector<3xi16> becoming i32s but currently that's an
assert. This commit fixes that limitation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
krzysz00 added a commit that referenced this pull request Feb 27, 2026
…3825)

Currently, as pointed out in the reviews for #183405, decomposeValues
and composeValues should be able to emit zexts and truncations for cases
like i48 and vector<3xi16> becoming i32s but currently that's an assert.
This commit fixes that limitation.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
…183405)

This commit updates the LLVM::decomposeValue and LLVM::composeValue
methods to handle aggregate types - LLVM arrays and structs, and to have
different behaviors on dealing with types like pointers that can't be
bitcast to fixed-size integers. This allows the "any type" on
gpu.subgroup_broadcast to be more comprehensive - you can broadcast a
memref to a subgroup by decomposing it, for example.

(This branched off of getting an LLM to implement
ValueuboundsOpInterface on subgroup_broadcast, having it add handling
for the dimensions of shaped types, and realizing that there's no
fundamental reason you can't broadcast a memref or the like)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
…m#183825)

Currently, as pointed out in the reviews for llvm#183405, decomposeValues
and composeValues should be able to emit zexts and truncations for cases
like i48 and vector<3xi16> becoming i32s but currently that's an assert.
This commit fixes that limitation.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants