[mlir][LLVM] Let decomposeValue/composeValue handle aggregates#183405
[mlir][LLVM] Let decomposeValue/composeValue handle aggregates#183405
Conversation
This commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Krzysztof Drewniak (krzysz00) ChangesThis commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example. (This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like) Full diff: https://github.com/llvm/llvm-project/pull/183405.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index cacd500d41291..562ce48e23f26 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -66,10 +66,16 @@ bool opHasUnsupportedFloatingPointTypes(Operation *op,
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
-/// series of bitcasts and vector ops. Src and dst types are expected to be int
-/// or float types or vector types of them.
-SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
- Type dstType);
+/// series of bitcasts and vector ops. Handles int, float, vector types as well
+/// as LLVM aggregate types (LLVMArrayType, LLVMStructType) by recursively
+/// extracting elements.
+///
+/// When `permitVariablySizedScalars` is true, leaf types that have no fixed
+/// bit width (e.g., `!llvm.ptr`) are passed through as-is (1 element in
+/// result). When false (default), encountering such a type returns failure.
+LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src,
+ Type dstType, SmallVectorImpl<Value> &result,
+ bool permitVariablySizedScalars = false);
/// Composes a set of `src` values into a single value of type `dstType` through
/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3c2c61b2426e9..379d6180596e9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2530,8 +2530,10 @@ struct AMDGPUSwizzleBitModeLowering
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, src, i32);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
unsigned andMask = op.getAndMask();
unsigned orMask = op.getOrMask();
unsigned xorMask = op.getXorMask();
@@ -2573,8 +2575,10 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, src, i32);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
SmallVector<Value> permuted;
for (Value v : decomposed) {
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 096554d53e031..c5f0aedb33143 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -271,8 +271,11 @@ struct GPUSubgroupBroadcastOpToROCDL
Type i32 = rewriter.getI32Type();
Location loc = op.getLoc();
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, src, i32);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
+ /*permitVariablySizedScalars=*/true)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
SmallVector<Value> results;
results.reserve(decomposed.size());
@@ -359,8 +362,11 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value dwordAlignedDstLane =
LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
+ decomposed)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..38f71b916b5b8 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -384,23 +384,68 @@ static unsigned getBitWidth(Type type) {
return vec.getNumElements() * getBitWidth(vec.getElementType());
}
+/// Returns true if every leaf in `type` (recursing through LLVM arrays and
+/// structs) is either equal to `dstType` or has a fixed bit width.
+static bool isFixedSizeAggregate(Type type, Type dstType) {
+ if (type == dstType)
+ return true;
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
+ return isFixedSizeAggregate(arrayType.getElementType(), dstType);
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(type))
+ return llvm::all_of(structType.getBody(), [&](Type fieldType) {
+ return isFixedSizeAggregate(fieldType, dstType);
+ });
+ if (auto vecTy = dyn_cast<VectorType>(type))
+ return !vecTy.isScalable();
+ return type.isIntOrFloat();
+}
+
static Value createI32Constant(OpBuilder &builder, Location loc,
int32_t value) {
Type i32 = builder.getI32Type();
return LLVM::ConstantOp::create(builder, loc, i32, value);
}
-SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
- Value src, Type dstType) {
+/// Recursive implementation of decomposeValue. When
+/// `permitVariablySizedScalars` is false, callers must ensure
+/// isFixedSizeAggregate() holds before calling this.
+static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src,
+ Type dstType, SmallVectorImpl<Value> &result) {
Type srcType = src.getType();
- if (srcType == dstType)
- return {src};
+ if (srcType == dstType) {
+ result.push_back(src);
+ return;
+ }
+
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(srcType)) {
+ for (auto i : llvm::seq(arrayType.getNumElements())) {
+ Value elem = LLVM::ExtractValueOp::create(builder, loc, src, i);
+ decomposeValueImpl(builder, loc, elem, dstType, result);
+ }
+ return;
+ }
+
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(srcType)) {
+ for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+ Value field = LLVM::ExtractValueOp::create(builder, loc, src,
+ static_cast<int64_t>(i));
+ decomposeValueImpl(builder, loc, field, dstType, result);
+ }
+ return;
+ }
+
+ // Variably sized leaf types (e.g., ptr) — pass through as-is.
+ if (!srcType.isIntOrFloat() && !isa<VectorType>(srcType)) {
+ result.push_back(src);
+ return;
+ }
unsigned srcBitWidth = getBitWidth(srcType);
unsigned dstBitWidth = getBitWidth(dstType);
if (srcBitWidth == dstBitWidth) {
Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
- return {cast};
+ result.push_back(cast);
+ return;
}
if (dstBitWidth > srcBitWidth) {
@@ -410,7 +455,8 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
auto largerInt = builder.getIntegerType(dstBitWidth);
Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
- return {res};
+ result.push_back(res);
+ return;
}
assert(srcBitWidth % dstBitWidth == 0 &&
"src bit width must be a multiple of dst bit width");
@@ -419,47 +465,94 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
src = LLVM::BitcastOp::create(builder, loc, vecType, src);
- SmallVector<Value> res;
for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
- res.emplace_back(elem);
+ result.push_back(elem);
}
-
- return res;
}
-Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
- Type dstType) {
- assert(!src.empty() && "src range must not be empty");
- if (src.size() == 1) {
- Value res = src.front();
- if (res.getType() == dstType)
- return res;
+LogicalResult mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+ Value src, Type dstType,
+ SmallVectorImpl<Value> &result,
+ bool permitVariablySizedScalars) {
+ // Check the type tree before emitting any IR, so that a failing pattern
+ // leaves the IR unmodified.
+ if (!permitVariablySizedScalars &&
+ !isFixedSizeAggregate(src.getType(), dstType))
+ return failure();
- unsigned srcBitWidth = getBitWidth(res.getType());
- unsigned dstBitWidth = getBitWidth(dstType);
- if (dstBitWidth < srcBitWidth) {
- auto largerInt = builder.getIntegerType(srcBitWidth);
- if (res.getType() != largerInt)
- res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+ decomposeValueImpl(builder, loc, src, dstType, result);
+ return success();
+}
+
+/// Recursive implementation of composeValue. Consumes elements from `src`
+/// starting at `offset`, advancing it past the consumed elements.
+static Value composeValueImpl(OpBuilder &builder, Location loc, ValueRange src,
+ size_t &offset, Type dstType) {
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(dstType)) {
+ Value result = LLVM::PoisonOp::create(builder, loc, arrayType);
+ Type elemType = arrayType.getElementType();
+ for (auto i : llvm::seq(arrayType.getNumElements())) {
+ Value elem = composeValueImpl(builder, loc, src, offset, elemType);
+ result = LLVM::InsertValueOp::create(builder, loc, result, elem, i);
+ }
+ return result;
+ }
- auto smallerInt = builder.getIntegerType(dstBitWidth);
- res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(dstType)) {
+ Value result = LLVM::PoisonOp::create(builder, loc, structType);
+ for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+ Value field = composeValueImpl(builder, loc, src, offset, fieldType);
+ result = LLVM::InsertValueOp::create(builder, loc, result, field,
+ static_cast<int64_t>(i));
}
+ return result;
+ }
+
+ // Variably sized leaf types (e.g., ptr) — consume and return as-is.
+ if (!dstType.isIntOrFloat() && !isa<VectorType>(dstType))
+ return src[offset++];
+
+ unsigned dstBitWidth = getBitWidth(dstType);
- if (res.getType() != dstType)
- res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+ Value front = src[offset];
+ if (front.getType() == dstType) {
+ ++offset;
+ return front;
+ }
- return res;
+ // Single element wider than or equal to dst: bitcast/trunc.
+ if (front.getType().isIntOrFloat() || isa<VectorType>(front.getType())) {
+ unsigned srcBitWidth = getBitWidth(front.getType());
+ if (srcBitWidth >= dstBitWidth) {
+ ++offset;
+ Value res = front;
+ if (dstBitWidth < srcBitWidth) {
+ auto largerInt = builder.getIntegerType(srcBitWidth);
+ if (res.getType() != largerInt)
+ res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+
+ auto smallerInt = builder.getIntegerType(dstBitWidth);
+ res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+ }
+ if (res.getType() != dstType)
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+ return res;
+ }
}
- int64_t numElements = src.size();
- auto srcType = VectorType::get(numElements, src.front().getType());
- Value res = LLVM::PoisonOp::create(builder, loc, srcType);
- for (auto &&[i, elem] : llvm::enumerate(src)) {
+ // Multiple elements narrower than dst: gather into a vector and bitcast.
+ unsigned elemBitWidth = getBitWidth(front.getType());
+ assert(dstBitWidth % elemBitWidth == 0 &&
+ "dst bit width must be a multiple of element bit width");
+ int64_t numElements = dstBitWidth / elemBitWidth;
+ auto vecType = VectorType::get(numElements, front.getType());
+ Value res = LLVM::PoisonOp::create(builder, loc, vecType);
+ for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
- res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
+ res = LLVM::InsertElementOp::create(builder, loc, vecType, res,
+ src[offset++], idx);
}
if (res.getType() != dstType)
@@ -468,6 +561,15 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
return res;
}
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType) {
+ assert(!src.empty() && "src range must not be empty");
+ size_t offset = 0;
+ Value result = composeValueImpl(builder, loc, src, offset, dstType);
+ assert(offset == src.size() && "not all decomposed values were consumed");
+ return result;
+}
+
Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
const LLVMTypeConverter &converter,
MemRefType type, Value memRefDesc,
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 6b20a442c30c5..1f958791b41bb 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -851,4 +851,73 @@ func.func @broadcast_4xi16(%arg0 : vector<4xi16>) -> vector<4xi16> {
%0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<4xi16>
func.return %0 : vector<4xi16>
}
+
+// CHECK-LABEL: func @broadcast_2x2xi16
+// CHECK-SAME: (%[[ARG:.+]]: !llvm.array<2 x vector<2xi16>>)
+func.func @broadcast_2x2xi16(%arg0 : vector<2x2xi16>) -> vector<2x2xi16> {
+ // CHECK-DAG: %[[E0:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.array<2 x vector<2xi16>>
+ // CHECK-DAG: %[[BC0:.+]] = llvm.bitcast %[[E0]] : vector<2xi16> to i32
+ // CHECK-DAG: %[[E1:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.array<2 x vector<2xi16>>
+ // CHECK-DAG: %[[BC1:.+]] = llvm.bitcast %[[E1]] : vector<2xi16> to i32
+ // CHECK: %[[R0:.+]] = rocdl.readfirstlane %[[BC0]] : i32
+ // CHECK: %[[R1:.+]] = rocdl.readfirstlane %[[BC1]] : i32
+ // CHECK: %[[POISON:.+]] = llvm.mlir.poison : !llvm.array<2 x vector<2xi16>>
+ // CHECK: %[[RBC0:.+]] = llvm.bitcast %[[R0]] : i32 to vector<2xi16>
+ // CHECK: %[[INS0:.+]] = llvm.insertvalue %[[RBC0]], %[[POISON]][0] : !llvm.array<2 x vector<2xi16>>
+ // CHECK: %[[RBC1:.+]] = llvm.bitcast %[[R1]] : i32 to vector<2xi16>
+ // CHECK: %[[INS1:.+]] = llvm.insertvalue %[[RBC1]], %[[INS0]][1] : !llvm.array<2 x vector<2xi16>>
+ // CHECK: llvm.return %[[INS1]]
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<2x2xi16>
+ func.return %0 : vector<2x2xi16>
+}
+
+// CHECK-LABEL: func @broadcast_memref
+func.func @broadcast_memref(%arg0 : memref<?xi32>) -> memref<?xi32> {
+ // CHECK-DAG: %[[PTR0:.+]] = llvm.extractvalue %{{.+}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[PTR1:.+]] = llvm.extractvalue %{{.+}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[OFF:.+]] = llvm.extractvalue %{{.+}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[OFFVEC:.+]] = llvm.bitcast %[[OFF]] : i64 to vector<2xi32>
+ // CHECK-DAG: %[[OFF0:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[OFF1:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[SZ_ARR:.+]] = llvm.extractvalue %{{.+}}[3] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[SZ:.+]] = llvm.extractvalue %[[SZ_ARR]][0] : !llvm.array<1 x i64>
+ // CHECK-DAG: %[[SZVEC:.+]] = llvm.bitcast %[[SZ]] : i64 to vector<2xi32>
+ // CHECK-DAG: %[[SZ0:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[SZ1:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[STR_ARR:.+]] = llvm.extractvalue %{{.+}}[4] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[STR:.+]] = llvm.extractvalue %[[STR_ARR]][0] : !llvm.array<1 x i64>
+ // CHECK-DAG: %[[STRVEC:.+]] = llvm.bitcast %[[STR]] : i64 to vector<2xi32>
+ // CHECK-DAG: %[[STR0:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[STR1:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+ //
+ // CHECK: %[[RPTR0:.+]] = rocdl.readfirstlane %[[PTR0]] : !llvm.ptr
+ // CHECK: %[[RPTR1:.+]] = rocdl.readfirstlane %[[PTR1]] : !llvm.ptr
+ // CHECK: %[[ROFF0:.+]] = rocdl.readfirstlane %[[OFF0]] : i32
+ // CHECK: %[[ROFF1:.+]] = rocdl.readfirstlane %[[OFF1]] : i32
+ // CHECK: %[[RSZ0:.+]] = rocdl.readfirstlane %[[SZ0]] : i32
+ // CHECK: %[[RSZ1:.+]] = rocdl.readfirstlane %[[SZ1]] : i32
+ // CHECK: %[[RSTR0:.+]] = rocdl.readfirstlane %[[STR0]] : i32
+ // CHECK: %[[RSTR1:.+]] = rocdl.readfirstlane %[[STR1]] : i32
+ //
+ // CHECK: %[[SOUT:.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[S0:.+]] = llvm.insertvalue %[[RPTR0]], %[[SOUT]][0]
+ // CHECK: %[[S1:.+]] = llvm.insertvalue %[[RPTR1]], %[[S0]][1]
+ // CHECK: %[[OVEC:.+]] = llvm.insertelement %[[ROFF0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[OVEC2:.+]] = llvm.insertelement %[[ROFF1]], %[[OVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[OCAST:.+]] = llvm.bitcast %[[OVEC2]] : vector<2xi32> to i64
+ // CHECK: %[[S2:.+]] = llvm.insertvalue %[[OCAST]], %[[S1]][2]
+ // CHECK: %[[SZVEC2:.+]] = llvm.insertelement %[[RSZ0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[SZVEC3:.+]] = llvm.insertelement %[[RSZ1]], %[[SZVEC2]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[SZCAST:.+]] = llvm.bitcast %[[SZVEC3]] : vector<2xi32> to i64
+ // CHECK: %[[SZA:.+]] = llvm.insertvalue %[[SZCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+ // CHECK: %[[S3:.+]] = llvm.insertvalue %[[SZA]], %[[S2]][3]
+ // CHECK: %[[STRVEC2:.+]] = llvm.insertelement %[[RSTR0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[STRVEC3:.+]] = llvm.insertelement %[[RSTR1]], %[[STRVEC2]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[STRCAST:.+]] = llvm.bitcast %[[STRVEC3]] : vector<2xi32> to i64
+ // CHECK: %[[STR_ARR:.+]] = llvm.insertvalue %[[STRCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+ // CHECK: %[[S4:.+]] = llvm.insertvalue %[[STR_ARR]], %[[S3]][4]
+ // CHECK: llvm.return %[[S4]]
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : memref<?xi32>
+ func.return %0 : memref<?xi32>
+}
}
|
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThis commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example. (This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like) Full diff: https://github.com/llvm/llvm-project/pull/183405.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index cacd500d41291..562ce48e23f26 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -66,10 +66,16 @@ bool opHasUnsupportedFloatingPointTypes(Operation *op,
} // namespace detail
/// Decomposes a `src` value into a set of values of type `dstType` through
-/// series of bitcasts and vector ops. Src and dst types are expected to be int
-/// or float types or vector types of them.
-SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
- Type dstType);
+/// series of bitcasts and vector ops. Handles int, float, vector types as well
+/// as LLVM aggregate types (LLVMArrayType, LLVMStructType) by recursively
+/// extracting elements.
+///
+/// When `permitVariablySizedScalars` is true, leaf types that have no fixed
+/// bit width (e.g., `!llvm.ptr`) are passed through as-is (1 element in
+/// result). When false (default), encountering such a type returns failure.
+LogicalResult decomposeValue(OpBuilder &builder, Location loc, Value src,
+ Type dstType, SmallVectorImpl<Value> &result,
+ bool permitVariablySizedScalars = false);
/// Composes a set of `src` values into a single value of type `dstType` through
/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3c2c61b2426e9..379d6180596e9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -2530,8 +2530,10 @@ struct AMDGPUSwizzleBitModeLowering
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, src, i32);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
unsigned andMask = op.getAndMask();
unsigned orMask = op.getOrMask();
unsigned xorMask = op.getXorMask();
@@ -2573,8 +2575,10 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, src, i32);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
SmallVector<Value> permuted;
for (Value v : decomposed) {
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 096554d53e031..c5f0aedb33143 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -271,8 +271,11 @@ struct GPUSubgroupBroadcastOpToROCDL
Type i32 = rewriter.getI32Type();
Location loc = op.getLoc();
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, src, i32);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed,
+ /*permitVariablySizedScalars=*/true)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
SmallVector<Value> results;
results.reserve(decomposed.size());
@@ -359,8 +362,11 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Value dwordAlignedDstLane =
LLVM::ShlOp::create(rewriter, loc, int32Type, selectDstLane, two);
- SmallVector<Value> decomposed =
- LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
+ SmallVector<Value> decomposed;
+ if (failed(LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type,
+ decomposed)))
+ return rewriter.notifyMatchFailure(op,
+ "failed to decompose value to i32");
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res = ROCDL::DsBpermuteOp::create(rewriter, loc, int32Type,
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 640ff3d7c3c7d..38f71b916b5b8 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -384,23 +384,68 @@ static unsigned getBitWidth(Type type) {
return vec.getNumElements() * getBitWidth(vec.getElementType());
}
+/// Returns true if every leaf in `type` (recursing through LLVM arrays and
+/// structs) is either equal to `dstType` or has a fixed bit width.
+static bool isFixedSizeAggregate(Type type, Type dstType) {
+ if (type == dstType)
+ return true;
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
+ return isFixedSizeAggregate(arrayType.getElementType(), dstType);
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(type))
+ return llvm::all_of(structType.getBody(), [&](Type fieldType) {
+ return isFixedSizeAggregate(fieldType, dstType);
+ });
+ if (auto vecTy = dyn_cast<VectorType>(type))
+ return !vecTy.isScalable();
+ return type.isIntOrFloat();
+}
+
static Value createI32Constant(OpBuilder &builder, Location loc,
int32_t value) {
Type i32 = builder.getI32Type();
return LLVM::ConstantOp::create(builder, loc, i32, value);
}
-SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
- Value src, Type dstType) {
+/// Recursive implementation of decomposeValue. When
+/// `permitVariablySizedScalars` is false, callers must ensure
+/// isFixedSizeAggregate() holds before calling this.
+static void decomposeValueImpl(OpBuilder &builder, Location loc, Value src,
+ Type dstType, SmallVectorImpl<Value> &result) {
Type srcType = src.getType();
- if (srcType == dstType)
- return {src};
+ if (srcType == dstType) {
+ result.push_back(src);
+ return;
+ }
+
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(srcType)) {
+ for (auto i : llvm::seq(arrayType.getNumElements())) {
+ Value elem = LLVM::ExtractValueOp::create(builder, loc, src, i);
+ decomposeValueImpl(builder, loc, elem, dstType, result);
+ }
+ return;
+ }
+
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(srcType)) {
+ for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+ Value field = LLVM::ExtractValueOp::create(builder, loc, src,
+ static_cast<int64_t>(i));
+ decomposeValueImpl(builder, loc, field, dstType, result);
+ }
+ return;
+ }
+
+ // Variably sized leaf types (e.g., ptr) — pass through as-is.
+ if (!srcType.isIntOrFloat() && !isa<VectorType>(srcType)) {
+ result.push_back(src);
+ return;
+ }
unsigned srcBitWidth = getBitWidth(srcType);
unsigned dstBitWidth = getBitWidth(dstType);
if (srcBitWidth == dstBitWidth) {
Value cast = LLVM::BitcastOp::create(builder, loc, dstType, src);
- return {cast};
+ result.push_back(cast);
+ return;
}
if (dstBitWidth > srcBitWidth) {
@@ -410,7 +455,8 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
auto largerInt = builder.getIntegerType(dstBitWidth);
Value res = LLVM::ZExtOp::create(builder, loc, largerInt, src);
- return {res};
+ result.push_back(res);
+ return;
}
assert(srcBitWidth % dstBitWidth == 0 &&
"src bit width must be a multiple of dst bit width");
@@ -419,47 +465,94 @@ SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
src = LLVM::BitcastOp::create(builder, loc, vecType, src);
- SmallVector<Value> res;
for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
Value elem = LLVM::ExtractElementOp::create(builder, loc, src, idx);
- res.emplace_back(elem);
+ result.push_back(elem);
}
-
- return res;
}
-Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
- Type dstType) {
- assert(!src.empty() && "src range must not be empty");
- if (src.size() == 1) {
- Value res = src.front();
- if (res.getType() == dstType)
- return res;
+LogicalResult mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
+ Value src, Type dstType,
+ SmallVectorImpl<Value> &result,
+ bool permitVariablySizedScalars) {
+ // Check the type tree before emitting any IR, so that a failing pattern
+ // leaves the IR unmodified.
+ if (!permitVariablySizedScalars &&
+ !isFixedSizeAggregate(src.getType(), dstType))
+ return failure();
- unsigned srcBitWidth = getBitWidth(res.getType());
- unsigned dstBitWidth = getBitWidth(dstType);
- if (dstBitWidth < srcBitWidth) {
- auto largerInt = builder.getIntegerType(srcBitWidth);
- if (res.getType() != largerInt)
- res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+ decomposeValueImpl(builder, loc, src, dstType, result);
+ return success();
+}
+
+/// Recursive implementation of composeValue. Consumes elements from `src`
+/// starting at `offset`, advancing it past the consumed elements.
+static Value composeValueImpl(OpBuilder &builder, Location loc, ValueRange src,
+ size_t &offset, Type dstType) {
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(dstType)) {
+ Value result = LLVM::PoisonOp::create(builder, loc, arrayType);
+ Type elemType = arrayType.getElementType();
+ for (auto i : llvm::seq(arrayType.getNumElements())) {
+ Value elem = composeValueImpl(builder, loc, src, offset, elemType);
+ result = LLVM::InsertValueOp::create(builder, loc, result, elem, i);
+ }
+ return result;
+ }
- auto smallerInt = builder.getIntegerType(dstBitWidth);
- res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+ if (auto structType = dyn_cast<LLVM::LLVMStructType>(dstType)) {
+ Value result = LLVM::PoisonOp::create(builder, loc, structType);
+ for (auto [i, fieldType] : llvm::enumerate(structType.getBody())) {
+ Value field = composeValueImpl(builder, loc, src, offset, fieldType);
+ result = LLVM::InsertValueOp::create(builder, loc, result, field,
+ static_cast<int64_t>(i));
}
+ return result;
+ }
+
+ // Variably sized leaf types (e.g., ptr) — consume and return as-is.
+ if (!dstType.isIntOrFloat() && !isa<VectorType>(dstType))
+ return src[offset++];
+
+ unsigned dstBitWidth = getBitWidth(dstType);
- if (res.getType() != dstType)
- res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+ Value front = src[offset];
+ if (front.getType() == dstType) {
+ ++offset;
+ return front;
+ }
- return res;
+ // Single element wider than or equal to dst: bitcast/trunc.
+ if (front.getType().isIntOrFloat() || isa<VectorType>(front.getType())) {
+ unsigned srcBitWidth = getBitWidth(front.getType());
+ if (srcBitWidth >= dstBitWidth) {
+ ++offset;
+ Value res = front;
+ if (dstBitWidth < srcBitWidth) {
+ auto largerInt = builder.getIntegerType(srcBitWidth);
+ if (res.getType() != largerInt)
+ res = LLVM::BitcastOp::create(builder, loc, largerInt, res);
+
+ auto smallerInt = builder.getIntegerType(dstBitWidth);
+ res = LLVM::TruncOp::create(builder, loc, smallerInt, res);
+ }
+ if (res.getType() != dstType)
+ res = LLVM::BitcastOp::create(builder, loc, dstType, res);
+ return res;
+ }
}
- int64_t numElements = src.size();
- auto srcType = VectorType::get(numElements, src.front().getType());
- Value res = LLVM::PoisonOp::create(builder, loc, srcType);
- for (auto &&[i, elem] : llvm::enumerate(src)) {
+ // Multiple elements narrower than dst: gather into a vector and bitcast.
+ unsigned elemBitWidth = getBitWidth(front.getType());
+ assert(dstBitWidth % elemBitWidth == 0 &&
+ "dst bit width must be a multiple of element bit width");
+ int64_t numElements = dstBitWidth / elemBitWidth;
+ auto vecType = VectorType::get(numElements, front.getType());
+ Value res = LLVM::PoisonOp::create(builder, loc, vecType);
+ for (auto i : llvm::seq(numElements)) {
Value idx = createI32Constant(builder, loc, i);
- res = LLVM::InsertElementOp::create(builder, loc, srcType, res, elem, idx);
+ res = LLVM::InsertElementOp::create(builder, loc, vecType, res,
+ src[offset++], idx);
}
if (res.getType() != dstType)
@@ -468,6 +561,15 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
return res;
}
+Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
+ Type dstType) {
+ assert(!src.empty() && "src range must not be empty");
+ size_t offset = 0;
+ Value result = composeValueImpl(builder, loc, src, offset, dstType);
+ assert(offset == src.size() && "not all decomposed values were consumed");
+ return result;
+}
+
Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
const LLVMTypeConverter &converter,
MemRefType type, Value memRefDesc,
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 6b20a442c30c5..1f958791b41bb 100755
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -851,4 +851,73 @@ func.func @broadcast_4xi16(%arg0 : vector<4xi16>) -> vector<4xi16> {
%0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<4xi16>
func.return %0 : vector<4xi16>
}
+
+// CHECK-LABEL: func @broadcast_2x2xi16
+// CHECK-SAME: (%[[ARG:.+]]: !llvm.array<2 x vector<2xi16>>)
+func.func @broadcast_2x2xi16(%arg0 : vector<2x2xi16>) -> vector<2x2xi16> {
+ // CHECK-DAG: %[[E0:.+]] = llvm.extractvalue %[[ARG]][0] : !llvm.array<2 x vector<2xi16>>
+ // CHECK-DAG: %[[BC0:.+]] = llvm.bitcast %[[E0]] : vector<2xi16> to i32
+ // CHECK-DAG: %[[E1:.+]] = llvm.extractvalue %[[ARG]][1] : !llvm.array<2 x vector<2xi16>>
+ // CHECK-DAG: %[[BC1:.+]] = llvm.bitcast %[[E1]] : vector<2xi16> to i32
+ // CHECK: %[[R0:.+]] = rocdl.readfirstlane %[[BC0]] : i32
+ // CHECK: %[[R1:.+]] = rocdl.readfirstlane %[[BC1]] : i32
+ // CHECK: %[[POISON:.+]] = llvm.mlir.poison : !llvm.array<2 x vector<2xi16>>
+ // CHECK: %[[RBC0:.+]] = llvm.bitcast %[[R0]] : i32 to vector<2xi16>
+ // CHECK: %[[INS0:.+]] = llvm.insertvalue %[[RBC0]], %[[POISON]][0] : !llvm.array<2 x vector<2xi16>>
+ // CHECK: %[[RBC1:.+]] = llvm.bitcast %[[R1]] : i32 to vector<2xi16>
+ // CHECK: %[[INS1:.+]] = llvm.insertvalue %[[RBC1]], %[[INS0]][1] : !llvm.array<2 x vector<2xi16>>
+ // CHECK: llvm.return %[[INS1]]
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : vector<2x2xi16>
+ func.return %0 : vector<2x2xi16>
+}
+
+// CHECK-LABEL: func @broadcast_memref
+func.func @broadcast_memref(%arg0 : memref<?xi32>) -> memref<?xi32> {
+ // CHECK-DAG: %[[PTR0:.+]] = llvm.extractvalue %{{.+}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[PTR1:.+]] = llvm.extractvalue %{{.+}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[OFF:.+]] = llvm.extractvalue %{{.+}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[OFFVEC:.+]] = llvm.bitcast %[[OFF]] : i64 to vector<2xi32>
+ // CHECK-DAG: %[[OFF0:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[OFF1:.+]] = llvm.extractelement %[[OFFVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[SZ_ARR:.+]] = llvm.extractvalue %{{.+}}[3] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[SZ:.+]] = llvm.extractvalue %[[SZ_ARR]][0] : !llvm.array<1 x i64>
+ // CHECK-DAG: %[[SZVEC:.+]] = llvm.bitcast %[[SZ]] : i64 to vector<2xi32>
+ // CHECK-DAG: %[[SZ0:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[SZ1:.+]] = llvm.extractelement %[[SZVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[STR_ARR:.+]] = llvm.extractvalue %{{.+}}[4] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK-DAG: %[[STR:.+]] = llvm.extractvalue %[[STR_ARR]][0] : !llvm.array<1 x i64>
+ // CHECK-DAG: %[[STRVEC:.+]] = llvm.bitcast %[[STR]] : i64 to vector<2xi32>
+ // CHECK-DAG: %[[STR0:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK-DAG: %[[STR1:.+]] = llvm.extractelement %[[STRVEC]][%{{.+}} : i32] : vector<2xi32>
+ //
+ // CHECK: %[[RPTR0:.+]] = rocdl.readfirstlane %[[PTR0]] : !llvm.ptr
+ // CHECK: %[[RPTR1:.+]] = rocdl.readfirstlane %[[PTR1]] : !llvm.ptr
+ // CHECK: %[[ROFF0:.+]] = rocdl.readfirstlane %[[OFF0]] : i32
+ // CHECK: %[[ROFF1:.+]] = rocdl.readfirstlane %[[OFF1]] : i32
+ // CHECK: %[[RSZ0:.+]] = rocdl.readfirstlane %[[SZ0]] : i32
+ // CHECK: %[[RSZ1:.+]] = rocdl.readfirstlane %[[SZ1]] : i32
+ // CHECK: %[[RSTR0:.+]] = rocdl.readfirstlane %[[STR0]] : i32
+ // CHECK: %[[RSTR1:.+]] = rocdl.readfirstlane %[[STR1]] : i32
+ //
+ // CHECK: %[[SOUT:.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[S0:.+]] = llvm.insertvalue %[[RPTR0]], %[[SOUT]][0]
+ // CHECK: %[[S1:.+]] = llvm.insertvalue %[[RPTR1]], %[[S0]][1]
+ // CHECK: %[[OVEC:.+]] = llvm.insertelement %[[ROFF0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[OVEC2:.+]] = llvm.insertelement %[[ROFF1]], %[[OVEC]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[OCAST:.+]] = llvm.bitcast %[[OVEC2]] : vector<2xi32> to i64
+ // CHECK: %[[S2:.+]] = llvm.insertvalue %[[OCAST]], %[[S1]][2]
+ // CHECK: %[[SZVEC2:.+]] = llvm.insertelement %[[RSZ0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[SZVEC3:.+]] = llvm.insertelement %[[RSZ1]], %[[SZVEC2]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[SZCAST:.+]] = llvm.bitcast %[[SZVEC3]] : vector<2xi32> to i64
+ // CHECK: %[[SZA:.+]] = llvm.insertvalue %[[SZCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+ // CHECK: %[[S3:.+]] = llvm.insertvalue %[[SZA]], %[[S2]][3]
+ // CHECK: %[[STRVEC2:.+]] = llvm.insertelement %[[RSTR0]], %{{.+}}[%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[STRVEC3:.+]] = llvm.insertelement %[[RSTR1]], %[[STRVEC2]][%{{.+}} : i32] : vector<2xi32>
+ // CHECK: %[[STRCAST:.+]] = llvm.bitcast %[[STRVEC3]] : vector<2xi32> to i64
+ // CHECK: %[[STR_ARR:.+]] = llvm.insertvalue %[[STRCAST]], %{{.+}}[0] : !llvm.array<1 x i64>
+ // CHECK: %[[S4:.+]] = llvm.insertvalue %[[STR_ARR]], %[[S3]][4]
+ // CHECK: llvm.return %[[S4]]
+ %0 = gpu.subgroup_broadcast %arg0, first_active_lane : memref<?xi32>
+ func.return %0 : memref<?xi32>
+}
}
|
amd-eochoalo
left a comment
There was a problem hiding this comment.
Looks good, just a couple of questions!
Currently, as pointed out in the reviews for llvm#183405, decomposeValues and composeValues should be able to emit zexts and truncations for cases like i48 and vector<3xi16> becoming i32s but currently that's an assert. This commit fixes that limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…3825) Currently, as pointed out in the reviews for #183405, decomposeValues and composeValues should be able to emit zexts and truncations for cases like i48 and vector<3xi16> becoming i32s but currently that's an assert. This commit fixes that limitation. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…183405) This commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example. (This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like) --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
…m#183825) Currently, as pointed out in the reviews for llvm#183405, decomposeValues and composeValues should be able to emit zexts and truncations for cases like i48 and vector<3xi16> becoming i32s but currently that's an assert. This commit fixes that limitation. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit updates the LLVM::decomposeValue and LLVM::composeValue methods to handle aggregate types - LLVM arrays and structs, and to have different behaviors on dealing with types like pointers that can't be bitcast to fixed-size integers. This allows the "any type" on gpu.subgroup_broadcast to be more comprehensive - you can broadcast a memref to a subgroup by decomposing it, for example.
(This branched off of getting an LLM to implement ValueuboundsOpInterface on subgroup_broadcast, having it add handling for the dimensions of shaped types, and realizing that there's no fundamental reason you can't broadcast a memref or the like)