Skip to content

Commit ff96267

Browse files
committed
[mlir][vector] Add folder for bitcast of integer splat constants
This is a similar to the existing folder for f16 to f32 added with D96041 but instead for integer types where destination bits > source bits. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D142922
1 parent d6eaaa1 commit ff96267

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4950,6 +4950,27 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
49504950
}
49514951
}
49524952

4953+
if (auto intPack = sourceConstant.dyn_cast<DenseIntElementsAttr>()) {
4954+
if (intPack.isSplat()) {
4955+
auto splat = intPack.getSplatValue<IntegerAttr>();
4956+
4957+
if (dstElemType.isa<IntegerType>()) {
4958+
uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
4959+
uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
4960+
4961+
// Casting to a larger integer bit width.
4962+
if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
4963+
APInt intBits = splat.getValue().zext(dstBitWidth);
4964+
4965+
// Duplicate the lower width element.
4966+
for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
4967+
intBits = (intBits << srcBitWidth) | intBits;
4968+
return DenseElementsAttr::get(getResultVectorType(), intBits);
4969+
}
4970+
}
4971+
}
4972+
}
4973+
49534974
return {};
49544975
}
49554976

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,20 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
741741
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
742742
}
743743

744+
// CHECK-LABEL: func @bitcast_i8_to_i32
745+
// bit pattern: 0xA0A0A0A0
746+
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
747+
// bit pattern: 0x00000000
748+
// CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi32>
749+
// CHECK: return %[[CST0]], %[[CST1]]
750+
func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) {
751+
%cst0 = arith.constant dense<0> : vector<16xi8> // bit pattern: 0x00
752+
%cst1 = arith.constant dense<160> : vector<16xi8> // bit pattern: 0xA0
753+
%cast0 = vector.bitcast %cst0: vector<16xi8> to vector<4xi32>
754+
%cast1 = vector.bitcast %cst1: vector<16xi8> to vector<4xi32>
755+
return %cast0, %cast1: vector<4xi32>, vector<4xi32>
756+
}
757+
744758
// -----
745759

746760
// CHECK-LABEL: broadcast_folding1

0 commit comments

Comments
 (0)