171 changes: 171 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,175 @@ def LLVMArrayType : LLVMType<"LLVMArray", "array", [
}];
}

//===----------------------------------------------------------------------===//
// LLVMFunctionType
//===----------------------------------------------------------------------===//

def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
let summary = "LLVM function type";
let description = [{
The `!llvm.func` is a function type. It consists of a single return type
(unlike MLIR which can have multiple), a list of parameter types and can
optionally be variadic.

Example:

```mlir
!llvm.func<i32 (i32)>
```
}];

let parameters = (ins "Type":$returnType, ArrayRefParameter<"Type">:$params,
"bool":$varArg);
let assemblyFormat = [{
`<` custom<PrettyLLVMType>($returnType) ` ` `(`
custom<FunctionTypes>($params, $varArg) `>`
}];

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins
"Type":$result, "ArrayRef<Type>":$arguments,
CArg<"bool", "false">:$isVarArg)>
];

let extraClassDeclaration = [{
/// Checks if the given type can be used an argument in a function type.
static bool isValidArgumentType(Type type);

/// Checks if the given type can be used as a result in a function type.
static bool isValidResultType(Type type);

/// Returns whether the function is variadic.
bool isVarArg() const { return getVarArg(); }

/// Returns a clone of this function type with the given argument
/// and result types.
LLVMFunctionType clone(TypeRange inputs, TypeRange results) const;

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
ArrayRef<Type> getReturnTypes() const;

/// Returns the number of arguments to the function.
unsigned getNumParams() const { return getParams().size(); }

/// Returns `i`-th argument of the function. Asserts on out-of-bounds.
Type getParamType(unsigned i) { return getParams()[i]; }
}];
}

//===----------------------------------------------------------------------===//
// LLVMPointerType
//===----------------------------------------------------------------------===//

def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "verifyEntries"]>,
DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
let summary = "LLVM pointer type";
let description = [{
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
a reference to an object in memory. Pointers may be opaque or parameterized
by the element type. Both opaque and non-opaque pointers are additionally
parameterized by the address space.

Example:

```mlir
!llvm.ptr<i8>
!llvm.ptr
```
}];

let parameters = (ins DefaultValuedParameter<"Type", "Type()">:$elementType,
DefaultValuedParameter<"unsigned", "0">:$addressSpace);
let assemblyFormat = [{
(`<` custom<Pointer>($elementType, $addressSpace)^ `>`)?
}];

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
CArg<"unsigned", "0">:$addressSpace)>,
TypeBuilder<(ins CArg<"unsigned", "0">:$addressSpace), [{
return $_get($_ctxt, Type(), addressSpace);
}]>
];

let extraClassDeclaration = [{
/// Returns `true` if this type is the opaque pointer type, i.e., it has no
/// pointed-to type.
bool isOpaque() const { return !getElementType(); }

/// Checks if the given type can have a pointer type pointing to it.
static bool isValidElementType(Type type);
}];
}

//===----------------------------------------------------------------------===//
// LLVMFixedVectorType
//===----------------------------------------------------------------------===//

def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
let summary = "LLVM fixed vector type";
let description = [{
LLVM dialect scalable vector type, represents a sequence of elements of
unknown length that is known to be divisible by some constant. These
elements can be processed as one in SIMD context.
}];

let parameters = (ins "Type":$elementType, "unsigned":$numElements);
let assemblyFormat = [{
`<` $numElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
}];

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
"unsigned":$numElements)>
];

let extraClassDeclaration = [{
/// Checks if the given type can be used in a vector type.
static bool isValidElementType(Type type);
}];
}

//===----------------------------------------------------------------------===//
// LLVMScalableVectorType
//===----------------------------------------------------------------------===//

def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
let summary = "LLVM scalable vector type";
let description = [{
LLVM dialect scalable vector type, represents a sequence of elements of
unknown length that is known to be divisible by some constant. These
elements can be processed as one in SIMD context.
}];

let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
let assemblyFormat = [{
`<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
}];

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType,
"unsigned":$minNumElements)>
];

let extraClassDeclaration = [{
/// Checks if the given type can be used in a vector type.
static bool isValidElementType(Type type);
}];
}

#endif // LLVMTYPES_TD
61 changes: 32 additions & 29 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def NVVM_VoteBallotOp :
let hasCustomAssemblyFormat = 1;
}

def NVVM_SyncWarpOp :
def NVVM_SyncWarpOp :
NVVM_Op<"bar.warp.sync">,
Arguments<(ins LLVM_Type:$mask)> {
string llvmBuilder = [{
Expand Down Expand Up @@ -534,9 +534,9 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
);
}

// Returns a list of operation suffixes corresponding to possible b1
// multiply-and-accumulate operations for all fragments which have a
// b1 type. For all other fragments, the list returned holds a list
// Returns a list of operation suffixes corresponding to possible b1
// multiply-and-accumulate operations for all fragments which have a
// b1 type. For all other fragments, the list returned holds a list
// containing the empty string.
class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
list<string> ret = !cond(
Expand All @@ -555,7 +555,7 @@ class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
# "_" # ALayout
# "_" # BLayout
# !if(Satfinite, "_satfinite", "")
# signature;
# signature;
}

/// Helper to create the mapping between the configuration and the mma.sync
Expand All @@ -572,21 +572,21 @@ class MMA_SYNC_INTR {
"if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && "
" m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
" && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
# op[1].ptx_elt_type # "\" == eltypeB && "
# op[1].ptx_elt_type # "\" == eltypeB && "
# " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
# " \"" # op[3].ptx_elt_type # "\" == eltypeD "
# " && (sat.has_value() ? " # sat # " == static_cast<int>(*sat) : true)"
# !if(!ne(b1op, ""), " && (b1Op.has_value() ? MMAB1Op::" # b1op # " == b1Op.value() : true)", "") # ")\n"
# " return " #
MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
"") // if supported
) // b1op
) // sat
) // layoutB
) // layoutA
); // all_mma_sync_ops
list<list<list<string>>> f1 = !foldl([[[""]]],
!foldl([[[[""]]]], cond0, acc, el,
!foldl([[[[""]]]], cond0, acc, el,
!listconcat(acc, el)),
acc1, el1, !listconcat(acc1, el1));
list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
Expand Down Expand Up @@ -776,7 +776,10 @@ def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
```
}];

let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)";
let assemblyFormat = [{
$ptr `,` $stride `,` $args attr-dict `:` qualified(type($ptr)) `,`
type($args)
}];
let hasVerifier = 1;
}

Expand Down Expand Up @@ -884,32 +887,32 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {

let description = [{
The `nvvm.mma.sync` operation collectively performs the operation
`D = matmul(A, B) + C` using all threads in a warp.
`D = matmul(A, B) + C` using all threads in a warp.

All the threads in the warp must execute the same `mma.sync` operation.

For each possible multiplicand PTX data type, there are one or more possible
instruction shapes given as "mMnNkK". The below table describes the posssibilities
as well as the types required for the operands. Note that the data type for
C (the accumulator) and D (the result) can vary independently when there are
as well as the types required for the operands. Note that the data type for
C (the accumulator) and D (the result) can vary independently when there are
multiple possibilities in the "C/D Type" column.

When an optional attribute cannot be immediately inferred from the types of
the operands and the result during parsing or validation, an error will be
raised.

`b1Op` is only relevant when the binary (b1) type is given to
`b1Op` is only relevant when the binary (b1) type is given to
`multiplicandDataType`. It specifies how the multiply-and-acumulate is
performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.

`intOverflowBehavior` is only relevant when the `multiplicandType` attribute
`intOverflowBehavior` is only relevant when the `multiplicandType` attribute
is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
in the accumulator. When the attribute is `satfinite`, the accumulator values
are clamped in the int32 range on overflow. This is the default behavior.
Alternatively, accumulator behavior `wrapped` can also be specified, in
Alternatively, accumulator behavior `wrapped` can also be specified, in
which case overflow wraps from one end of the range to the other.

`layoutA` and `layoutB` are required and should generally be set to
`layoutA` and `layoutB` are required and should generally be set to
`#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
combinations are possible for certain layouts according to the table below.

Expand Down Expand Up @@ -938,33 +941,33 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
Example:
```mlir

%128 = nvvm.mma.sync A[%120, %121, %122, %123]
B[%124, %125]
C[%126, %127]
{layoutA = #nvvm.mma_layout<row>,
layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
%128 = nvvm.mma.sync A[%120, %121, %122, %123]
B[%124, %125]
C[%126, %127]
{layoutA = #nvvm.mma_layout<row>,
layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>)
-> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
```
}];

let results = (outs LLVM_AnyStruct:$res);
let arguments = (ins NVVM_MMAShapeAttr:$shape,
OptionalAttr<MMAB1OpAttr>:$b1Op,
OptionalAttr<MMAB1OpAttr>:$b1Op,
OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
MMALayoutAttr:$layoutA,
MMALayoutAttr:$layoutB,
OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
Variadic<LLVM_Type>:$operandA,
Variadic<LLVM_Type>:$operandB,
Variadic<LLVM_Type>:$operandC);
Variadic<LLVM_Type>:$operandC);

let extraClassDeclaration = !strconcat([{
static llvm::Intrinsic::ID getIntrinsicID(
int64_t m, int64_t n, uint64_t k,
llvm::Optional<MMAB1Op> b1Op,
llvm::Optional<MMAB1Op> b1Op,
llvm::Optional<MMAIntOverflow> sat,
mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
Expand All @@ -988,7 +991,7 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
}]);

let builders = [
OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
"ValueRange":$operandB, "ValueRange":$operandC,
"ArrayRef<int64_t>":$shape, "Optional<MMAB1Op>":$b1Op,
"Optional<MMAIntOverflow>":$intOverflow,
Expand All @@ -999,12 +1002,12 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
$shape.getM(), $shape.getN(), $shape.getK(),
$shape.getM(), $shape.getN(), $shape.getK(),
$b1Op, $intOverflowBehavior,
$layoutA, $layoutB,
$multiplicandAPtxType.value(),
$multiplicandAPtxType.value(),
$multiplicandBPtxType.value(),
op.accumPtxType(),
op.accumPtxType(),
op.resultPtxType());

$res = createIntrinsicCall(
Expand Down
7 changes: 6 additions & 1 deletion mlir/include/mlir/IR/SubElementInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
would replace the very first attribute given by `walkImmediateSubElements`.
On success, the new instance with the values replaced is returned. If replacement
fails, nullptr is returned.

Note that replacing the sub-elements of mutable types or attributes is
not currently supported by the interface. If an implementing type or
attribute is mutable, it should return `nullptr` if it has no mechanism
for replacing sub elements.
}], attrOrType, "replaceImmediateSubElements", (ins
"::llvm::ArrayRef<::mlir::Attribute>":$replAttrs,
"::llvm::ArrayRef<::mlir::Type>":$replTypes
Expand Down Expand Up @@ -106,7 +111,7 @@ class SubElementInterfaceBase<string interfaceName, string attrOrType,
void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
}

/// Recursively replace all of the nested sub-attributes using the provided
/// map function. Returns nullptr in the case of failure.
}] # attrOrType # [{ replaceSubElements(
Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2573,10 +2573,6 @@ void LLVMDialect::initialize() {
LLVMTokenType,
LLVMLabelType,
LLVMMetadataType,
LLVMFunctionType,
LLVMPointerType,
LLVMFixedVectorType,
LLVMScalableVectorType,
LLVMStructType>();
// clang-format on
registerTypes();
Expand Down
137 changes: 9 additions & 128 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,30 +99,6 @@ static void printStructType(AsmPrinter &printer, LLVMStructType type) {
printer << '>';
}

/// Prints a type containing a fixed number of elements.
template <typename TypeTy>
static void printVectorType(AsmPrinter &printer, TypeTy type) {
printer << '<' << type.getNumElements() << " x ";
dispatchPrint(printer, type.getElementType());
printer << '>';
}

/// Prints a function type.
static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) {
printer << '<';
dispatchPrint(printer, funcType.getReturnType());
printer << " (";
llvm::interleaveComma(
funcType.getParams(), printer.getStream(),
[&printer](Type subtype) { dispatchPrint(printer, subtype); });
if (funcType.isVarArg()) {
if (funcType.getNumParams() != 0)
printer << ", ";
printer << "...";
}
printer << ")>";
}

/// Prints the given LLVM dialect type recursively. This leverages closedness of
/// the LLVM dialect type system to avoid printing the dialect prefix
/// repeatedly. For recursive structures, only prints the name of the structure
Expand All @@ -140,38 +116,13 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {

printer << getTypeKeyword(type);

if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
if (ptrType.isOpaque()) {
if (ptrType.getAddressSpace() != 0)
printer << '<' << ptrType.getAddressSpace() << '>';
return;
}

printer << '<';
dispatchPrint(printer, ptrType.getElementType());
if (ptrType.getAddressSpace() != 0)
printer << ", " << ptrType.getAddressSpace();
printer << '>';
return;
}

if (auto arrayType = type.dyn_cast<LLVMArrayType>())
return arrayType.print(printer);
if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
return printVectorType(printer, vectorType);

if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
printer << "<? x " << vectorType.getMinNumElements() << " x ";
dispatchPrint(printer, vectorType.getElementType());
printer << '>';
return;
}

if (auto structType = type.dyn_cast<LLVMStructType>())
return printStructType(printer, structType);

if (auto funcType = type.dyn_cast<LLVMFunctionType>())
return printFunctionType(printer, funcType);
llvm::TypeSwitch<Type>(type)
.Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType,
LLVMScalableVectorType, LLVMFunctionType>(
[&](auto type) { type.print(printer); })
.Case([&](LLVMStructType structType) {
printStructType(printer, structType);
});
}

//===----------------------------------------------------------------------===//
Expand All @@ -180,76 +131,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {

static ParseResult dispatchParse(AsmParser &parser, Type &type);

/// Parses an LLVM dialect function type.
/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
static LLVMFunctionType parseFunctionType(AsmParser &parser) {
SMLoc loc = parser.getCurrentLocation();
Type returnType;
if (parser.parseLess() || dispatchParse(parser, returnType) ||
parser.parseLParen())
return LLVMFunctionType();

// Function type without arguments.
if (succeeded(parser.parseOptionalRParen())) {
if (succeeded(parser.parseGreater()))
return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
/*isVarArg=*/false);
return LLVMFunctionType();
}

// Parse arguments.
SmallVector<Type, 8> argTypes;
do {
if (succeeded(parser.parseOptionalEllipsis())) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
/*isVarArg=*/true);
}

Type arg;
if (dispatchParse(parser, arg))
return LLVMFunctionType();
argTypes.push_back(arg);
} while (succeeded(parser.parseOptionalComma()));

if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
/*isVarArg=*/false);
}

/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
/// | `ptr` (`<` integer `>`)?
static LLVMPointerType parsePointerType(AsmParser &parser) {
SMLoc loc = parser.getCurrentLocation();
Type elementType;
if (parser.parseOptionalLess()) {
return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
/*addressSpace=*/0);
}

unsigned addressSpace = 0;
OptionalParseResult opr = parser.parseOptionalInteger(addressSpace);
if (opr.has_value()) {
if (failed(*opr) || parser.parseGreater())
return LLVMPointerType();
return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
addressSpace);
}

if (dispatchParse(parser, elementType))
return LLVMPointerType();

if (succeeded(parser.parseOptionalComma()) &&
failed(parser.parseInteger(addressSpace)))
return LLVMPointerType();
if (failed(parser.parseGreater()))
return LLVMPointerType();
return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
}

/// Parses an LLVM dialect vector type.
/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
/// Supports both fixed and scalable vectors.
Expand Down Expand Up @@ -445,8 +326,8 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
.Case("token", [&] { return LLVMTokenType::get(ctx); })
.Case("label", [&] { return LLVMLabelType::get(ctx); })
.Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
.Case("func", [&] { return parseFunctionType(parser); })
.Case("ptr", [&] { return parsePointerType(parser); })
.Case("func", [&] { return LLVMFunctionType::parse(parser); })
.Case("ptr", [&] { return LLVMPointerType::parse(parser); })
.Case("vec", [&] { return parseVectorType(parser); })
.Case("array", [&] { return LLVMArrayType::parse(parser); })
.Case("struct", [&] { return parseStructType(parser); })
Expand Down
202 changes: 120 additions & 82 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -27,6 +28,105 @@ using namespace mlir::LLVM;

constexpr const static unsigned kBitsInByte = 8;

//===----------------------------------------------------------------------===//
// custom<FunctionTypes>
//===----------------------------------------------------------------------===//

static ParseResult parseFunctionTypes(AsmParser &p,
FailureOr<SmallVector<Type>> &params,
FailureOr<bool> &isVarArg) {
params.emplace();
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
return success();

// `(` `...` `)`
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
return p.parseRParen();
}

// type (`,` type)* (`,` `...`)?
FailureOr<Type> type;
if (parsePrettyLLVMType(p, type))
return failure();
params->push_back(*type);
while (succeeded(p.parseOptionalComma())) {
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
return p.parseRParen();
}
if (parsePrettyLLVMType(p, type))
return failure();
params->push_back(*type);
}
return p.parseRParen();
}

static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
bool isVarArg) {
llvm::interleaveComma(params, p,
[&](Type type) { printPrettyLLVMType(p, type); });
if (isVarArg) {
if (!params.empty())
p << ", ";
p << "...";
}
p << ')';
}

//===----------------------------------------------------------------------===//
// custom<Pointer>
//===----------------------------------------------------------------------===//

static ParseResult parsePointer(AsmParser &p, FailureOr<Type> &elementType,
FailureOr<unsigned> &addressSpace) {
addressSpace = 0;
// `<` addressSpace `>`
OptionalParseResult result = p.parseOptionalInteger(*addressSpace);
if (result.has_value()) {
if (failed(result.value()))
return failure();
elementType = Type();
return success();
}

if (parsePrettyLLVMType(p, elementType))
return failure();
if (succeeded(p.parseOptionalComma()))
return p.parseInteger(*addressSpace);

return success();
}

static void printPointer(AsmPrinter &p, Type elementType,
unsigned addressSpace) {
if (elementType)
printPrettyLLVMType(p, elementType);
if (addressSpace != 0) {
if (elementType)
p << ", ";
p << addressSpace;
}
}

//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//

/// These are unused for now.
/// TODO: Move over to these once more types have been migrated to TypeDef.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
LLVM_ATTRIBUTE_UNUSED static LogicalResult
generatedTypePrinter(Type def, AsmPrinter &printer);

#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"

//===----------------------------------------------------------------------===//
// LLVMArrayType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -130,25 +230,8 @@ LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs,
return get(results[0], llvm::to_vector(inputs), isVarArg());
}

Type LLVMFunctionType::getReturnType() const {
return getImpl()->getReturnType();
}
ArrayRef<Type> LLVMFunctionType::getReturnTypes() const {
return getImpl()->getReturnType();
}

unsigned LLVMFunctionType::getNumParams() {
return getImpl()->getArgumentTypes().size();
}

Type LLVMFunctionType::getParamType(unsigned i) {
return getImpl()->getArgumentTypes()[i];
}

bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); }

ArrayRef<Type> LLVMFunctionType::getParams() const {
return getImpl()->getArgumentTypes();
return static_cast<detail::LLVMFunctionTypeStorage *>(getImpl())->returnType;
}

LogicalResult
Expand All @@ -164,10 +247,14 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

//===----------------------------------------------------------------------===//
// SubElementTypeInterface

void LLVMFunctionType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Type type : llvm::concat<const Type>(getReturnTypes(), getParams()))
walkTypesFn(getReturnType());
for (Type type : getParams())
walkTypesFn(type);
}

Expand All @@ -177,7 +264,7 @@ Type LLVMFunctionType::replaceImmediateSubElements(
}

//===----------------------------------------------------------------------===//
// Pointer type.
// LLVMPointerType
//===----------------------------------------------------------------------===//

bool LLVMPointerType::isValidElementType(Type type) {
Expand All @@ -195,32 +282,6 @@ LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
return Base::get(pointee.getContext(), pointee, addressSpace);
}

LLVMPointerType LLVMPointerType::get(MLIRContext *context,
unsigned addressSpace) {
return Base::get(context, Type(), addressSpace);
}

LLVMPointerType
LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type pointee, unsigned addressSpace) {
return Base::getChecked(emitError, pointee.getContext(), pointee,
addressSpace);
}

LLVMPointerType
LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, unsigned addressSpace) {
return Base::getChecked(emitError, context, Type(), addressSpace);
}

Type LLVMPointerType::getElementType() const { return getImpl()->pointeeType; }

bool LLVMPointerType::isOpaque() const { return !getImpl()->pointeeType; }

unsigned LLVMPointerType::getAddressSpace() const {
return getImpl()->addressSpace;
}

LogicalResult
LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
Type pointee, unsigned) {
Expand All @@ -229,6 +290,9 @@ LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

//===----------------------------------------------------------------------===//
// DataLayoutTypeInterface

constexpr const static unsigned kDefaultPointerSizeBits = 64;
constexpr const static unsigned kDefaultPointerAlignment = 8;

Expand Down Expand Up @@ -375,6 +439,9 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
return success();
}

//===----------------------------------------------------------------------===//
// SubElementTypeInterface

void LLVMPointerType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
Expand All @@ -383,7 +450,7 @@ void LLVMPointerType::walkImmediateSubElements(

Type LLVMPointerType::replaceImmediateSubElements(
ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
return get(replTypes.front(), getAddressSpace());
return get(getContext(), replTypes.front(), getAddressSpace());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -631,9 +698,12 @@ void LLVMStructType::walkImmediateSubElements(

Type LLVMStructType::replaceImmediateSubElements(
ArrayRef<Attribute> replAttrs, ArrayRef<Type> replTypes) const {
// TODO: It's not clear how we support replacing sub-elements of mutable
// types.
return nullptr;
if (isIdentified()) {
// TODO: It's not clear how we support replacing sub-elements of mutable
// types.
return nullptr;
}
return getLiteral(getContext(), replTypes, isPacked());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -668,14 +738,6 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
numElements);
}

Type LLVMFixedVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}

unsigned LLVMFixedVectorType::getNumElements() const {
return getImpl()->numElements;
}

bool LLVMFixedVectorType::isValidElementType(Type type) {
return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
}
Expand Down Expand Up @@ -716,14 +778,6 @@ LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
minNumElements);
}

Type LLVMScalableVectorType::getElementType() const {
return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
}

unsigned LLVMScalableVectorType::getMinNumElements() const {
return getImpl()->numElements;
}

bool LLVMScalableVectorType::isValidElementType(Type type) {
if (auto intType = type.dyn_cast<IntegerType>())
return intType.isSignless();
Expand Down Expand Up @@ -1005,22 +1059,6 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
});
}

//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//

/// These are unused for now.
/// TODO: Move over to these once more types have been migrated to TypeDef.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
LLVM_ATTRIBUTE_UNUSED static LogicalResult
generatedTypePrinter(Type def, AsmPrinter &printer);

#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMTypes.cpp.inc"

//===----------------------------------------------------------------------===//
// LLVMDialect
//===----------------------------------------------------------------------===//
Expand Down
83 changes: 0 additions & 83 deletions mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,89 +321,6 @@ struct LLVMStructTypeStorage : public TypeStorage {
unsigned identifiedBodySizeAndFlags = 0;
};

//===----------------------------------------------------------------------===//
// LLVMFunctionTypeStorage.
//===----------------------------------------------------------------------===//

/// Type storage for LLVM dialect function types. These are uniqued using the
/// list of types they contain and the vararg bit.
struct LLVMFunctionTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, ArrayRef<Type>, bool>;

/// Construct a storage from the given components. The list is expected to be
/// allocated in the context.
LLVMFunctionTypeStorage(Type result, ArrayRef<Type> arguments, bool variadic)
: resultType(result), isVariadicFlag(variadic),
numArguments(arguments.size()), argumentTypes(arguments.data()) {}

/// Hook into the type uniquing infrastructure.
static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<LLVMFunctionTypeStorage>())
LLVMFunctionTypeStorage(std::get<0>(key),
allocator.copyInto(std::get<1>(key)),
std::get<2>(key));
}

static unsigned hashKey(const KeyTy &key) {
// LLVM doesn't like hashing bools in tuples.
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
static_cast<int>(std::get<2>(key)));
}

bool operator==(const KeyTy &key) const {
return std::make_tuple(getReturnType(), getArgumentTypes(), isVariadic()) ==
key;
}

/// Returns the list of function argument types.
ArrayRef<Type> getArgumentTypes() const {
return ArrayRef<Type>(argumentTypes, numArguments);
}

/// Checks whether the function type is variadic.
bool isVariadic() const { return isVariadicFlag; }

/// Returns the function result type.
const Type &getReturnType() const { return resultType; }

private:
/// The result type of the function.
Type resultType;
/// Flag indicating if the function is variadic.
bool isVariadicFlag;
/// The argument types of the function.
unsigned numArguments;
const Type *argumentTypes;
};

//===----------------------------------------------------------------------===//
// LLVMPointerTypeStorage.
//===----------------------------------------------------------------------===//

/// Storage type for LLVM dialect pointer types. These are uniqued by a pair of
/// element type and address space. The element type may be null indicating that
/// the pointer is opaque.
struct LLVMPointerTypeStorage : public TypeStorage {
using KeyTy = std::tuple<Type, unsigned>;

LLVMPointerTypeStorage(const KeyTy &key)
: pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {}

static LLVMPointerTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<LLVMPointerTypeStorage>())
LLVMPointerTypeStorage(key);
}

bool operator==(const KeyTy &key) const {
return std::make_tuple(pointeeType, addressSpace) == key;
}

Type pointeeType;
unsigned addressSpace;
};

//===----------------------------------------------------------------------===//
// LLVMTypeAndSizeStorage.
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 0 additions & 14 deletions mlir/lib/IR/SubElementInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,6 @@ void SubElementTypeInterface::walkSubElements(
//===----------------------------------------------------------------------===//
// ReplaceSubElements

/// Return if the given element is mutable.
static bool isMutable(Attribute attr) {
return attr.hasTrait<AttributeTrait::IsMutable>();
}
static bool isMutable(Type type) {
return type.hasTrait<TypeTrait::IsMutable>();
}

template <typename InterfaceT, typename T, typename ReplaceSubElementFnT>
static void updateSubElementImpl(
T element, function_ref<std::pair<T, WalkResult>(T)> walkFn,
Expand Down Expand Up @@ -187,12 +179,6 @@ replaceSubElementsImpl(InterfaceT interface,
if (!*changed)
return interface;

// If this element is mutable, we don't support changing its sub elements, the
// sub element walk doesn't give us a valid ordering for what we need here. If
// we want to support mutable elements, we'll need something more.
if (isMutable(interface))
return {};

// Use the new elements during the replacement.
return interface.replaceImmediateSubElements(newAttrs, newTypes);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ func.func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32
// -----

func.func @null_non_llvm_type() {
// expected-error@+1 {{custom op 'llvm.mlir.null' invalid kind of type specified}}
// expected-error@+1 {{'llvm.mlir.null' op result #0 must be LLVM pointer type, but got 'i32'}}
llvm.mlir.null : i32
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/LLVMIR/layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
// CHECK: size = 8
"test.data_layout_query"() : () -> !llvm.ptr<i8, 5>
// CHECK: alignment = 4
// CHECK: bitsize = 32
// CHECK: bitsize = 32
// CHECK: preferred = 8
// CHECK: size = 4
"test.data_layout_query"() : () -> !llvm.ptr<3>
// CHECK: alignment = 8
// CHECK: bitsize = 32
// CHECK: bitsize = 32
// CHECK: preferred = 8
// CHECK: size = 4
"test.data_layout_query"() : () -> !llvm.ptr<4>
Expand Down