diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 76b08e664ee76..5e2461b4508e4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -87,21 +87,21 @@ class LLVM_TernarySameArgsIntrOpF traits = []> : class LLVM_CountZerosIntrOp traits = []> : LLVM_OneResultIntrOp { let arguments = (ins LLVM_ScalarOrVectorOf:$in, I1Attr:$is_zero_poison); } def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> { let arguments = (ins LLVM_ScalarOrVectorOf:$in, I1Attr:$is_int_min_poison); } def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> { let arguments = (ins LLVM_ScalarOrVectorOf:$in, I32Attr:$bit); } @@ -360,8 +360,8 @@ def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">; def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], [DeclareOpInterfaceMethods], - /*requiresFastmath=*/0, /*immArgPositions=*/[0], - /*immArgAttrNames=*/["size"]> { + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, + /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> { let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let results = (outs LLVM_DefaultPointer:$res); let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; @@ -412,6 +412,7 @@ class LLVM_ConstrainedIntr], true : []), /*requiresFastmath=*/0, + /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[], /*immArgAttrNames=*/[]> { dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i)); @@ -589,7 +590,7 @@ def LLVM_ExpectOp def LLVM_ExpectWithProbabilityOp : LLVM_OneResultIntrOp<"expect.with.probability", [], [0], [Pure, AllTypesMatch<["val", "expected", "res"]>], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> { let arguments = (ins AnySignlessInteger:$val, AnySignlessInteger:$expected, @@ -825,7 +826,7 @@ class LLVM_VecReductionAccBase /*overloadedResults=*/[], /*overloadedOperands=*/[1], /*traits=*/[Pure, SameOperandsAndResultElementType], - /*equiresFastmath=*/1>, + /*requiresFastmath=*/1>, Arguments<(ins element:$start_value, LLVM_VectorOf:$input, DefaultValuedAttr:$fastmathFlags)>; @@ -1069,14 +1070,36 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> { } /// Create a call to Masked Expand Load intrinsic. -def LLVM_masked_expandload : LLVM_IntrOp<"masked.expandload", [0], [], [], 1> { - let arguments = (ins LLVM_AnyPointer, LLVM_VectorOf, LLVM_AnyVector); +def LLVM_masked_expandload + : LLVM_OneResultIntrOp<"masked.expandload", [0], [], + /*traits=*/[], /*requiresFastMath=*/0, /*requiresArgAndResultAttrs=*/1, + /*immArgPositions=*/[], /*immArgAttrNames=*/[]> { + dag args = (ins LLVM_AnyPointer:$ptr, + LLVM_VectorOf:$mask, + LLVM_AnyVector:$passthru); + + let arguments = !con(args, baseArgs); + + let builders = [ + OpBuilder<(ins "TypeRange":$resTy, "Value":$ptr, "Value":$mask, "Value":$passthru, CArg<"uint64_t", "1">:$align)> + ]; } /// Create a call to Masked Compress Store intrinsic. def LLVM_masked_compressstore - : LLVM_IntrOp<"masked.compressstore", [], [0], [], 0> { - let arguments = (ins LLVM_AnyVector, LLVM_AnyPointer, LLVM_VectorOf); + : LLVM_ZeroResultIntrOp<"masked.compressstore", [0], + /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, + /*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0, + /*immArgPositions=*/[], /*immArgAttrNames=*/[]> { + dag args = (ins LLVM_AnyVector:$value, + LLVM_AnyPointer:$ptr, + LLVM_VectorOf:$mask); + + let arguments = !con(args, baseArgs); + + let builders = [ + OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, CArg<"uint64_t", "1">:$align)> + ]; } // @@ -1155,7 +1178,7 @@ def LLVM_vector_insert PredOpTrait<"it is not inserting scalable into fixed-length vectors.", CPred<"!isScalableVectorType($srcvec.getType()) || " "isScalableVectorType($dstvec.getType())">>], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> { let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec, I64Attr:$pos); @@ -1189,7 +1212,7 @@ def LLVM_vector_extract PredOpTrait<"it is not extracting scalable from fixed-length vectors.", CPred<"!isScalableVectorType($res.getType()) || " "isScalableVectorType($srcvec.getType())">>], - /*requiresFastmath=*/0, + /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0, /*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> { let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos); let results = (outs LLVM_AnyVector:$res); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index a8d7cf2069547..d6aa9580870a8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -475,11 +475,12 @@ class LLVM_OneResultIntrOp overloadedResults = [], list overloadedOperands = [], list traits = [], bit requiresFastmath = 0, + bit requiresArgAndResultAttrs = 0, list immArgPositions = [], list immArgAttrNames = []> : LLVM_IntrOp; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 422039f81855a..a6e89f63c822b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) { return success(); } +static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, + bool isExpandLoad, + uint64_t alignment = 1) { + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The pointer alignment defaults to 1. + if (alignment == 1) { + return nullptr; + } + + auto emptyDictAttr = builder.getDictionaryAttr({}); + auto alignmentAttr = builder.getI64IntegerAttr(alignment); + auto namedAttr = + builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr); + SmallVector attrs = {namedAttr}; + auto alignDictAttr = builder.getDictionaryAttr(attrs); + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The align parameter attribute can be provided for [expandload]'s first + // argument. The align parameter attribute can be provided for + // [compressstore]'s second argument. + int pos = isExpandLoad ? 0 : 1; + return pos == 0 ? builder.getArrayAttr( + {alignDictAttr, emptyDictAttr, emptyDictAttr}) + : builder.getArrayAttr( + {emptyDictAttr, alignDictAttr, emptyDictAttr}); +} + //===----------------------------------------------------------------------===// // Operand bundle helpers. //===----------------------------------------------------------------------===// @@ -4116,6 +4148,32 @@ LogicalResult LLVM::masked_scatter::verify() { return success(); } +//===----------------------------------------------------------------------===// +// masked_expandload (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state, + mlir::TypeRange resTys, Value ptr, + Value mask, Value passthru, + uint64_t align) { + ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align); + build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// +// masked_compressstore (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_compressstore::build(OpBuilder &builder, + OperationState &state, Value value, + Value ptr, Value mask, uint64_t align) { + ArrayAttr argAttrs = + getLLVMAlignParamForCompressExpand(builder, false, align); + build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + //===----------------------------------------------------------------------===// // InlineAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 9f882ad6f22e8..07d22120153fe 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -545,6 +545,15 @@ define void @masked_expand_compress_intrinsics(ptr %0, <7 x i1> %1, <7 x float> ret void } +; CHECK-LABEL: llvm.func @masked_expand_compress_intrinsics_with_alignment +define void @masked_expand_compress_intrinsics_with_alignment(ptr %0, <7 x i1> %1, <7 x float> %2) { + ; CHECK: %[[val1:.+]] = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> vector<7xf32> + %4 = call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %0, <7 x i1> %1, <7 x float> %2) + ; CHECK: "llvm.intr.masked.compressstore"(%[[val1]], %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> () + call void @llvm.masked.compressstore.v7f32(<7 x float> %4, ptr align 8 %0, <7 x i1> %1) + ret void +} + ; CHECK-LABEL: llvm.func @annotate_intrinsics define void @annotate_intrinsics(ptr %var, ptr %ptr, i16 %int, ptr %annotation, ptr %fileName, i32 %line, ptr %args) { ; CHECK: "llvm.intr.var.annotation"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> () diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 2b420ed246fb2..c99dde36f5ccb 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -577,6 +577,17 @@ llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr, %mask: vector<7xi1 llvm.return } +// CHECK-LABEL: @masked_expand_compress_intrinsics_with_alignment +llvm.func @masked_expand_compress_intrinsics_with_alignment(%ptr: !llvm.ptr, %mask: vector<7xi1>, %passthru: vector<7xf32>) { + // CHECK: call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}}) + %0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru) {arg_attrs = [{llvm.align = 8 : i32}, {}, {}]} + : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> (vector<7xf32>) + // CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, ptr align 8 %{{.*}}, <7 x i1> %{{.*}}) + "llvm.intr.masked.compressstore"(%0, %ptr, %mask) {arg_attrs = [{}, {llvm.align = 8 : i32}, {}]} + : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> () + llvm.return +} + // CHECK-LABEL: @annotate_intrinsics llvm.func @annotate_intrinsics(%var: !llvm.ptr, %int: i16, %ptr: !llvm.ptr, %annotation: !llvm.ptr, %fileName: !llvm.ptr, %line: i32, %attr: !llvm.ptr) { // CHECK: call void @llvm.var.annotation.p0.p0(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, ptr %{{.*}})