[mlir][x86] Lower Int8 vector.contract to AVX2/AVX10 dp (online packing)#189386
Merged
Conversation
Member
|
@llvm/pr-subscribers-mlir Author: Arun Thangamani (arun-thmn) ChangesA transform pass to lower flat layout
Patch is 23.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/189386.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
index a4496f3620b97..d5b2ac0bc9e31 100644
--- a/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
+++ b/mlir/lib/Dialect/X86/Transforms/VectorContractToPackedTypeDotProduct.cpp
@@ -82,22 +82,78 @@ static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
auto elemTy = Ty.getElementType();
auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
- auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
- opA->getResult(0));
- auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
- opB->getResult(0));
+ Value srcBuff;
+ SmallVector<Value> indexVals;
- static constexpr int64_t maskLo[] = {
+ llvm::TypeSwitch<Operation *>(opA).Case<TransferReadOp, LoadOp>(
+ [&](auto readOp) {
+ srcBuff = readOp.getOperand(0);
+
+ auto indices = readOp.getIndices();
+ indexVals.reserve(indices.size());
+
+ llvm::transform(
+ indices, std::back_inserter(indexVals), [&](OpFoldResult ofr) {
+ return mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+ });
+ });
+
+ auto vec1 = vector::LoadOp::create(rewriter, loc, flatTy, srcBuff, indexVals);
+
+ unsigned int offset = 1;
+ if (elemTy.isSignlessInteger(8))
+ offset = 2;
+
+ Value cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
+ auto nextIndx =
+ arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), cOffset,
+ indexVals[indexVals.size() - 2]);
+ indexVals[indexVals.size() - 2] = nextIndx;
+
+ auto vec2 = vector::LoadOp::create(rewriter, loc, flatTy, srcBuff, indexVals);
+
+ static constexpr int64_t maskLo_bf16[] = {
0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
- static constexpr int64_t maskHi[] = {
+ static constexpr int64_t maskHi_bf16[] = {
4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
- auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
- castB, maskLo);
- auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
- castB, maskHi);
+ static constexpr int64_t maskLo_int8_avx2[] = {
+ 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
+ 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59};
+ static constexpr int64_t maskHi_int8_avx2[] = {
+ 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
+ 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
+
+ static constexpr int64_t maskLo_int8_avx10[] = {
+ 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99,
+ 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107,
+ 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115,
+ 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123};
+ static constexpr int64_t maskHi_int8_avx10[] = {
+ 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103,
+ 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111,
+ 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119,
+ 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127};
+
+ mlir::DenseI64ArrayAttr maskLo = rewriter.getDenseI64ArrayAttr(maskLo_bf16);
+ mlir::DenseI64ArrayAttr maskHi = rewriter.getDenseI64ArrayAttr(maskHi_bf16);
+
+ if (elemTy.isSignlessInteger(8)) {
+ maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx10);
+ maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx10);
+
+ if (nonUnitDimAcc == 32) {
+ maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx2);
+ maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx2);
+ }
+ }
+
+ auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
+ vec2, maskLo);
+ auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
+ vec2, maskHi);
auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
@@ -159,8 +215,8 @@ struct VectorContractToPackedTypeDotProduct
isInVnniLayout(contractOp.getOperation(),
contractOp.getIndexingMapsArray(), blockingFactor);
- if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
- return failure();
+ // if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
+ // return failure();
VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
@@ -217,7 +273,7 @@ struct VectorContractToPackedTypeDotProduct
if (!isVnni && (extraFlatDim != blockingFactor))
return rewriter.notifyMatchFailure(
- contractOp, "The K or reduction dim for flat layout should be 2.");
+ contractOp, "The K or reduction dim for flat layout should be 2/4.");
if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
(lhsTy.getElementType().isSignlessInteger(8) &&
diff --git a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
index f4279a3eb507a..aea6bf6adcd4a 100644
--- a/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
+++ b/mlir/lib/Dialect/X86/Utils/X86Utils.cpp
@@ -116,15 +116,20 @@ struct ShuffleMasks {
llvm::ArrayRef<int64_t> maskHi;
};
-inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
+inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc, bool isInt8Avx2) {
// We only support these two layouts for now.
assert((nonUnitDimAcc == 8 || nonUnitDimAcc == 16) &&
"Unsupported nonUnitDimAcc value");
+
// Do interleaving between two <8xf32> targeting AVX2.
static constexpr int64_t maskLo8[] = {0, 8, 1, 9, 2, 10, 3, 11};
static constexpr int64_t maskHi8[] = {4, 12, 5, 13, 6, 14, 7, 15};
- // Shuffle two <16xf32> as below targeting AVX512.
+ // Do interleaving between two <8xi32> targeting AVX2.
+ static constexpr int64_t maskLo8_avx2_int8[] = {0, 1, 2, 3, 8, 9, 10, 11};
+ static constexpr int64_t maskHi8_avx2_int8[] = {4, 5, 6, 7, 12, 13, 14, 15};
+
+ // Shuffle two <16xf32/i32> as below targeting AVX512.
static constexpr int64_t maskLo16[] = {0, 1, 2, 3, 16, 17, 18, 19,
4, 5, 6, 7, 20, 21, 22, 23};
static constexpr int64_t maskHi16[] = {8, 9, 10, 11, 24, 25, 26, 27,
@@ -133,6 +138,9 @@ inline ShuffleMasks getShuffleMasks(int64_t nonUnitDimAcc) {
if (nonUnitDimAcc == 16)
return {maskLo16, maskHi16};
+ if (isInt8Avx2)
+ return {maskLo8_avx2_int8, maskHi8_avx2_int8};
+
return {maskLo8, maskHi8};
}
@@ -255,7 +263,8 @@ LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA,
auto castB =
vector::ShapeCastOp::create(rewriter, loc, flatTy, opB->getResult(0));
- auto masks = getShuffleMasks(nonUnitDimAcc);
+ auto masks = getShuffleMasks(
+ nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
auto shuffleLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
castB, masks.maskLo);
@@ -313,7 +322,8 @@ LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter,
auto castB = vector::ShapeCastOp::create(rewriter, loc, flatTy, vecB);
// TODO: derive shuffle masks instead of hard-coding
- auto masks = getShuffleMasks(nonUnitDimAcc);
+ auto masks = getShuffleMasks(
+ nonUnitDimAcc, (elemTy.isSignlessInteger(32) && nonUnitDimAcc == 8));
auto shuffledLo = vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
castB, masks.maskLo);
diff --git a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
index 0953ee042a24d..f861d357739a3 100644
--- a/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
+++ b/mlir/test/Dialect/X86/vector-contract-to-packed-type-dotproduct.mlir
@@ -412,6 +412,144 @@ module attributes {transform.with_named_sequence} {
// -----
+!vecA = vector<1x4xi8>
+!vecB = vector<4x16xi8>
+!vecC = vector<1x16xi32>
+!memrefA = memref<4x4xi8>
+!memrefB = memref<4x64xi8>
+!memrefC = memref<4x64xi32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_i8_avx10dp_flat_layout(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %arg2[%c0, %c16], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %5 = vector.transfer_read %arg1[%c0, %c16], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %4, %2
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %5, %3
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_i8_avx10dp_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+// CHECK: vector.shuffle{{.*}}[0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123] : vector<64xi8>, vector<64xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127] : vector<64xi8>, vector<64xi8>
+// CHECK: x86.avx10.dot.i8
+// CHECK: x86.avx10.dot.i8
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xi32>, vector<16xi32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x4xi8>
+!vecB = vector<4x8xi8>
+!vecC = vector<1x8xi32>
+!memrefA = memref<4x4xi8>
+!memrefB = memref<4x64xi8>
+!memrefC = memref<4x64xi32>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @matmul_i8_avx2dp_flat_layout(
+ %arg0: !memrefA, %arg1: !memrefB, %arg2: !memrefC) -> !memrefC
+{
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %0 = ub.poison : i8
+ %32 = ub.poison : i32
+ %1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefA, !vecA
+ %2 = vector.transfer_read %arg2[%c0, %c0], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %3 = vector.transfer_read %arg2[%c0, %c8], %32 {in_bounds = [true, true]} :
+ !memrefC, !vecC
+ %4 = vector.transfer_read %arg1[%c0, %c0], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+ %5 = vector.transfer_read %arg1[%c0, %c8], %0 {in_bounds = [true, true]} :
+ !memrefB, !vecB
+
+ %6 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %4, %2
+ : !vecA, !vecB into !vecC
+
+ %7 = vector.contract {
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %1, %5, %3
+ : !vecA, !vecB into !vecC
+
+ vector.transfer_write %6, %arg2[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %7, %arg2[%c0, %c8] {in_bounds = [true, true]} : !vecC, !memrefC
+
+ return %arg2 : !memrefC
+}
+
+// CHECK-LABEL: @matmul_i8_avx2dp_flat_layout
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 8, 9, 10, 11] : vector<8xi32>, vector<8xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 5, 6, 7, 12, 13, 14, 15] : vector<8xi32>, vector<8xi32>
+// CHECK: vector.shuffle{{.*}}[0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59] : vector<32xi8>, vector<32xi8>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63] : vector<32xi8>, vector<32xi8>
+// CHECK: x86.avx.dot.i8
+// CHECK: x86.avx.dot.i8
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 8, 9, 10, 11] : vector<8xi32>, vector<8xi32>
+// CHECK-NEXT: vector.shuffle{{.*}}[4, 5, 6, 7, 12, 13, 14, 15] : vector<8xi32>, vector<8xi32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
!vecA = vector<1x2xbf16>
!vecB = vector<2x16xbf16>
!vecC = vector<1x16xf32>
@@ -640,6 +778,102 @@ func.func @matmul_bf16dp_flat_layout_B_shuffled(
// CHECK: x86.avx512.dot
// CHECK-NOT: vector.contract
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.x86.vector_contract_to_packed_type_dot_product
+ } : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+!vecA = vector<1x1x4xi8>
+!vecB = vector<1x4x16xi8>
+!vecC = vector<1x16xi32>
+!memrefA = memref<1x2x4xi8, strided<[16384, 256, 1], offset: ?>>
+!memrefB = memref<1x4x32xi8, strided<[32768, 128, 1], offset: ?>>
+!memrefC = memref<2x32xi32, strided<[128, 1], offset: ?>>
+
+#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+func.func @brgemm_int8_flat_avx10(%arg0: memref<16x64x256xi8>, %arg1: memref<16x256x128xi8>, %arg2: memref<64x128xi32>) -> memref<64x128xi32> {
+ %0 = ub.poison : i32
+ %1 = ub.poison : i8
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ %c2 = arith.constant 2 : index
+ %c32 = arith.constant 32 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg3 = %c0 to %c64 step %c2 {
+ scf.for %arg4 = %c0 to %c128 step %c32 {
+ %subview = memref.subview %arg2[%arg3, %arg4] [2, 32] [1, 1] : memref<64x128xi32> to !memrefC
+ %2 = vector.transfer_read %subview[%c0, %c0], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %3 = vector.transfer_read %subview[%c0, %c16], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %4 = vector.transfer_read %subview[%c1, %c0], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %5 = vector.transfer_read %subview[%c1, %c16], %0 {in_bounds = [true, true]}
+ : !memrefC, !vecC
+ %6:4 = scf.for %arg5 = %c0 to %c16 step %c1 iter_args(%arg6 = %2, %arg7 = %3, %arg8 = %4, %arg9 = %5) -> (!vecC, !vecC, !vecC, !vecC) {
+ %7:4 = scf.for %arg10 = %c0 to %c256 step %c4 iter_args(%arg11 = %arg6, %arg12 = %arg7, %arg13 = %arg8, %arg14 = %arg9) -> (!vecC, !vecC, !vecC, !vecC) {
+ %subview_0 = memref.subview %arg0[%arg5, %arg3, %arg10] [1, 2, 4] [1, 1, 1] : memref<16x64x256xi8> to !memrefA
+ %subview_1 = memref.subview %arg1[%arg5, %arg10, %arg4] [1, 4, 32] [1, 1, 1] : memref<16x256x128xi8> to !memrefB
+ %8 = vector.transfer_read %subview_0[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]}
+ : !memrefA, !vecA
+ %9 = vector.transfer_read %subview_0[%c0, %c1, %c0], %1 {in_bounds = [true, true, true]}
+ : !memrefA, !vecA
+ %10 = vector.transfer_read %subview_1[%c0, %c0, %c0], %1 {in_bounds = [true, true, true]}
+ : !memrefB, !vecB
+ %11 = vector.transfer_read %subview_1[%c0, %c0, %c16], %1 {in_bounds = [true, true, true]}
+ : !memrefB, !vecB
+ %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %10, %arg11 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+ %13 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %8, %11, %arg12 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+ %14 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %10, %arg13 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+ %15 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types =
+ ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+ %9, %11, %arg14 {unroll_shape = array<i64: 1, 1, 16, 4>} : !vecA, !vecB into !vecC
+ scf.yield %12, %13, %14, %15 : !vecC, !vecC, !vecC, !vecC
+ }
+ scf.yield %7#0, %7#1, %7#2, %7#3 : !vecC, !vecC, !vecC, !vecC
+ }
+ vector.transfer_write %6#3, %subview[%c1, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#2, %subview[%c1, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#1, %subview[%c0, %c16] {in_bounds = [true, true]} : !vecC, !memrefC
+ vector.transfer_write %6#0, %subview[%c0, %c0] {in_bounds = [true, true]} : !vecC, !memrefC
+ }
+ }
+ %alloc = memref.alloc() : memref<64x128xi32>
+ memref.copy %arg2, %alloc : memref<64x128xi32> to memref<64x128xi32>
+ return %alloc : memref<64x128xi32>
+}
+
+// CHECK-LABEL: @brgemm_int8_flat_avx10
+// CHECK: vector.shuffle{{.*}}[0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23] : vector<16xi32>, vector<16xi32>
+// CHECK-NEXT: vector.shuf...
[truncated]
|
Contributor
Author
|
Any Comments? Let me know please. |
adam-smnk
reviewed
Apr 8, 2026
adam-smnk
approved these changes
Apr 9, 2026
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/52/builds/16431 Here is the relevant piece of the build log for the reference |
arun-thmn
added a commit
to libxsmm/tpp-mlir
that referenced
this pull request
Apr 15, 2026
Bump TPP-MLIR on the recent LLVM commit to include the below changes: 1. llvm/llvm-project#189386 2. llvm/llvm-project#188192
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
A transform pass to lower flat layout
int8packed typevector.contractoperation to:x86.avx10.int8.dp- for vector length of16, andx86.avx.int8.dp- for vector length of8via online packing.