211 changes: 124 additions & 87 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ def LLVM_ZeroResultOpBuilder :
}
}]>;

class LLVM_TwoBuilders<OpBuilderDAG b1, OpBuilderDAG b2> {
list<OpBuilderDAG> builders = [b1, b2];
}

// Base class for LLVM operations with one result.
class LLVM_OneResultOp<string mnemonic, list<OpTrait> traits = []> :
LLVM_Op<mnemonic, traits>, Results<(outs LLVM_Type:$res)> {
let builders = [LLVM_OneResultOpBuilder];
}

// Compatibility builder that takes an instance of wrapped llvm::VoidType
// to indicate no result.
def LLVM_VoidResultTypeOpBuilder :
Expand All @@ -66,10 +56,6 @@ def LLVM_VoidResultTypeOpBuilder :
build($_builder, $_state, operands, attributes);
}]>;

// Base class for LLVM operations with zero results.
class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
LLVM_Op<mnemonic, traits>, Results<(outs)>,
LLVM_TwoBuilders<LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder>;

// Opaque builder used for terminator operations that contain successors.
def LLVM_TerminatorPassthroughOpBuilder :
Expand All @@ -89,14 +75,14 @@ class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
// Class for arithmetic binary operations.
class LLVM_ArithmeticOpBase<Type type, string mnemonic,
string builderFunc, list<OpTrait> traits = []> :
LLVM_OneResultOp<mnemonic,
LLVM_Op<mnemonic,
!listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
LLVM_ScalarOrVectorOf<type>:$rhs);
let parser =
[{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }];
let results = (outs LLVM_ScalarOrVectorOf<type>:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)";
}
class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
list<OpTrait> traits = []> :
Expand All @@ -108,13 +94,13 @@ class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
// Class for arithmetic unary operations.
class LLVM_UnaryArithmeticOp<Type type, string mnemonic,
string builderFunc, list<OpTrait> traits = []> :
LLVM_OneResultOp<mnemonic,
LLVM_Op<mnemonic,
!listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> {
let arguments = (ins type:$operand);
let parser =
[{ return impl::parseOneResultSameOperandTypeOp(parser, result); }];
let printer = [{ mlir::impl::printOneResultOp(this->getOperation(), p); }];
let results = (outs type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$operand attr-dict `:` type($res)";
}

// Integer binary operations.
Expand Down Expand Up @@ -153,10 +139,11 @@ def ICmpPredicate : I64EnumAttr<
}

// Other integer operations.
def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]> {
def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
let arguments = (ins ICmpPredicate:$predicate,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>:$lhs,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>:$rhs);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_i1>:$res);
let llvmBuilder = [{
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
Expand Down Expand Up @@ -200,10 +187,11 @@ def FCmpPredicate : I64EnumAttr<
}

// Other integer operations.
def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]> {
def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> {
let arguments = (ins FCmpPredicate:$predicate,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_i1>:$res);
let llvmBuilder = [{
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
Expand Down Expand Up @@ -252,11 +240,10 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
}

// Memory-related operations.
def LLVM_AllocaOp :
MemoryOpWithAlignmentBase,
LLVM_OneResultOp<"alloca"> {
def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
let arguments = (ins LLVM_AnyInteger:$arraySize,
OptionalAttr<I64Attr>:$alignment);
let results = (outs LLVM_AnyPointer:$res);
string llvmBuilder = [{
auto *inst = builder.CreateAlloca(
$_resultType->getPointerElementType(), $arraySize);
Expand All @@ -276,21 +263,22 @@ def LLVM_AllocaOp :
let printer = [{ printAllocaOp(p, *this); }];
}

def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]>,
LLVM_Builder<"$res = builder.CreateGEP($base, $indices);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
Variadic<LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>:$indices);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = [{
$base `[` $indices `]` attr-dict `:` functional-type(operands, results)
}];
}

def LLVM_LoadOp :
MemoryOpWithAlignmentAndAttributes,
LLVM_OneResultOp<"load"> {
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
let arguments = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
UnitAttr:$nontemporal);
let results = (outs LLVM_Type:$res);
string llvmBuilder = [{
auto *inst = builder.CreateLoad($addr, $volatile_);
}] # setAlignmentCode # setNonTemporalMetadataCode # [{
Expand All @@ -309,9 +297,8 @@ def LLVM_LoadOp :
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
}
def LLVM_StoreOp :
MemoryOpWithAlignmentAndAttributes,
LLVM_ZeroResultOp<"store"> {

def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
let arguments = (ins LLVM_LoadableType:$value,
LLVM_PointerTo<LLVM_LoadableType>:$addr,
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
Expand All @@ -330,52 +317,64 @@ def LLVM_StoreOp :

// Casts.
class LLVM_CastOp<string mnemonic, string builderFunc, Type type,
list<OpTrait> traits = []> :
LLVM_OneResultOp<mnemonic,
!listconcat([NoSideEffect], traits)>,
Type resultType, list<OpTrait> traits = []> :
LLVM_Op<mnemonic, !listconcat([NoSideEffect], traits)>,
LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> {
let arguments = (ins type:$arg);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return mlir::impl::parseCastOp(parser, result); }];
let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }];
}
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast",
LLVM_AnyNonAggregate>;
LLVM_AnyNonAggregate, LLVM_AnyNonAggregate>;
def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast",
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt",
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;

// Call-related operations.
def LLVM_InvokeOp : LLVM_Op<"invoke", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
Terminator
]>,
Results<(outs Variadic<LLVM_Type>)> {
Terminator]> {
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$operands,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands);
let results = (outs Variadic<LLVM_Type>);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);

Expand All @@ -398,17 +397,19 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
let printer = [{ printInvokeOp(p, *this); }];
}

def LLVM_LandingpadOp : LLVM_OneResultOp<"landingpad"> {
def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let arguments = (ins UnitAttr:$cleanup, Variadic<LLVM_Type>);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseLandingpadOp(parser, result); }];
let printer = [{ printLandingpadOp(p, *this); }];
}

def LLVM_CallOp : LLVM_Op<"call">,
Results<(outs Variadic<LLVM_Type>)> {
def LLVM_CallOp : LLVM_Op<"call"> {
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>);
let results = (outs Variadic<LLVM_Type>);
let builders = [
OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
Expand All @@ -424,8 +425,9 @@ def LLVM_CallOp : LLVM_Op<"call">,
let parser = [{ return parseCallOp(parser, result); }];
let printer = [{ printCallOp(p, *this); }];
}
def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]> {
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, LLVM_AnyInteger:$position);
let results = (outs LLVM_Type:$res);
string llvmBuilder = [{
$res = builder.CreateExtractElement($vector, $position);
}];
Expand All @@ -435,26 +437,31 @@ def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]> {
let parser = [{ return parseExtractElementOp(parser, result); }];
let printer = [{ printExtractElementOp(p, *this); }];
}
def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]> {
def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position);
let results = (outs LLVM_Type:$res);
string llvmBuilder = [{
$res = builder.CreateExtractValue($container, extractPosition($position));
}];
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return parseExtractValueOp(parser, result); }];
let printer = [{ printExtractValueOp(p, *this); }];
}
def LLVM_InsertElementOp : LLVM_OneResultOp<"insertelement", [NoSideEffect]> {
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
LLVM_AnyInteger:$position);
let results = (outs LLVM_AnyVector:$res);
string llvmBuilder = [{
$res = builder.CreateInsertElement($vector, $value, $position);
}];
let builders = [LLVM_OneResultOpBuilder];
let parser = [{ return parseInsertElementOp(parser, result); }];
let printer = [{ printInsertElementOp(p, *this); }];
}
def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]> {
def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value,
ArrayAttr:$position);
let results = (outs LLVM_AnyAggregate:$res);
string llvmBuilder = [{
$res = builder.CreateInsertValue($container, $value,
extractPosition($position));
Expand All @@ -467,8 +474,9 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]> {
let parser = [{ return parseInsertValueOp(parser, result); }];
let printer = [{ printInsertValueOp(p, *this); }];
}
def LLVM_ShuffleVectorOp : LLVM_OneResultOp<"shufflevector", [NoSideEffect]> {
def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask);
let results = (outs LLVM_AnyVector:$res);
string llvmBuilder = [{
SmallVector<unsigned, 4> position = extractPosition($mask);
SmallVector<int, 4> mask(position.begin(), position.end());
Expand All @@ -493,21 +501,24 @@ def LLVM_ShuffleVectorOp : LLVM_OneResultOp<"shufflevector", [NoSideEffect]> {

// Misc operations.
def LLVM_SelectOp
: LLVM_OneResultOp<"select",
: LLVM_Op<"select",
[NoSideEffect, AllTypesMatch<["trueValue", "falseValue", "res"]>]>,
LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_i1>:$condition,
LLVM_Type:$trueValue, LLVM_Type:$falseValue);
let results = (outs LLVM_Type:$res);
let builders = [
OpBuilderDAG<(ins "Value":$condition, "Value":$lhs, "Value":$rhs),
[{
build($_builder, $_state, lhs.getType(), condition, lhs, rhs);
}]>];
let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)";
}
def LLVM_FreezeOp : LLVM_OneResultOp<"freeze", [SameOperandsAndResultType]> {
def LLVM_FreezeOp : LLVM_Op<"freeze", [SameOperandsAndResultType]> {
let arguments = (ins LLVM_Type:$val);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$val attr-dict `:` type($val)";
string llvmBuilder = "builder.CreateFreeze($val);";
}
Expand Down Expand Up @@ -639,8 +650,9 @@ def Linkage : LLVM_EnumAttr<
}


def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof"> {
def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
let arguments = (ins FlatSymbolRefAttr:$global_name);
let results = (outs LLVM_Type:$res);

let summary = "Creates a pointer pointing to a global or a function";

Expand Down Expand Up @@ -677,10 +689,8 @@ def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof"> {
let verifier = "return ::verify(*this);";
}

def LLVM_GlobalOp
: LLVM_ZeroResultOp<"mlir.global",
[IsolatedFromAbove,
SingleBlockImplicitTerminator<"ReturnOp">, Symbol]> {
def LLVM_GlobalOp : LLVM_Op<"mlir.global",
[IsolatedFromAbove, SingleBlockImplicitTerminator<"ReturnOp">, Symbol]> {
let arguments = (ins
TypeAttr:$type,
UnitAttr:$constant,
Expand Down Expand Up @@ -740,9 +750,8 @@ def LLVM_GlobalOp
let verifier = "return ::verify(*this);";
}

def LLVM_LLVMFuncOp
: LLVM_ZeroResultOp<"func", [AutomaticAllocationScope, IsolatedFromAbove,
FunctionLike, Symbol]> {
def LLVM_LLVMFuncOp : LLVM_Op<"func",
[AutomaticAllocationScope, IsolatedFromAbove, FunctionLike, Symbol]> {
let summary = "LLVM dialect function, has wrapped LLVM IR function type";

let arguments = (ins DefaultValuedAttr<Linkage, "Linkage::External">:$linkage,
Expand Down Expand Up @@ -793,28 +802,32 @@ def LLVM_LLVMFuncOp
}

def LLVM_NullOp
: LLVM_OneResultOp<"mlir.null", [NoSideEffect]>,
: LLVM_Op<"mlir.null", [NoSideEffect]>,
LLVM_Builder<"$res = llvm::ConstantPointerNull::get("
" cast<llvm::PointerType>($_resultType));"> {
let results = (outs LLVM_AnyPointer:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "attr-dict `:` type($res)";
let verifier = [{ return ::verify(*this); }];
}

def LLVM_UndefOp : LLVM_OneResultOp<"mlir.undef", [NoSideEffect]>,
def LLVM_UndefOp : LLVM_Op<"mlir.undef", [NoSideEffect]>,
LLVM_Builder<"$res = llvm::UndefValue::get($_resultType);"> {
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "attr-dict `:` type($res)";
}
def LLVM_ConstantOp
: LLVM_OneResultOp<"mlir.constant", [NoSideEffect]>,
: LLVM_Op<"mlir.constant", [NoSideEffect]>,
LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
{
let arguments = (ins AnyAttr:$value);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
let verifier = [{ return ::verify(*this); }];
}

def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,
Results<(outs AnyType:$res)> {
def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]> {
let summary = "Type cast between LLVM dialect and Standard.";
let description = [{
llvm.mlir.cast op casts between Standard and LLVM dialects. It only changes
Expand All @@ -828,6 +841,7 @@ def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,
llvm.mlir.cast %v : !llvm<"<2 x float>"> to vector<2xf32>
}];
let arguments = (ins AnyType:$in);
let results = (outs AnyType:$res);
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($res)";
let verifier = "return ::verify(*this);";
}
Expand Down Expand Up @@ -947,10 +961,11 @@ def LLVM_vector_reduce_fmul : LLVM_VectorReductionAcc<"fmul">;
/// isVolatile - True if the load operation is marked as volatile.
/// columns - Number of columns in matrix (must be a constant)
/// stride - Space between columns
def LLVM_MatrixColumnMajorLoadOp
: LLVM_OneResultOp<"intr.matrix.column.major.load"> {
def LLVM_MatrixColumnMajorLoadOp : LLVM_Op<"intr.matrix.column.major.load"> {
let arguments = (ins LLVM_Type:$data, LLVM_Type:$stride, I1Attr:$isVolatile,
I32Attr:$rows, I32Attr:$columns);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
const llvm::DataLayout &dl =
Expand All @@ -973,10 +988,10 @@ def LLVM_MatrixColumnMajorLoadOp
/// rows - Number of rows in matrix (must be a constant)
/// columns - Number of columns in matrix (must be a constant)
/// stride - Space between columns
def LLVM_MatrixColumnMajorStoreOp
: LLVM_ZeroResultOp<"intr.matrix.column.major.store"> {
def LLVM_MatrixColumnMajorStoreOp : LLVM_Op<"intr.matrix.column.major.store"> {
let arguments = (ins LLVM_Type:$matrix, LLVM_Type:$data, LLVM_Type:$stride,
I1Attr:$isVolatile, I32Attr:$rows, I32Attr:$columns);
let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
const llvm::DataLayout &dl =
Expand All @@ -993,10 +1008,11 @@ def LLVM_MatrixColumnMajorStoreOp

/// Create a llvm.matrix.multiply call, multiplying 2-D matrices LHS and RHS, as
/// specified in the LLVM MatrixBuilder.
def LLVM_MatrixMultiplyOp
: LLVM_OneResultOp<"intr.matrix.multiply"> {
def LLVM_MatrixMultiplyOp : LLVM_Op<"intr.matrix.multiply"> {
let arguments = (ins LLVM_Type:$lhs, LLVM_Type:$rhs, I32Attr:$lhs_rows,
I32Attr:$lhs_columns, I32Attr:$rhs_columns);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
$res = mb.CreateMatrixMultiply(
Expand All @@ -1009,8 +1025,10 @@ def LLVM_MatrixMultiplyOp

/// Create a llvm.matrix.transpose call, transposing a `rows` x `columns` 2-D
/// `matrix`, as specified in the LLVM MatrixBuilder.
def LLVM_MatrixTransposeOp : LLVM_OneResultOp<"intr.matrix.transpose"> {
def LLVM_MatrixTransposeOp : LLVM_Op<"intr.matrix.transpose"> {
let arguments = (ins LLVM_Type:$matrix, I32Attr:$rows, I32Attr:$columns);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
$res = mb.CreateMatrixTranspose(
Expand All @@ -1032,9 +1050,11 @@ def LLVM_GetActiveLaneMaskOp
}

/// Create a call to Masked Load intrinsic.
def LLVM_MaskedLoadOp : LLVM_OneResultOp<"intr.masked.load"> {
def LLVM_MaskedLoadOp : LLVM_Op<"intr.masked.load"> {
let arguments = (ins LLVM_Type:$data, LLVM_Type:$mask,
Variadic<LLVM_Type>:$pass_thru, I32Attr:$alignment);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
string llvmBuilder = [{
$res = $pass_thru.empty() ? builder.CreateMaskedLoad(
$data, llvm::Align($alignment), $mask) :
Expand All @@ -1046,9 +1066,10 @@ def LLVM_MaskedLoadOp : LLVM_OneResultOp<"intr.masked.load"> {
}

/// Create a call to Masked Store intrinsic.
def LLVM_MaskedStoreOp : LLVM_ZeroResultOp<"intr.masked.store"> {
def LLVM_MaskedStoreOp : LLVM_Op<"intr.masked.store"> {
let arguments = (ins LLVM_Type:$value, LLVM_Type:$data, LLVM_Type:$mask,
I32Attr:$alignment);
let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
string llvmBuilder = [{
builder.CreateMaskedStore(
$value, $data, llvm::Align($alignment), $mask);
Expand All @@ -1058,9 +1079,11 @@ def LLVM_MaskedStoreOp : LLVM_ZeroResultOp<"intr.masked.store"> {
}

/// Create a call to Masked Gather intrinsic.
def LLVM_masked_gather : LLVM_OneResultOp<"intr.masked.gather"> {
def LLVM_masked_gather : LLVM_Op<"intr.masked.gather"> {
let arguments = (ins LLVM_Type:$ptrs, LLVM_Type:$mask,
Variadic<LLVM_Type>:$pass_thru, I32Attr:$alignment);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
string llvmBuilder = [{
$res = $pass_thru.empty() ? builder.CreateMaskedGather(
$ptrs, llvm::Align($alignment), $mask) :
Expand All @@ -1072,9 +1095,10 @@ def LLVM_masked_gather : LLVM_OneResultOp<"intr.masked.gather"> {
}

/// Create a call to Masked Scatter intrinsic.
def LLVM_masked_scatter : LLVM_ZeroResultOp<"intr.masked.scatter"> {
def LLVM_masked_scatter : LLVM_Op<"intr.masked.scatter"> {
let arguments = (ins LLVM_Type:$value, LLVM_Type:$ptrs, LLVM_Type:$mask,
I32Attr:$alignment);
let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
string llvmBuilder = [{
builder.CreateMaskedScatter(
$value, $ptrs, llvm::Align($alignment), $mask);
Expand Down Expand Up @@ -1139,11 +1163,11 @@ def AtomicOrdering : I64EnumAttr<

def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>;

def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
Results<(outs LLVM_Type:$res)> {
def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
let arguments = (ins AtomicBinOp:$bin_op,
LLVM_PointerTo<LLVM_AtomicRMWType>:$ptr,
LLVM_AtomicRMWType:$val, AtomicOrdering:$ordering);
let results = (outs LLVM_AtomicRMWType:$res);
let llvmBuilder = [{
$res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val,
getLLVMAtomicOrdering($ordering));
Expand All @@ -1154,12 +1178,24 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
}

def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>;

def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg">, Results<(outs LLVM_Type:$res)> {
def LLVM_AtomicCmpXchgResultType : Type<And<[
LLVM_AnyStruct.predicate,
CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().getBody().size() == 2">,
SubstLeaves<"$_self",
"$_self.cast<::mlir::LLVM::LLVMStructType>().getBody()[0]",
LLVM_AtomicCmpXchgType.predicate>,
SubstLeaves<"$_self",
"$_self.cast<::mlir::LLVM::LLVMStructType>().getBody()[1]",
LLVM_i1.predicate>]>,
"an LLVM struct type with any integer or pointer followed by a single-bit "
"integer">;

def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> {
let arguments = (ins LLVM_PointerTo<LLVM_AtomicCmpXchgType>:$ptr,
LLVM_AtomicCmpXchgType:$cmp, LLVM_AtomicCmpXchgType:$val,
AtomicOrdering:$success_ordering,
AtomicOrdering:$failure_ordering);
let results = (outs LLVM_AtomicCmpXchgResultType:$res);
let llvmBuilder = [{
$res = builder.CreateAtomicCmpXchg($ptr, $cmp, $val,
getLLVMAtomicOrdering($success_ordering),
Expand All @@ -1180,8 +1216,9 @@ def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> {
}];
}

def LLVM_FenceOp : LLVM_ZeroResultOp<"fence", []> {
def LLVM_FenceOp : LLVM_Op<"fence"> {
let arguments = (ins AtomicOrdering:$ordering, StrAttr:$syncscope);
let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder];
let llvmBuilder = [{
llvm::LLVMContext &llvmContext = builder.getContext();
builder.CreateFence(getLLVMAtomicOrdering($ordering),
Expand Down
12 changes: 0 additions & 12 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1510,18 +1510,6 @@ static LogicalResult verify(LLVMFuncOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// Verification for LLVM::NullOp.
//===----------------------------------------------------------------------===//

// Only LLVM pointer types are supported.
static LogicalResult verify(LLVM::NullOp op) {
auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>();
if (!llvmType || !llvmType.isPointerTy())
return op.emitOpError("expected LLVM IR pointer type");
return success();
}

//===----------------------------------------------------------------------===//
// Verification for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func @invalid_vector_type_3(%arg0: !llvm.vec<4 x float>, %arg1: !llvm.i32, %arg2
// -----

func @null_non_llvm_type() {
// expected-error@+1 {{expected LLVM IR pointer type}}
// expected-error@+1 {{must be LLVM pointer type, but got '!llvm.i32'}}
llvm.mlir.null : !llvm.i32
}

Expand Down