34 changes: 33 additions & 1 deletion llvm/test/MC/X86/apx/bzhi-att.s
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
# RUN: llvm-mc -triple x86_64 --show-encoding %s | FileCheck %s
# RUN: not llvm-mc -triple i386 -show-encoding %s 2>&1 | FileCheck %s --check-prefix=ERROR

# ERROR-COUNT-4: error:
# ERROR-COUNT-12: error:
# ERROR-NOT: error:
# CHECK: {nf} bzhil %ecx, %edx, %r10d
# CHECK: encoding: [0x62,0x72,0x74,0x0c,0xf5,0xd2]
{nf} bzhil %ecx, %edx, %r10d

# CHECK: {evex} bzhil %ecx, %edx, %r10d
# CHECK: encoding: [0x62,0x72,0x74,0x08,0xf5,0xd2]
{evex} bzhil %ecx, %edx, %r10d

# CHECK: {nf} bzhil %ecx, 123(%rax,%rbx,4), %edx
# CHECK: encoding: [0x62,0xf2,0x74,0x0c,0xf5,0x54,0x98,0x7b]
{nf} bzhil %ecx, 123(%rax,%rbx,4), %edx

# CHECK: {evex} bzhil %ecx, 123(%rax,%rbx,4), %edx
# CHECK: encoding: [0x62,0xf2,0x74,0x08,0xf5,0x54,0x98,0x7b]
{evex} bzhil %ecx, 123(%rax,%rbx,4), %edx

# CHECK: {nf} bzhiq %r9, %r15, %r11
# CHECK: encoding: [0x62,0x52,0xb4,0x0c,0xf5,0xdf]
{nf} bzhiq %r9, %r15, %r11

# CHECK: {evex} bzhiq %r9, %r15, %r11
# CHECK: encoding: [0x62,0x52,0xb4,0x08,0xf5,0xdf]
{evex} bzhiq %r9, %r15, %r11

# CHECK: {nf} bzhiq %r9, 123(%rax,%rbx,4), %r15
# CHECK: encoding: [0x62,0x72,0xb4,0x0c,0xf5,0x7c,0x98,0x7b]
{nf} bzhiq %r9, 123(%rax,%rbx,4), %r15

# CHECK: {evex} bzhiq %r9, 123(%rax,%rbx,4), %r15
# CHECK: encoding: [0x62,0x72,0xb4,0x08,0xf5,0x7c,0x98,0x7b]
{evex} bzhiq %r9, 123(%rax,%rbx,4), %r15

# CHECK: bzhil %r18d, %r22d, %r26d
# CHECK: encoding: [0x62,0x6a,0x6c,0x00,0xf5,0xd6]
bzhil %r18d, %r22d, %r26d
Expand Down
32 changes: 32 additions & 0 deletions llvm/test/MC/X86/apx/bzhi-intel.s
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
# RUN: llvm-mc -triple x86_64 -x86-asm-syntax=intel -output-asm-variant=1 --show-encoding %s | FileCheck %s

# CHECK: {nf} bzhi r10d, edx, ecx
# CHECK: encoding: [0x62,0x72,0x74,0x0c,0xf5,0xd2]
{nf} bzhi r10d, edx, ecx

# CHECK: {evex} bzhi r10d, edx, ecx
# CHECK: encoding: [0x62,0x72,0x74,0x08,0xf5,0xd2]
{evex} bzhi r10d, edx, ecx

# CHECK: {nf} bzhi edx, dword ptr [rax + 4*rbx + 123], ecx
# CHECK: encoding: [0x62,0xf2,0x74,0x0c,0xf5,0x54,0x98,0x7b]
{nf} bzhi edx, dword ptr [rax + 4*rbx + 123], ecx

# CHECK: {evex} bzhi edx, dword ptr [rax + 4*rbx + 123], ecx
# CHECK: encoding: [0x62,0xf2,0x74,0x08,0xf5,0x54,0x98,0x7b]
{evex} bzhi edx, dword ptr [rax + 4*rbx + 123], ecx

# CHECK: {nf} bzhi r11, r15, r9
# CHECK: encoding: [0x62,0x52,0xb4,0x0c,0xf5,0xdf]
{nf} bzhi r11, r15, r9

# CHECK: {evex} bzhi r11, r15, r9
# CHECK: encoding: [0x62,0x52,0xb4,0x08,0xf5,0xdf]
{evex} bzhi r11, r15, r9

# CHECK: {nf} bzhi r15, qword ptr [rax + 4*rbx + 123], r9
# CHECK: encoding: [0x62,0x72,0xb4,0x0c,0xf5,0x7c,0x98,0x7b]
{nf} bzhi r15, qword ptr [rax + 4*rbx + 123], r9

# CHECK: {evex} bzhi r15, qword ptr [rax + 4*rbx + 123], r9
# CHECK: encoding: [0x62,0x72,0xb4,0x08,0xf5,0x7c,0x98,0x7b]
{evex} bzhi r15, qword ptr [rax + 4*rbx + 123], r9

# CHECK: bzhi r26d, r22d, r18d
# CHECK: encoding: [0x62,0x6a,0x6c,0x00,0xf5,0xd6]
bzhi r26d, r22d, r18d
Expand Down
12 changes: 12 additions & 0 deletions llvm/test/TableGen/x86-fold-tables.inc
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,10 @@ static const X86FoldTableEntry Table1[] = {
{X86::AND8rr_NF_ND, X86::AND8mr_NF_ND, 0},
{X86::BEXTR32rr, X86::BEXTR32rm, 0},
{X86::BEXTR32rr_EVEX, X86::BEXTR32rm_EVEX, 0},
{X86::BEXTR32rr_NF, X86::BEXTR32rm_NF, 0},
{X86::BEXTR64rr, X86::BEXTR64rm, 0},
{X86::BEXTR64rr_EVEX, X86::BEXTR64rm_EVEX, 0},
{X86::BEXTR64rr_NF, X86::BEXTR64rm_NF, 0},
{X86::BEXTRI32ri, X86::BEXTRI32mi, 0},
{X86::BEXTRI64ri, X86::BEXTRI64mi, 0},
{X86::BLCFILL32rr, X86::BLCFILL32rm, 0},
Expand All @@ -640,18 +642,24 @@ static const X86FoldTableEntry Table1[] = {
{X86::BLSFILL64rr, X86::BLSFILL64rm, 0},
{X86::BLSI32rr, X86::BLSI32rm, 0},
{X86::BLSI32rr_EVEX, X86::BLSI32rm_EVEX, 0},
{X86::BLSI32rr_NF, X86::BLSI32rm_NF, 0},
{X86::BLSI64rr, X86::BLSI64rm, 0},
{X86::BLSI64rr_EVEX, X86::BLSI64rm_EVEX, 0},
{X86::BLSI64rr_NF, X86::BLSI64rm_NF, 0},
{X86::BLSIC32rr, X86::BLSIC32rm, 0},
{X86::BLSIC64rr, X86::BLSIC64rm, 0},
{X86::BLSMSK32rr, X86::BLSMSK32rm, 0},
{X86::BLSMSK32rr_EVEX, X86::BLSMSK32rm_EVEX, 0},
{X86::BLSMSK32rr_NF, X86::BLSMSK32rm_NF, 0},
{X86::BLSMSK64rr, X86::BLSMSK64rm, 0},
{X86::BLSMSK64rr_EVEX, X86::BLSMSK64rm_EVEX, 0},
{X86::BLSMSK64rr_NF, X86::BLSMSK64rm_NF, 0},
{X86::BLSR32rr, X86::BLSR32rm, 0},
{X86::BLSR32rr_EVEX, X86::BLSR32rm_EVEX, 0},
{X86::BLSR32rr_NF, X86::BLSR32rm_NF, 0},
{X86::BLSR64rr, X86::BLSR64rm, 0},
{X86::BLSR64rr_EVEX, X86::BLSR64rm_EVEX, 0},
{X86::BLSR64rr_NF, X86::BLSR64rm_NF, 0},
{X86::BSF16rr, X86::BSF16rm, 0},
{X86::BSF32rr, X86::BSF32rm, 0},
{X86::BSF64rr, X86::BSF64rm, 0},
Expand All @@ -660,8 +668,10 @@ static const X86FoldTableEntry Table1[] = {
{X86::BSR64rr, X86::BSR64rm, 0},
{X86::BZHI32rr, X86::BZHI32rm, 0},
{X86::BZHI32rr_EVEX, X86::BZHI32rm_EVEX, 0},
{X86::BZHI32rr_NF, X86::BZHI32rm_NF, 0},
{X86::BZHI64rr, X86::BZHI64rm, 0},
{X86::BZHI64rr_EVEX, X86::BZHI64rm_EVEX, 0},
{X86::BZHI64rr_NF, X86::BZHI64rm_NF, 0},
{X86::CMP16rr, X86::CMP16rm, 0},
{X86::CMP32rr, X86::CMP32rm, 0},
{X86::CMP64rr, X86::CMP64rm, 0},
Expand Down Expand Up @@ -1876,8 +1886,10 @@ static const X86FoldTableEntry Table2[] = {
{X86::AND8rr_NF_ND, X86::AND8rm_NF_ND, 0},
{X86::ANDN32rr, X86::ANDN32rm, 0},
{X86::ANDN32rr_EVEX, X86::ANDN32rm_EVEX, 0},
{X86::ANDN32rr_NF, X86::ANDN32rm_NF, 0},
{X86::ANDN64rr, X86::ANDN64rm, 0},
{X86::ANDN64rr_EVEX, X86::ANDN64rm_EVEX, 0},
{X86::ANDN64rr_NF, X86::ANDN64rm_NF, 0},
{X86::ANDNPDrr, X86::ANDNPDrm, TB_ALIGN_16},
{X86::ANDNPSrr, X86::ANDNPSrm, TB_ALIGN_16},
{X86::ANDPDrr, X86::ANDPDrm, TB_ALIGN_16},
Expand Down
56 changes: 48 additions & 8 deletions mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
}

/// Returns the number of shape sizes that is either dynamic or greater than 1.
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
return llvm::count_if(
shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
}

/// Packing one-dimensional tensor can be expressed as an expand shape op.
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
Expand All @@ -34,26 +40,60 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
reassociation);
}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");

/// Returns success() if it is only packing on the innermost dimension.
LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
PackOp packOp) const {
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(
packOp,
"expects outer_dims_perm is empty or an identity permutation");
}

RankedTensorType sourceType = packOp.getSourceType();
RankedTensorType destType = packOp.getDestType();
int64_t srcRank = packOp.getSourceRank();
ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != sourceType.getRank())) {
if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
return rewriter.notifyMatchFailure(
packOp, "expects packing at the innermost dimension");
}
return success();
}

/// Returns success() if there is only 1 dimension size in source being
/// greater than 1 and packing only happens on the dimension. It assumes that
/// the pack op does not have padding value.
LogicalResult isPack1DSrc(RewriterBase &rewriter, PackOp packOp) const {
assert(!packOp.getPaddingValue() &&
"expect the op does not have padding value.");
ArrayRef<int64_t> srcShape = packOp.getSourceType().getShape();
if (getNumGtOneDims(srcShape) > 1) {
return rewriter.notifyMatchFailure(
packOp, "expects source to have at most one non-unit dims");
}

// The pack op does not have padding value. Non-unit inner tile size must be
// be used by the non-unit dimension.
SmallVector<int64_t> innerTiles = packOp.getStaticTiles();
if (getNumGtOneDims(innerTiles) > 1) {
return rewriter.notifyMatchFailure(
packOp, "expects at most one non-unit inner tiles");
}

return success();
}

LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");

if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPack1DSrc(rewriter, packOp))) {
return failure();
}

RankedTensorType sourceType = packOp.getSourceType();
RankedTensorType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,57 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x

// -----

// CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
// CHECK: return %[[EXPANDED]]
func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
%empty = tensor.empty() : tensor<1x32x1x1xf32>
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %empty
: tensor<1x32xf32> -> tensor<1x32x1x1xf32>
return %pack : tensor<1x32x1x1xf32>
}

// -----

// CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
// CHECK: return %[[EXPANDED]]
func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
%empty = tensor.empty() : tensor<1x16x1x2xf32>
%pack = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 2] into %empty
: tensor<1x32xf32> -> tensor<1x16x1x2xf32>
return %pack : tensor<1x16x1x2xf32>
}

// -----

// CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: return %[[EXPANDED]]
func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
%empty = tensor.empty() : tensor<1x16x2x1xf32>
%pack = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [2, 1] into %empty
: tensor<32x1xf32> -> tensor<1x16x2x1xf32>
return %pack : tensor<1x16x2x1xf32>
}

// -----

// CHECK-LABEL: func.func @pack_32x1_to_16x1x1x2
// CHECK-NOT: tensor.expand_shape
// CHECK: tensor.pack
func.func @pack_32x1_to_16x1x1x2(%arg0 : tensor<32x1xf32>) -> tensor<16x1x1x2xf32> {
%empty = tensor.empty() : tensor<16x1x1x2xf32>
%pack = tensor.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [1, 2] into %empty
: tensor<32x1xf32> -> tensor<16x1x1x2xf32>
return %pack : tensor<16x1x1x2xf32>
}

// -----

// CHECK-LABEL: func.func @unpack_1d_to_collapse
// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
Expand Down