Skip to content

Commit

Permalink
Add vp2intersect to AVX512 dialect.
Browse files Browse the repository at this point in the history
Adds vp2intersect to the AVX512 dialect and defines a lowering to the
LLVM dialect.

Author: Matthias Springer <springerm@google.com>

Differential Revision: https://reviews.llvm.org/D95301
  • Loading branch information
matthias-springer authored and nicolasvasilache committed Jan 26, 2021
1 parent d705c2f commit 90ebc48
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 7 deletions.
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/AVX512/AVX512.td
Expand Up @@ -96,4 +96,41 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
}

def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
AllTypesMatch<["a", "b"]>,
TypesMatchWith<"k1 has the same number of bits as elements in a",
"a", "k1",
"IntegerType::get($_self.getContext(), "
"($_self.cast<VectorType>().getShape()[0]))">,
TypesMatchWith<"k2 has the same number of bits as elements in b",
// Should use `b` instead of `a`, but that would require
// adding `type($b)` to assemblyFormat.
"a", "k2",
"IntegerType::get($_self.getContext(), "
"($_self.cast<VectorType>().getShape()[0]))">]> {
let summary = "Vp2Intersect op";
let description = [{
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
LLVMAVX512 operation: `llvm.vp2intersect.d.512` or
`llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is
applied to.

#### From the Intel Intrinsics Guide:

Compute intersection of packed integer vectors `a` and `b`, and store
indication of match in the corresponding bit of two mask registers
specified by `k1` and `k2`. A match in corresponding elements of `a` and
`b` is indicated by a set bit in the corresponding bit of the mask
registers.
}];
let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
);
let results = (outs AnyTypeOf<[I16, I8]>:$k1,
AnyTypeOf<[I16, I8]>:$k2
);
let assemblyFormat =
"$a `,` $b attr-dict `:` type($a)";
}

#endif // AVX512_OPS
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/AVX512/AVX512Dialect.h.inc"
Expand Down
20 changes: 14 additions & 6 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td
Expand Up @@ -28,25 +28,33 @@ def LLVMAVX512_Dialect : Dialect {
// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system
//----------------------------------------------------------------------------//

class LLVMAVX512_IntrOp<string mnemonic, list<OpTrait> traits = []> :
class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
"x86_avx512_" # !subst(".", "_", mnemonic),
[], [], traits, 1>;
[], [], traits, numResults>;

def LLVM_x86_avx512_mask_rndscale_ps_512 :
LLVMAVX512_IntrOp<"mask.rndscale.ps.512">,
LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;

def LLVM_x86_avx512_mask_rndscale_pd_512 :
LLVMAVX512_IntrOp<"mask.rndscale.pd.512">,
LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;

def LLVM_x86_avx512_mask_scalef_ps_512 :
LLVMAVX512_IntrOp<"mask.scalef.ps.512">,
LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;

def LLVM_x86_avx512_mask_scalef_pd_512 :
LLVMAVX512_IntrOp<"mask.scalef.pd.512">,
LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;

def LLVM_x86_avx512_vp2intersect_d_512 :
LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
Arguments<(ins LLVM_Type, LLVM_Type)>;

def LLVM_x86_avx512_vp2intersect_q_512 :
LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>,
Arguments<(ins LLVM_Type, LLVM_Type)>;

#endif // AVX512_OPS
27 changes: 26 additions & 1 deletion mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
Expand Up @@ -77,13 +77,38 @@ struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
return failure();
}
};

struct Vp2IntersectOp512Conversion
: public ConvertOpToLLVMPattern<Vp2IntersectOp> {
explicit Vp2IntersectOp512Conversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertOpToLLVMPattern<Vp2IntersectOp>(typeConverter) {}

LogicalResult
matchAndRewrite(Vp2IntersectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type elementType =
op.a().getType().template cast<VectorType>().getElementType();
if (elementType.isInteger(32))
return LLVM::detail::oneToOneRewrite(
op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands,
*getTypeConverter(), rewriter);
if (elementType.isInteger(64))
return LLVM::detail::oneToOneRewrite(
op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands,
*getTypeConverter(), rewriter);
return failure();
}
};
} // namespace

/// Populate the given list with patterns that convert from AVX512 to LLVM.
void mlir::populateAVX512ToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<MaskRndScaleOp512Conversion,
ScaleFOp512Conversion>(&converter.getContext(), converter);
ScaleFOp512Conversion,
Vp2IntersectOp512Conversion>(&converter.getContext(),
converter);
// clang-format on
}
10 changes: 10 additions & 0 deletions mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir
Expand Up @@ -16,3 +16,13 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
// Keep results alive.
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
}

func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (i16, i16, i8, i8)
{
// CHECK: llvm_avx512.vp2intersect.d.512
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: llvm_avx512.vp2intersect.q.512
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : i16, i16, i8, i8
}
10 changes: 10 additions & 0 deletions mlir/test/Dialect/AVX512/roundtrip.mlir
Expand Up @@ -19,3 +19,13 @@ func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16,
%1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
return %0, %1: vector<16xf32>, vector<8xf64>
}

func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (i16, i16, i8, i8)
{
// CHECK: avx512.vp2intersect {{.*}} : vector<16xi32>
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: avx512.vp2intersect {{.*}} : vector<8xi64>
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : i16, i16, i8, i8
}
20 changes: 20 additions & 0 deletions mlir/test/Target/avx512.mlir
Expand Up @@ -29,3 +29,23 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
(vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64>
llvm.return %1: vector<8xf64>
}

// CHECK-LABEL: define <{ i16, i16 }> @LLVM_x86_vp2intersect_d_512
llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
-> !llvm.struct<packed (i16, i16)>
{
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
%0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) :
(vector<16xi32>, vector<16xi32>) -> !llvm.struct<packed (i16, i16)>
llvm.return %0 : !llvm.struct<packed (i16, i16)>
}

// CHECK-LABEL: define <{ i8, i8 }> @LLVM_x86_vp2intersect_q_512
llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
-> !llvm.struct<packed (i8, i8)>
{
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
%0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) :
(vector<8xi64>, vector<8xi64>) -> !llvm.struct<packed (i8, i8)>
llvm.return %0 : !llvm.struct<packed (i8, i8)>
}

0 comments on commit 90ebc48

Please sign in to comment.