292 changes: 292 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrAttrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
//===-- PtrAttrDefs.td - Ptr Attributes definition file ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PTR_ATTRDEFS
#define PTR_ATTRDEFS

include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/IR/AttrTypeBase.td"

// All of the attributes will extend this class.
class Ptr_Attr<string name, string attrMnemonic,
list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<Ptr_Dialect, name, traits, baseCppClass> {
let mnemonic = attrMnemonic;
}

//===----------------------------------------------------------------------===//
// AliasScopeDomainAttr
//===----------------------------------------------------------------------===//

def Ptr_AliasScopeDomainAttr : Ptr_Attr<"AliasScopeDomain",
"alias_scope_domain"> {
let parameters = (ins
"DistinctAttr":$id,
OptionalParameter<"StringAttr">:$description
);

let builders = [
AttrBuilder<(ins CArg<"StringAttr", "{}">:$description), [{
return $_get($_ctxt,
DistinctAttr::create(UnitAttr::get($_ctxt)), description);
}]>
];

let summary = "Ptr dialect alias scope domain metadata";

let description = [{
Defines a domain that may be associated with an alias scope.

See the following link for more details:
https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata
}];

let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// AliasScopeAttr
//===----------------------------------------------------------------------===//

def Ptr_AliasScopeAttr : Ptr_Attr<"AliasScope", "alias_scope"> {
let parameters = (ins
"DistinctAttr":$id,
"AliasScopeDomainAttr":$domain,
OptionalParameter<"StringAttr">:$description
);

let builders = [
AttrBuilderWithInferredContext<(ins
"AliasScopeDomainAttr":$domain,
CArg<"StringAttr", "{}">:$description
), [{
MLIRContext *ctx = domain.getContext();
return $_get(ctx, DistinctAttr::create(UnitAttr::get(ctx)), domain, description);
}]>
];

let description = [{
Defines an alias scope that can be attached to a memory-accessing operation.
Such scopes can be used in combination with `noalias` metadata to indicate
that sets of memory-affecting operations in one scope do not alias with
memory-affecting operations in another scope.

Example:
```mlir
#domain = #ptr.alias_scope_domain<id = distinct[1]<>, description = "Optional domain description">
#scope1 = #ptr.alias_scope<id = distinct[2]<>, domain = #domain>
#scope2 = #ptr.alias_scope<id = distinct[3]<>, domain = #domain, description = "Optional scope description">
ptr.func @foo(%ptr1 : !ptr.ptr) {
%c0 = llvm.mlir.constant(0 : i32) : i32
%c4 = llvm.mlir.constant(4 : i32) : i32
%1 = ptr.ptrtoint %ptr1 : !ptr.ptr to i32
%2 = llvm.add %1, %c1 : i32
%ptr2 = ptr.inttoptr %2 : i32 to !ptr.ptr
ptr.store %c0, %ptr1 { alias_scopes = [#scope1], ptr.noalias = [#scope2] } : i32, !ptr.ptr
ptr.store %c4, %ptr2 { alias_scopes = [#scope2], ptr.noalias = [#scope1] } : i32, !ptr.ptr
llvm.return
}
```

See the following link for more details:
https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata
}];

let summary = "Ptr dialect alias scope";

let assemblyFormat = "`<` struct(params) `>`";
}

def Ptr_AliasScopeArrayAttr
: TypedArrayAttrBase<Ptr_AliasScopeAttr,
Ptr_AliasScopeAttr.summary # " array"> {
let constBuilderCall = ?;
}

//===----------------------------------------------------------------------===//
// AccessGroupAttr
//===----------------------------------------------------------------------===//

def Ptr_AccessGroupAttr : Ptr_Attr<"AccessGroup", "access_group"> {

let parameters = (ins "DistinctAttr":$id);

let builders = [
AttrBuilder<(ins), [{
return $_get($_ctxt, DistinctAttr::create(UnitAttr::get($_ctxt)));
}]>
];

let summary = "Ptr dialect access group metadata";

let description = [{
Defines an access group metadata that can be set on any instruction
that potentially accesses memory via the `AccessGroupOpInterface` or on
branch instructions in the loop latch block via the `parallelAccesses`
parameter of `LLVM::LoopAnnotationAttr`.

See the following link for more details:
https://llvm.org/docs/LangRef.html#llvm-access-group-metadata
}];

let assemblyFormat = "`<` struct(params) `>`";
}

def Ptr_AccessGroupArrayAttr
: TypedArrayAttrBase<Ptr_AccessGroupAttr,
Ptr_AccessGroupAttr.summary # " array"> {
let constBuilderCall = ?;
}

//===----------------------------------------------------------------------===//
// TBAARootAttr
//===----------------------------------------------------------------------===//

def Ptr_TBAARootAttr : Ptr_Attr<"TBAARoot", "tbaa_root", [], "TBAANodeAttr"> {
let parameters = (ins OptionalParameter<"StringAttr">:$id);

let summary = "Ptr dialect TBAA root metadata";
let description = [{
Defines a TBAA root node.

Example:
```mlir
#cpp_root = #ptr.tbaa_root<identity = "Simple C/C++ TBAA">
#other_root = #ptr.tbaa_root
```

See the following link for more details:
https://llvm.org/docs/LangRef.html#tbaa-metadata
}];

let assemblyFormat = "(`<` struct(params)^ `>`)?";
}

//===----------------------------------------------------------------------===//
// TBAATypeDescriptorAttr
//===----------------------------------------------------------------------===//

def Ptr_TBAAMemberAttr : Ptr_Attr<"TBAAMember", "tbaa_member"> {
let parameters = (ins
"TBAANodeAttr":$typeDesc,
"int64_t":$offset
);

let builders = [
AttrBuilderWithInferredContext<(ins "TBAANodeAttr":$typeDesc,
"int64_t":$offset), [{
return $_get(typeDesc.getContext(), typeDesc, offset);
}]>
];

let assemblyFormat = "`<` params `>`";
}

def Ptr_TBAAMemberAttrArray : ArrayRefParameter<"TBAAMemberAttr"> {
let printer = [{
$_printer << '{';
llvm::interleaveComma($_self, $_printer, [&](TBAAMemberAttr attr) {
$_printer.printStrippedAttrOrType(attr);
});
$_printer << '}';
}];

let parser = [{
[&]() -> FailureOr<SmallVector<TBAAMemberAttr>> {
using Result = SmallVector<TBAAMemberAttr>;
if ($_parser.parseLBrace())
return failure();
FailureOr<Result> result = FieldParser<Result>::parse($_parser);
if (failed(result))
return failure();
if ($_parser.parseRBrace())
return failure();
return result;
}()
}];
}

def Ptr_TBAATypeDescriptorAttr : Ptr_Attr<"TBAATypeDescriptor",
"tbaa_type_desc", [], "TBAANodeAttr"> {
let parameters = (ins
StringRefParameter<>:$id,
Ptr_TBAAMemberAttrArray:$members
);

let summary = "Ptr dialect TBAA type metadata";

let description = [{
Defines a TBAA node describing a type.

Example:
```mlir
#tbaa_root = #ptr.tbaa_root<identity = "Simple C/C++ TBAA">
#tbaa_type_desc1 = #ptr.tbaa_type_desc<id = "omnipotent char", members = {<#tbaa_root, 0>}>
#tbaa_type_desc2 = #ptr.tbaa_type_desc<id = "long long", members = {<#tbaa_root, 0>}>
#tbaa_type_desc3 = #ptr.tbaa_type_desc<id = "agg2_t", members = {<#tbaa_type_desc2, 0>, <#tbaa_type_desc2, 8>}>
#tbaa_type_desc4 = #ptr.tbaa_type_desc<id = "int", members = {<#tbaa_type_desc1, 0>}>
#tbaa_type_desc5 = #ptr.tbaa_type_desc<id = "agg1_t", members = {<#tbaa_type_desc4, 0>, <#tbaa_type_desc4, 4>}>
```

See the following link for more details:
https://llvm.org/docs/LangRef.html#tbaa-metadata
}];

let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// TBAATagAttr
//===----------------------------------------------------------------------===//

def Ptr_TBAATagAttr : Ptr_Attr<"TBAATag", "tbaa_tag"> {
let parameters = (ins
"TBAATypeDescriptorAttr":$base_type,
"TBAATypeDescriptorAttr":$access_type,
"int64_t":$offset,
DefaultValuedParameter<"bool", "false">:$constant
);

let builders = [
AttrBuilderWithInferredContext<(ins "TBAATypeDescriptorAttr":$baseType,
"TBAATypeDescriptorAttr":$accessType,
"int64_t":$offset), [{
return $_get(baseType.getContext(), baseType, accessType, offset,
/*constant=*/false);
}]>
];

let summary = "Ptr dialect TBAA tag metadata";

let description = [{
Defines a TBAA node describing a memory access.

Example:
```mlir
#tbaa_root = #ptr.tbaa_root<identity = "Simple C/C++ TBAA">
#tbaa_type_desc1 = #ptr.tbaa_type_desc<id = "omnipotent char", members = {<#tbaa_root, 0>}>
#tbaa_type_desc2 = #ptr.tbaa_type_desc<id = "int", members = {<#tbaa_type_desc1, 0>}>
#tbaa_type_desc3 = #ptr.tbaa_type_desc<id = "agg1_t", members = {<#tbaa_type_desc4, 0>, <#tbaa_type_desc4, 4>}>
#tbaa_tag = #ptr.tbaa_tag<base_type = #tbaa_type_desc3, access_type = #tbaa_type_desc2, offset = 0, constant = true>
```

See the following link for more details:
https://llvm.org/docs/LangRef.html#tbaa-metadata
}];

let assemblyFormat = "`<` struct(params) `>`";
}

def Ptr_TBAATagArrayAttr
: TypedArrayAttrBase<Ptr_TBAATagAttr,
Ptr_TBAATagAttr.summary # " array"> {
let constBuilderCall = ?;
}

#endif // PTR_ATTRDEFS
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- PointerDialect.h - Pointer dialect -----------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Ptr dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PTR_IR_PTRDIALECT_H
#define MLIR_DIALECT_PTR_IR_PTRDIALECT_H

#include "mlir/IR/Dialect.h"

#include "mlir/Dialect/Ptr/IR/PtrOpsDialect.h.inc"

#endif // MLIR_DIALECT_PTR_IR_PTRDIALECT_H
89 changes: 89 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===- PointerDialect.td - Pointer dialect -----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PTR_DIALECT
#define PTR_DIALECT

include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/AsmInterfaces.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Pointer dialect definition.
//===----------------------------------------------------------------------===//

def Ptr_Dialect : Dialect {
let name = "ptr";
let summary = "Pointer dialect";
let cppNamespace = "::mlir::ptr";
let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
// Pointer type definitions
//===----------------------------------------------------------------------===//

class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Ptr_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries"]>,
DeclareTypeInterfaceMethods<TypeAsmAliasTypeInterface, [
"getAliasDialect"]>
]> {
let summary = "pointer type";
let description = [{
The `ptr` type is an opaque pointer type. This type typically represents
a reference to an object in memory. Pointers are optionally parameterized
by a memory space.
Syntax:

```mlir
pointer ::= `ptr` (`<` memory-space `>`)?
memory-space ::= attribute-value
```
}];
let parameters = (ins OptionalParameter<"Attribute">:$memorySpace);
let assemblyFormat = "(`<` $memorySpace^ `>`)?";
let builders = [
TypeBuilder<(ins CArg<"Attribute", "nullptr">:$addressSpace), [{
return $_get($_ctxt, addressSpace);
}]>,
TypeBuilder<(ins CArg<"unsigned">:$addressSpace), [{
return $_get($_ctxt, IntegerAttr::get(IntegerType::get($_ctxt, 32),
addressSpace));
}]>
];
let skipDefaultBuilders = 1;
let extraClassDeclaration = [{
/// Returns the default memory space.
Attribute getDefaultMemorySpace() const;

/// Returns the memory space as an unsigned number.
int64_t getAddressSpace() const;

/// Returns the memory space attribute wrapped in the `MemoryModel` class.
MemoryModel getMemoryModel() const;
}];
}

//===----------------------------------------------------------------------===//
// Base address operation definition.
//===----------------------------------------------------------------------===//

class Pointer_Op<string mnemonic, list<Trait> traits = []> :
Op<Ptr_Dialect, mnemonic, traits>;

#endif // PTR_DIALECT
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrEnums.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===-- PtrEnums.td - Ptr dialect enumerations -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PTR_ENUMS
#define PTR_ENUMS

include "mlir/IR/EnumAttr.td"

//===----------------------------------------------------------------------===//
// Atomic binary op enum attribute
//===----------------------------------------------------------------------===//

def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0, "xchg">;
def AtomicBinOpAdd : I64EnumAttrCase<"add", 1, "add">;
def AtomicBinOpSub : I64EnumAttrCase<"sub", 2, "sub">;
def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3, "_and">;
def AtomicBinOpNand : I64EnumAttrCase<"nand", 4, "nand">;
def AtomicBinOpOr : I64EnumAttrCase<"_or", 5, "_or">;
def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6, "_xor">;
def AtomicBinOpMax : I64EnumAttrCase<"max", 7, "max">;
def AtomicBinOpMin : I64EnumAttrCase<"min", 8, "min">;
def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9, "umax">;
def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10, "umin">;
def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11, "fadd">;
def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12, "fsub">;
def AtomicBinOpFMax : I64EnumAttrCase<"fmax", 13, "fmax">;
def AtomicBinOpFMin : I64EnumAttrCase<"fmin", 14, "fmin">;
def AtomicBinOpUIncWrap : I64EnumAttrCase<"uinc_wrap", 15, "uinc_wrap">;
def AtomicBinOpUDecWrap : I64EnumAttrCase<"udec_wrap", 16, "udec_wrap">;

def AtomicBinOp : I64EnumAttr<
"AtomicBinOp",
"ptr.atomicrmw binary operations",
[AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
AtomicBinOpFSub, AtomicBinOpFMax, AtomicBinOpFMin, AtomicBinOpUIncWrap,
AtomicBinOpUDecWrap]> {
let cppNamespace = "::mlir::ptr";
}

//===----------------------------------------------------------------------===//
// Atomic ordering enum attribute
//===----------------------------------------------------------------------===//

def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0, "not_atomic">;
def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1, "unordered">;
def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2, "monotonic">;
def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 3, "acquire">;
def AtomicOrderingRelease : I64EnumAttrCase<"release", 4, "release">;
def AtomicOrderingAcqRel : I64EnumAttrCase<"acq_rel", 5, "acq_rel">;
def AtomicOrderingSeqCst : I64EnumAttrCase<"seq_cst", 6, "seq_cst">;

def AtomicOrdering : I64EnumAttr<
"AtomicOrdering",
"Atomic ordering for LLVM's memory model",
[AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcqRel,
AtomicOrderingSeqCst
]> {
let cppNamespace = "::mlir::ptr";
}

#endif // PTR_ENUMS
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- PtrInterfaces.h - Ptr Interfaces -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines op interfaces for the Ptr dialect in MLIR.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PTR_IR_PTRINTERFACES_H
#define MLIR_DIALECT_PTR_IR_PTRINTERFACES_H

#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"

namespace mlir {
namespace ptr {
namespace detail {
/// Verifies the access groups attribute of memory operations that implement the
/// access group interface.
LogicalResult verifyAccessGroupOpInterface(Operation *op);

/// Verifies the alias analysis attributes of memory operations that implement
/// the alias analysis interface.
LogicalResult verifyAliasAnalysisOpInterface(Operation *op);
} // namespace detail
} // namespace ptr
} // namespace mlir

#include "mlir/Dialect/Ptr/IR/PtrInterfaces.h.inc"

#endif // MLIR_DIALECT_PTR_IR_PTRINTERFACES_H
152 changes: 152 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//===-- PtrInterfaces.td - Ptr dialect interfaces ----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines ptr dialect interfaces.
//
//===----------------------------------------------------------------------===//

#ifndef PTR_INTERFACES
#define PTR_INTERFACES

include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// Access group interface.
//===----------------------------------------------------------------------===//

def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
let description = [{
An interface for memory operations that can carry access groups metadata.
It provides setters and getters for the operation's access groups attribute.
The default implementations of the interface methods expect the operation
to have an attribute of type ArrayAttr named access_groups.
}];

let cppNamespace = "::mlir::ptr";
let verify = [{ return detail::verifyAccessGroupOpInterface($_op); }];

let methods = [
InterfaceMethod<
/*desc=*/ "Returns the access groups attribute or nullptr",
/*returnType=*/ "ArrayAttr",
/*methodName=*/ "getAccessGroupsOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getAccessGroupsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the access groups attribute",
/*returnType=*/ "void",
/*methodName=*/ "setAccessGroups",
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setAccessGroupsAttr(attr);
}]
>
];
}

//===----------------------------------------------------------------------===//
// Alias analysis interface.
//===----------------------------------------------------------------------===//

def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
let description = [{
An interface for memory operations that can carry alias analysis metadata.
It provides setters and getters for the operation's alias analysis
attributes. The default implementations of the interface methods expect
the operation to have attributes of type ArrayAttr named alias_scopes,
noalias_scopes, and tbaa.
}];

let cppNamespace = "::mlir::ptr";
let verify = [{ return detail::verifyAliasAnalysisOpInterface($_op); }];

let methods = [
InterfaceMethod<
/*desc=*/ "Returns the alias scopes attribute or nullptr",
/*returnType=*/ "ArrayAttr",
/*methodName=*/ "getAliasScopesOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getAliasScopesAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the alias scopes attribute",
/*returnType=*/ "void",
/*methodName=*/ "setAliasScopes",
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setAliasScopesAttr(attr);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns the noalias scopes attribute or nullptr",
/*returnType=*/ "ArrayAttr",
/*methodName=*/ "getNoAliasScopesOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getNoaliasScopesAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the noalias scopes attribute",
/*returnType=*/ "void",
/*methodName=*/ "setNoAliasScopes",
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setNoaliasScopesAttr(attr);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns the tbaa attribute or nullptr",
/*returnType=*/ "ArrayAttr",
/*methodName=*/ "getTBAATagsOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getTbaaAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the tbaa attribute",
/*returnType=*/ "void",
/*methodName=*/ "setTBAATags",
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setTbaaAttr(attr);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns a list of all pointer operands accessed by the "
"operation",
/*returnType=*/ "::llvm::SmallVector<::mlir::Value>",
/*methodName=*/ "getAccessedOperands",
/*args=*/ (ins)
>
];
}

#endif // PTR_MEMORYSPACEINTERFACES
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- PointerDialect.h - Pointer dialect -----------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Pointer dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PTR_IR_PTROPS_H
#define MLIR_DIALECT_PTR_IR_PTROPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Ptr/IR/PtrAttrs.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrInterfaces.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOps.h.inc"

#endif // MLIR_DIALECT_PTR_IR_PTROPS_H
411 changes: 411 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- PointerTypes.h - Pointer types ---------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Pointer dialect types.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_PTR_IR_PTRTYPES_H
#define MLIR_DIALECT_PTR_IR_PTRTYPES_H

#include "mlir/Dialect/Ptr/IR/MemoryModel.h"
#include "mlir/IR/AsmInterfaces.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"

namespace mlir {
namespace ptr {
/// The positions of different values in the data layout entry for pointers.
enum class PtrDLEntryPos { Size = 0, Abi = 1, Preferred = 2, Index = 3 };

/// Returns the value that corresponds to named position `pos` from the
/// data layout entry `attr` assuming it's a dense integer elements attribute.
/// Returns `std::nullopt` if `pos` is not present in the entry.
/// Currently only `PtrDLEntryPos::Index` is optional, and all other positions
/// may be assumed to be present.
std::optional<uint64_t> extractPointerSpecValue(Attribute attr,
PtrDLEntryPos pos);
} // namespace ptr
} // namespace mlir

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Ptr/IR/PtrOpsTypes.h.inc"

#endif // MLIR_DIALECT_PTR_IR_PTRTYPES_H
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/AsmInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- AsmInterfaces.h - Asm Interfaces -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_ASMINTERFACES_H
#define MLIR_IR_ASMINTERFACES_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"

#include "mlir/IR/AsmAttrInterfaces.h.inc"

#include "mlir/IR/AsmTypeInterfaces.h.inc"

#endif // MLIR_IR_ASMINTERFACES_H
60 changes: 60 additions & 0 deletions mlir/include/mlir/IR/AsmInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===- AsmInterfaces.td - Asm Interfaces -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces and other utilities for interacting with the
// AsmParser and AsmPrinter.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_ASMINTERFACES_TD
#define MLIR_IR_ASMINTERFACES_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

//===----------------------------------------------------------------------===//
// AttrAsmAliasAttrInterface
//===----------------------------------------------------------------------===//

def AttrAsmAliasAttrInterface : AttrInterface<"AttrAsmAliasAttrInterface"> {
let cppNamespace = "::mlir";
let description = [{
This interface allows aliasing an attribute between dialects, allowing
custom printing of an attribute by an external dialect.
}];
let methods = [
InterfaceMethod<[{
Returns the dialect responsible for printing and parsing the attribute
instance.
}],
"Dialect*", "getAliasDialect", (ins), [{}], [{}]
>
];
}

//===----------------------------------------------------------------------===//
// TypeAsmAliasTypeInterface
//===----------------------------------------------------------------------===//

def TypeAsmAliasTypeInterface : TypeInterface<"TypeAsmAliasTypeInterface"> {
let cppNamespace = "::mlir";
let description = [{
This interface allows aliasing a type between dialects, allowing custom
printing of a type by an external dialect.
}];
let methods = [
InterfaceMethod<[{
Returns the dialect responsible for printing and parsing the type
instance.
}],
"Dialect*", "getAliasDialect", (ins), [{}], [{}]
>
];
}

#endif // MLIR_IR_ASMINTERFACES_TD
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ add_mlir_interface(OpAsmInterface)
add_mlir_interface(SymbolInterfaces)
add_mlir_interface(RegionKindInterface)

set(LLVM_TARGET_DEFINITIONS AsmInterfaces.td)
mlir_tablegen(AsmAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(AsmAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(AsmTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(AsmTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRAsmInterfacesIncGen)

set(LLVM_TARGET_DEFINITIONS BuiltinAttributes.td)
mlir_tablegen(BuiltinAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(BuiltinAttributes.cpp.inc -gen-attrdef-defs)
Expand Down
28 changes: 27 additions & 1 deletion mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1715,7 +1715,10 @@ class OpAsmDialectInterface
OverridableAlias,
/// An alias was provided and it should be used
/// (no other hooks will be checked).
FinalAlias
FinalAlias,
/// A dialect alias was provided and it will be used
/// (no other hooks will be checked).
DialectAlias
};

/// Hooks for getting an alias identifier alias for a given symbol, that is
Expand All @@ -1729,6 +1732,29 @@ class OpAsmDialectInterface
return AliasResult::NoAlias;
}

/// Hooks for parsing a dialect alias. The method returns success if the
/// dialect has an alias for the symbol, otherwise it must return failure.
/// If there was an error during parsing, this method should return success
/// and set the attribute to null.
virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
Attribute &attr, Type type) const {
return failure();
}
virtual LogicalResult parseDialectAlias(DialectAsmParser &parser,
Type &type) const {
return failure();
}
/// Hooks for printing a dialect alias. The method returns success if the
/// dialect has an alias for the symbol, otherwise it must return failure.
virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
Attribute attr) const {
return failure();
}
virtual LogicalResult printDialectAlias(DialectAsmPrinter &printer,
Type type) const {
return failure();
}

//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
Expand Down Expand Up @@ -131,6 +132,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
omp::OpenMPDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
ptr::PtrDialect,
quant::QuantizationDialect,
ROCDL::ROCDLDialect,
scf::SCFDialect,
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,27 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
(ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<[{
Checks whether the operation requires visiting the mutated
definitions by a store operation.
}], "bool", "requiresVisitingMutatedDefs", (ins), [{}],
[{ return false; }]
>,
InterfaceMethod<[{
Visits all the mutated definitions by a store operation.

This method will only be called after all blocking issues haven been
scheduled for removal and if `requiresVisitingMutatedDefs` returned
true.

The rewriter is located after the promotable operation on call. All IR
mutations must happen through the rewriter. During the transformation,
*no operation should be deleted*.
}],
"void", "visitMutatedDefs",
(ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
"::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
>,
];
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/Dialect/All.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/Ptr/LLVMIRToPtrTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
Expand All @@ -47,6 +49,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
registerNVVMDialectTranslation(registry);
registerOpenACCDialectTranslation(registry);
registerOpenMPDialectTranslation(registry);
registerPtrDialectTranslation(registry);
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);
registerVCIXDialectTranslation(registry);
Expand All @@ -65,6 +68,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
registerGPUDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
registerNVVMDialectTranslation(registry);
registerPtrDialectTranslation(registry);
registerROCDLDialectTranslation(registry);
registerSPIRVDialectTranslation(registry);

Expand All @@ -77,6 +81,7 @@ registerAllGPUToLLVMIRTranslations(DialectRegistry &registry) {
static inline void
registerAllFromLLVMIRTranslations(DialectRegistry &registry) {
registerLLVMDialectImport(registry);
registerPtrDialectImport(registry);
registerNVVMDialectImport(registry);
}
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- LLVMIRToPtrTranslation.h - LLVM IR to Ptr Dialect --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for LLVM IR to Ptr dialect translation.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TARGET_LLVMIR_DIALECT_PTR_LLVMIRTOPTRTRANSLATION_H
#define MLIR_TARGET_LLVMIR_DIALECT_PTR_LLVMIRTOPTRTRANSLATION_H

namespace mlir {

class DialectRegistry;
class MLIRContext;

/// Registers the Ptr dialect and its import from LLVM IR in the given
/// registry.
void registerPtrDialectImport(DialectRegistry &registry);

/// Registers the Ptr dialect and its import from LLVM IR with the given
/// context.
void registerPtrDialectImport(MLIRContext &context);

} // namespace mlir

#endif // MLIR_TARGET_LLVMIR_DIALECT_PTR_LLVMIRTOPTRTRANSLATION_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- PtrToLLVMIRTranslation.h - Ptr Dialect to LLVM IR --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This provides registration calls for Ptr dialect to LLVM IR translation.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
#define MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H

namespace mlir {

class DialectRegistry;
class MLIRContext;

/// Register the Ptr dialect and the translation from it to the LLVM IR in
/// the given registry;
void registerPtrDialectTranslation(DialectRegistry &registry);

/// Register the Ptr dialect and the translation from it in the registry
/// associated with the given context.
void registerPtrDialectTranslation(MLIRContext &context);

} // namespace mlir

#endif // MLIR_TARGET_LLVMIR_DIALECT_PTR_PTRTOLLVMIRTRANSLATION_H
53 changes: 53 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/LLVMImportInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ class LLVMImportDialectInterface
return failure();
}

/// Hook for derived dialect interfaces to implement the import of
/// instructions into MLIR.
virtual LogicalResult
convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
ArrayRef<llvm::Value *> llvmOperands,
LLVM::ModuleImport &moduleImport) const {
return failure();
}

/// Hook for derived dialect interfaces to implement the import of metadata
/// into MLIR. Attaches the converted metadata kind and node to the provided
/// operation.
Expand All @@ -66,6 +75,11 @@ class LLVMImportDialectInterface
/// returns the list of supported intrinsic identifiers.
virtual ArrayRef<unsigned> getSupportedIntrinsics() const { return {}; }

/// Hook for derived dialect interfaces to publish the supported instructions.
/// As every LLVM IR instructions has a unique integer identifier, the
/// function returns the list of supported instructions identifiers.
virtual ArrayRef<unsigned> getSupportedInstructions() const { return {}; }

/// Hook for derived dialect interfaces to publish the supported metadata
/// kinds. As every metadata kind has a unique integer identifier, the
/// function returns the list of supported metadata identifiers.
Expand Down Expand Up @@ -100,9 +114,27 @@ class LLVMImportInterface
*it, iface.getDialect()->getNamespace(),
intrinsicToDialect.lookup(*it)->getNamespace()));
}
const auto *instIt =
llvm::find_if(iface.getSupportedInstructions(), [&](unsigned id) {
return instructionToDialect.count(id);
});
if (instIt != iface.getSupportedInstructions().end()) {
return emitError(
UnknownLoc::get(iface.getContext()),
llvm::formatv(
"expected unique conversion for instruction ({0}), but "
"got conflicting {1} and {2} conversions",
*it, iface.getDialect()->getNamespace(),
instructionToDialect.lookup(*it)
->getDialect()
->getNamespace()));
}
// Add a mapping for all supported intrinsic identifiers.
for (unsigned id : iface.getSupportedIntrinsics())
intrinsicToDialect[id] = iface.getDialect();
// Add a mapping for all supported instruction identifiers.
for (unsigned id : iface.getSupportedInstructions())
instructionToDialect[id] = &iface;
// Add a mapping for all supported metadata kinds.
for (unsigned kind : iface.getSupportedMetadata())
metadataToDialect[kind].push_back(iface.getDialect());
Expand Down Expand Up @@ -132,6 +164,26 @@ class LLVMImportInterface
return intrinsicToDialect.count(id);
}

/// Converts the LLVM instruction to an MLIR operation if a conversion exists.
/// Returns failure otherwise.
LogicalResult convertInstruction(OpBuilder &builder, llvm::Instruction *inst,
ArrayRef<llvm::Value *> llvmOperands,
LLVM::ModuleImport &moduleImport) const {
// Lookup the dialect interface for the given instruction.
const LLVMImportDialectInterface *iface =
instructionToDialect.lookup(inst->getOpcode());
if (!iface)
return failure();

return iface->convertInstruction(builder, inst, llvmOperands, moduleImport);
}

/// Returns true if the given LLVM IR instruction is convertible to an MLIR
/// operation.
bool isConvertibleInstruction(unsigned id) {
return instructionToDialect.count(id);
}

/// Attaches the given LLVM metadata to the imported operation if a conversion
/// to one or more MLIR dialect attributes exists and succeeds. Returns
/// success if at least one of the conversions is successful and failure if
Expand Down Expand Up @@ -166,6 +218,7 @@ class LLVMImportInterface

private:
DenseMap<unsigned, Dialect *> intrinsicToDialect;
DenseMap<unsigned, const LLVMImportDialectInterface *> instructionToDialect;
DenseMap<unsigned, SmallVector<Dialect *, 1>> metadataToDialect;
};

Expand Down
21 changes: 15 additions & 6 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,22 @@ class ModuleImport {
/// Provides write-once access to store the MLIR value corresponding to the
/// given LLVM value.
Value &mapValue(llvm::Value *value) {
Value &mlir = valueMapping[value];
Value &mlir = valueMapping[value].first;
assert(mlir == nullptr &&
"attempting to map a value that is already mapped");
return mlir;
}
std::pair<Value, Operation *> &mapOp(llvm::Value *value) {
auto &mlir = valueMapping[value];
assert(mlir.first == nullptr &&
"attempting to map a value that is already mapped");
return mlir;
}

/// Returns the MLIR value mapped to the given LLVM value.
Value lookupValue(llvm::Value *value) { return valueMapping.lookup(value); }
Value lookupValue(llvm::Value *value) {
return valueMapping.lookup(value).first;
}

/// Stores a mapping between an LLVM instruction and the imported MLIR
/// operation if the operation returns no result. Asserts if the operation
Expand All @@ -107,9 +115,10 @@ class ModuleImport {
/// Returns the MLIR operation mapped to the given LLVM instruction. Queries
/// valueMapping and noResultOpMapping to support operations with and without
/// result.
Operation *lookupOperation(llvm::Instruction *inst) {
if (Value value = lookupValue(inst))
return value.getDefiningOp();
Operation *lookupOperation(llvm::Instruction *inst, bool getOp = false) {
if (std::pair<Value, Operation *> value = valueMapping.lookup(inst);
value.first)
return getOp && value.second ? value.second : value.first.getDefiningOp();
return noResultOpMapping.lookup(inst);
}

Expand Down Expand Up @@ -376,7 +385,7 @@ class ModuleImport {
/// Function-local mapping between original and imported block.
DenseMap<llvm::BasicBlock *, Block *> blockMapping;
/// Function-local mapping between original and imported values.
DenseMap<llvm::Value *, Value> valueMapping;
DenseMap<llvm::Value *, std::pair<Value, Operation *>> valueMapping;
/// Function-local mapping between original instructions and imported
/// operations for all operations that return no result. All operations that
/// return a result have a valueMapping entry instead.
Expand Down
9 changes: 6 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,19 @@ namespace mlir {
class Attribute;
class Block;
class Location;
namespace ptr {
class AliasScopeAttr;
class AliasScopeDomainAttr;
} // namespace ptr

namespace LLVM {

namespace detail {
class DebugTranslation;
class LoopAnnotationTranslation;
} // namespace detail

class AliasScopeAttr;
class AliasScopeDomainAttr;
using AliasScopeAttr = ::mlir::ptr::AliasScopeAttr;
using AliasScopeDomainAttr = ::mlir::ptr::AliasScopeDomainAttr;
class DINodeAttr;
class LLVMFuncOp;
class ComdatSelectorOp;
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/AsmParser/DialectSymbolParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ Attribute Parser::parseExtendedAttr(Type type) {

// Parse the attribute.
CustomDialectAsmParser customParser(symbolData, *this);
if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
Attribute attr{};
if (succeeded(iface->parseDialectAlias(customParser, attr, type)))
return attr;
resetToken(symbolData.data());
}
Attribute attr = dialect->parseAttribute(customParser, attrType);
resetToken(curLexerPos);
return attr;
Expand Down Expand Up @@ -310,6 +316,12 @@ Type Parser::parseExtendedType() {

// Parse the type.
CustomDialectAsmParser customParser(symbolData, *this);
if (auto iface = dyn_cast<OpAsmDialectInterface>(dialect)) {
Type type{};
if (succeeded(iface->parseDialectAlias(customParser, type)))
return type;
resetToken(symbolData.data());
}
Type type = dialect->parseType(customParser);
resetToken(curLexerPos);
return type;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalOp<arith::ConstantOp, func::ConstantOp,
UnrealizedConversionCastOp>();
target.addLegalDialect<ptr::PtrDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();

// All operations from Async dialect must be lowered to the runtime API and
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ConvertToLLVMPass
LogicalResult initialize(MLIRContext *context) final {
RewritePatternSet tempPatterns(context);
auto target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<ptr::PtrDialect>();
target->addLegalDialect<LLVM::LLVMDialect>();
auto typeConverter = std::make_shared<LLVMTypeConverter>(context);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ void GpuToLLVMConversionPass::runOnOperation() {
options.useBarePtrCallConv = hostBarePtrCallConv;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<ptr::PtrDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
LLVMTypeConverter converter(context, options);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/LLVMCommon/ConversionTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using namespace mlir;

mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
this->addLegalDialect<ptr::PtrDialect>();
this->addLegalDialect<LLVM::LLVMDialect>();
this->addLegalOp<UnrealizedConversionCastOp>();
}
11 changes: 6 additions & 5 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,17 +449,18 @@ struct GenericAtomicRMWOpLowering
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
// Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
// Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);

// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);
rewriter.create<LLVM::CondBrOp>(loc, cmpxchg.getStatus(), endBlock,
ArrayRef<Value>(), loopBlock,
cmpxchg.getRes());

rewriter.setInsertionPointToEnd(endBlock);

// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(atomicOp, {newLoaded});
rewriter.replaceOp(atomicOp, {cmpxchg.getRes()});

return success();
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ struct ConvertNVGPUToNVVMPass
});
populateNVGPUToNVVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
target.addLegalDialect<ptr::PtrDialect>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::arith::ArithDialect>();
target.addLegalDialect<::mlir::memref::MemRefDialect>();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
target.addLegalDialect<omp::OpenMPDialect, ptr::PtrDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();

RewritePatternSet patterns(module.getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class LowerHostCodeToLLVM
populateSPIRVToLLVMTypeConversion(typeConverter);

ConversionTarget target(*context);
target.addLegalDialect<ptr::PtrDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
ConversionTarget target(*context);
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<ptr::PtrDialect>();

if (clientAPI != spirv::ClientAPI::OpenCL &&
clientAPI != spirv::ClientAPI::Unknown)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
add_subdirectory(PDLInterp)
add_subdirectory(Ptr)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
MLIRFunctionInterfaces
MLIRInferTypeOpInterface
MLIRIR
MLIRPtrDialect
MLIRMemorySlotInterfaces
MLIRSideEffectInterfaces
MLIRSupport
Expand Down
161 changes: 153 additions & 8 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,159 @@ void LLVMDialect::registerAttributes() {
>();
}

//===----------------------------------------------------------------------===//
// AddressSpaceAttr
//===----------------------------------------------------------------------===//

static bool isLoadableType(Type type) {
return /*LLVM_PrimitiveType*/ (
LLVM::isCompatibleOuterType(type) &&
!isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) &&
/*LLVM_OpaqueStruct*/
!(isa<LLVM::LLVMStructType>(type) &&
cast<LLVM::LLVMStructType>(type).isOpaque()) &&
/*LLVM_AnyTargetExt*/
!(isa<LLVM::LLVMTargetExtType>(type) &&
!cast<LLVM::LLVMTargetExtType>(type).supportsMemOps());
}

/// Returns true if the given type is supported by atomic operations. All
/// integer and float types with limited bit width are supported. Additionally,
/// depending on the operation pointers may be supported as well.
static bool isTypeCompatibleWithAtomicOp(Type type) {
if (llvm::isa<LLVMPointerType>(type))
return true;

std::optional<unsigned> bitWidth;
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
if (!isCompatibleFloatingPointType(type))
return false;
bitWidth = floatType.getWidth();
}
if (auto integerType = llvm::dyn_cast<IntegerType>(type))
bitWidth = integerType.getWidth();
// The type is neither an integer, float, or pointer type.
if (!bitWidth)
return false;
return *bitWidth == 8 || *bitWidth == 16 || *bitWidth == 32 ||
*bitWidth == 64;
}

Dialect *AddressSpaceAttr::getMemorySpaceDialect() const {
return &getDialect();
}

Attribute AddressSpaceAttr::getDefaultMemorySpace() const {
return AddressSpaceAttr::get(getContext(), 0);
}

unsigned AddressSpaceAttr::getAddressSpace() const { return getAs(); }

LogicalResult AddressSpaceAttr::isValidLoad(Type type,
mlir::ptr::AtomicOrdering ordering,
IntegerAttr alignment,
Operation *diagnosticOp) const {
if (!isLoadableType(type))
return diagnosticOp ? diagnosticOp->emitError(
"type must be LLVM type with size, but got ")
<< type
: failure();
if (ordering != ptr::AtomicOrdering::not_atomic &&
!isTypeCompatibleWithAtomicOp(type))
return diagnosticOp ? diagnosticOp->emitError("unsupported type ")
<< type << " for atomic access"
: failure();
return success();
}

LogicalResult AddressSpaceAttr::isValidStore(Type type,
mlir::ptr::AtomicOrdering ordering,
IntegerAttr alignment,
Operation *diagnosticOp) const {
if (!isLoadableType(type))
return diagnosticOp ? diagnosticOp->emitError(
"type must be LLVM type with size, but got ")
<< type
: failure();
if (ordering != ptr::AtomicOrdering::not_atomic &&
!isTypeCompatibleWithAtomicOp(type))
return diagnosticOp ? diagnosticOp->emitError("unsupported type ")
<< type << " for atomic access"
: failure();
return success();
}

LogicalResult AddressSpaceAttr::isValidAtomicOp(
mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering,
IntegerAttr alignment, Operation *diagnosticOp) const {
if (binOp == ptr::AtomicBinOp::fadd || binOp == ptr::AtomicBinOp::fsub ||
binOp == ptr::AtomicBinOp::fmin || binOp == ptr::AtomicBinOp::fmax) {
if (!mlir::LLVM::isCompatibleFloatingPointType(type))
return diagnosticOp ? diagnosticOp->emitError(
"expected LLVM IR floating point type")
: failure();
} else if (binOp == ptr::AtomicBinOp::xchg) {
if (!isTypeCompatibleWithAtomicOp(type))
return diagnosticOp ? diagnosticOp->emitError(
"unexpected LLVM IR type for 'xchg' bin_op")
: failure();
} else {
auto intType = llvm::dyn_cast<IntegerType>(type);
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64)
return diagnosticOp
? diagnosticOp->emitError("expected LLVM IR integer type")
: failure();
}
return success();
}

LogicalResult AddressSpaceAttr::isValidAtomicXchg(
Type type, mlir::ptr::AtomicOrdering successOrdering,
mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
Operation *diagnosticOp) const {
if (!isLoadableType(type))
return diagnosticOp ? diagnosticOp->emitError(
"type must be LLVM type with size, but got ")
<< type
: failure();
if (!isTypeCompatibleWithAtomicOp(type))
return diagnosticOp ? diagnosticOp->emitError("unexpected LLVM IR type")
: failure();
return success();
}

template <typename Ty>
static bool isScalarOrVectorOf(Type ty) {
return isa<Ty>(ty) || (LLVM::isCompatibleVectorType(ty) &&
isa<Ty>(LLVM::getVectorElementType(ty)));
}

LogicalResult
AddressSpaceAttr::isValidAddrSpaceCast(Type tgt, Type src,
Operation *diagnosticOp) const {
if (!isScalarOrVectorOf<LLVMPointerType>(tgt))
return diagnosticOp ? diagnosticOp->emitError("invalid ptr-like operand")
: failure();
if (!isScalarOrVectorOf<LLVMPointerType>(src))
return diagnosticOp ? diagnosticOp->emitError("invalid ptr-like operand")
: failure();
return success();
}

LogicalResult
AddressSpaceAttr::isValidPtrIntCast(Type intLikeTy, Type ptrLikeTy,
Operation *diagnosticOp) const {
if (!isScalarOrVectorOf<IntegerType>(intLikeTy))
return diagnosticOp ? diagnosticOp->emitError("invalid int-like type")
: failure();
if (!isScalarOrVectorOf<LLVMPointerType>(ptrLikeTy))
return diagnosticOp ? diagnosticOp->emitError("invalid ptr-like type")
: failure();
return success();
}

//===----------------------------------------------------------------------===//
// DINodeAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -91,14 +244,6 @@ bool DITypeAttr::classof(Attribute attr) {
DIDerivedTypeAttr, DISubroutineTypeAttr>(attr);
}

//===----------------------------------------------------------------------===//
// TBAANodeAttr
//===----------------------------------------------------------------------===//

bool TBAANodeAttr::classof(Attribute attr) {
return llvm::isa<TBAATypeDescriptorAttr, TBAARootAttr>(attr);
}

//===----------------------------------------------------------------------===//
// MemoryEffectsAttr
//===----------------------------------------------------------------------===//
Expand Down
265 changes: 50 additions & 215 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -45,6 +46,12 @@ using namespace mlir::LLVM;
using mlir::LLVM::cconv::getMaxEnumValForCConv;
using mlir::LLVM::linkage::getMaxEnumValForLinkage;

static ParseResult parseAtomicOrdering(OpAsmParser &parser,
ptr::AtomicOrderingAttr &attr);

static void printAtomicOrdering(OpAsmPrinter &p, Operation *op,
ptr::AtomicOrderingAttr attr);

#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"

static constexpr const char kElemTypeAttrName[] = "elem_type";
Expand Down Expand Up @@ -750,127 +757,6 @@ Type GEPOp::getResultPtrElementType() {
return selectedType;
}

//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//

void LoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), getAddr());
// Volatile operations can have target-specific read-write effects on
// memory besides the one referred to by the pointer operand.
// Similarly, atomic operations that are monotonic or stricter cause
// synchronization that from a language point-of-view, are arbitrary
// read-writes into memory.
if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
getOrdering() != AtomicOrdering::unordered)) {
effects.emplace_back(MemoryEffects::Write::get());
effects.emplace_back(MemoryEffects::Read::get());
}
}

/// Returns true if the given type is supported by atomic operations. All
/// integer and float types with limited bit width are supported. Additionally,
/// depending on the operation pointers may be supported as well.
static bool isTypeCompatibleWithAtomicOp(Type type, bool isPointerTypeAllowed) {
if (llvm::isa<LLVMPointerType>(type))
return isPointerTypeAllowed;

std::optional<unsigned> bitWidth;
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
if (!isCompatibleFloatingPointType(type))
return false;
bitWidth = floatType.getWidth();
}
if (auto integerType = llvm::dyn_cast<IntegerType>(type))
bitWidth = integerType.getWidth();
// The type is neither an integer, float, or pointer type.
if (!bitWidth)
return false;
return *bitWidth == 8 || *bitWidth == 16 || *bitWidth == 32 ||
*bitWidth == 64;
}

/// Verifies the attributes and the type of atomic memory access operations.
template <typename OpTy>
LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
ArrayRef<AtomicOrdering> unsupportedOrderings) {
if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
if (!isTypeCompatibleWithAtomicOp(valueType,
/*isPointerTypeAllowed=*/true))
return memOp.emitOpError("unsupported type ")
<< valueType << " for atomic access";
if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
return memOp.emitOpError("unsupported ordering '")
<< stringifyAtomicOrdering(memOp.getOrdering()) << "'";
if (!memOp.getAlignment())
return memOp.emitOpError("expected alignment for atomic access");
return success();
}
if (memOp.getSyncscope())
return memOp.emitOpError(
"expected syncscope to be null for non-atomic access");
return success();
}

LogicalResult LoadOp::verify() {
Type valueType = getResult().getType();
return verifyAtomicMemOp(*this, valueType,
{AtomicOrdering::release, AtomicOrdering::acq_rel});
}

void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal, bool isInvariant,
AtomicOrdering ordering, StringRef syncscope) {
build(builder, state, type, addr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal, isInvariant, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
/*tbaa=*/nullptr);
}

//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//

void StoreOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Write::get(), getAddr());
// Volatile operations can have target-specific read-write effects on
// memory besides the one referred to by the pointer operand.
// Similarly, atomic operations that are monotonic or stricter cause
// synchronization that from a language point-of-view, are arbitrary
// read-writes into memory.
if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
getOrdering() != AtomicOrdering::unordered)) {
effects.emplace_back(MemoryEffects::Write::get());
effects.emplace_back(MemoryEffects::Read::get());
}
}

LogicalResult StoreOp::verify() {
Type valueType = getValue().getType();
return verifyAtomicMemOp(*this, valueType,
{AtomicOrdering::acquire, AtomicOrdering::acq_rel});
}

void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
Value addr, unsigned alignment, bool isVolatile,
bool isNonTemporal, AtomicOrdering ordering,
StringRef syncscope) {
build(builder, state, value, addr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
isNonTemporal, ordering,
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2600,93 +2486,39 @@ ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// AtomicRMWOp
// FenceOp
//===----------------------------------------------------------------------===//

void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
AtomicBinOp binOp, Value ptr, Value val,
AtomicOrdering ordering, StringRef syncscope,
unsigned alignment, bool isVolatile) {
build(builder, state, val.getType(), binOp, ptr, val, ordering,
!syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
/*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
}

LogicalResult AtomicRMWOp::verify() {
auto valType = getVal().getType();
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
return emitOpError("expected LLVM IR floating point type");
} else if (getBinOp() == AtomicBinOp::xchg) {
if (!isTypeCompatibleWithAtomicOp(valType, /*isPointerTypeAllowed=*/true))
return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
auto intType = llvm::dyn_cast<IntegerType>(valType);
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64)
return emitOpError("expected LLVM IR integer type");
}

if (static_cast<unsigned>(getOrdering()) <
static_cast<unsigned>(AtomicOrdering::monotonic))
return emitOpError() << "expected at least '"
<< stringifyAtomicOrdering(AtomicOrdering::monotonic)
<< "' ordering";

static ParseResult parseAtomicOrdering(OpAsmParser &parser,
ptr::AtomicOrderingAttr &attr) {
StringRef orderingStr;
auto loc = parser.getCurrentLocation();
if (failed(parser.parseOptionalKeyword(
&orderingStr, {"not_atomic", "unordered", "monotonic", "acquire",
"release", "acq_rel", "seq_cst"}))) {
return parser.emitError(
loc, "expected string or keyword containing one of the following "
"enum values for attribute 'ordering' [not_atomic, unordered, "
"monotonic, acquire, release, acq_rel, seq_cst]");
}
auto ordering = ptr::symbolizeAtomicOrdering(orderingStr);
if (!ordering)
return parser.emitError(loc, "invalid ")
<< "ordering attribute specification: \"" << orderingStr << '"';
attr =
ptr::AtomicOrderingAttr::get(parser.getBuilder().getContext(), *ordering);
return success();
}

//===----------------------------------------------------------------------===//
// AtomicCmpXchgOp
//===----------------------------------------------------------------------===//

/// Returns an LLVM struct type that contains a value type and a boolean type.
static LLVMStructType getValAndBoolStructType(Type valType) {
auto boolType = IntegerType::get(valType.getContext(), 1);
return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
}

void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
Value ptr, Value cmp, Value val,
AtomicOrdering successOrdering,
AtomicOrdering failureOrdering, StringRef syncscope,
unsigned alignment, bool isWeak, bool isVolatile) {
build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
successOrdering, failureOrdering,
!syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
isVolatile, /*access_groups=*/nullptr,
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
static void printAtomicOrdering(OpAsmPrinter &p, Operation *op,
ptr::AtomicOrderingAttr attr) {
p << attr.getValue();
}

LogicalResult AtomicCmpXchgOp::verify() {
auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
if (!ptrType)
return emitOpError("expected LLVM IR pointer type for operand #0");
auto valType = getVal().getType();
if (!isTypeCompatibleWithAtomicOp(valType,
/*isPointerTypeAllowed=*/true))
return emitOpError("unexpected LLVM IR type");
if (getSuccessOrdering() < AtomicOrdering::monotonic ||
getFailureOrdering() < AtomicOrdering::monotonic)
return emitOpError("ordering must be at least 'monotonic'");
if (getFailureOrdering() == AtomicOrdering::release ||
getFailureOrdering() == AtomicOrdering::acq_rel)
return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
return success();
}

//===----------------------------------------------------------------------===//
// FenceOp
//===----------------------------------------------------------------------===//

void FenceOp::build(OpBuilder &builder, OperationState &state,
AtomicOrdering ordering, StringRef syncscope) {
build(builder, state, ordering,
build(builder, state,
ptr::AtomicOrderingAttr::get(builder.getContext(), ordering),
syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
}

Expand Down Expand Up @@ -2803,26 +2635,11 @@ LogicalResult LLVM::BitcastOp::verify() {
// 'llvm.addrspacecast' must be used for this purpose instead.
if (resultType.getAddressSpace() != sourceType.getAddressSpace())
return emitOpError("cannot cast pointers of different address spaces, "
"use 'llvm.addrspacecast' instead");
"use 'ptr.addrspacecast' instead");

return success();
}

//===----------------------------------------------------------------------===//
// Folder for LLVM::AddrSpaceCastOp
//===----------------------------------------------------------------------===//

OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
// addrcast(x : T0, T0) -> x
if (getArg().getType() == getType())
return getArg();
// addrcast(addrcast(x : T0, T1), T0) -> x
if (auto prev = getArg().getDefiningOp<AddrSpaceCastOp>())
if (prev.getArg().getType() == getType())
return prev.getArg();
return {};
}

//===----------------------------------------------------------------------===//
// Folder for LLVM::GEPOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2945,6 +2762,24 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
})
.Default([](Attribute) { return AliasResult::NoAlias; });
}

AliasResult getAlias(Type type, raw_ostream &os) const override {
return TypeSwitch<Type, AliasResult>(type)
.Case<LLVMPointerType>(
[&](LLVMPointerType type) { return AliasResult::DialectAlias; })
.Default([](Type) { return AliasResult::NoAlias; });
}

LogicalResult printDialectAlias(DialectAsmPrinter &printer,
Type type) const override {
return TypeSwitch<Type, LogicalResult>(type)
.Case<LLVMPointerType>([&](LLVMPointerType type) {
printer << "ptr";
type.print(printer);
return success();
})
.Default([](Type) { return failure(); });
}
};
} // namespace

Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "LLVMInlining.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
Expand Down Expand Up @@ -188,7 +189,7 @@ deepCloneAliasScopes(iterator_range<Region::iterator> inlinedBlocks) {

for (Block &block : inlinedBlocks) {
for (Operation &op : block) {
if (auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
if (auto aliasInterface = dyn_cast<ptr::AliasAnalysisOpInterface>(op)) {
aliasInterface.setAliasScopes(
convertScopeList(aliasInterface.getAliasScopesOrNull()));
aliasInterface.setNoAliasScopes(
Expand Down Expand Up @@ -358,7 +359,7 @@ static void createNewAliasScopesFromNoAliasParameter(
// it is definitely based on and definitely not based on.
for (Block &inlinedBlock : inlinedBlocks) {
for (auto aliasInterface :
inlinedBlock.getOps<LLVM::AliasAnalysisOpInterface>()) {
inlinedBlock.getOps<ptr::AliasAnalysisOpInterface>()) {

// Collect the pointer arguments affected by the alias scopes.
SmallVector<Value> pointerArgs = aliasInterface.getAccessedOperands();
Expand Down Expand Up @@ -458,7 +459,7 @@ static void createNewAliasScopesFromNoAliasParameter(
static void
appendCallOpAliasScopes(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
auto callAliasInterface = dyn_cast<ptr::AliasAnalysisOpInterface>(call);
if (!callAliasInterface)
return;

Expand All @@ -472,7 +473,7 @@ appendCallOpAliasScopes(Operation *call,
// Simply append the call op's alias and noalias scopes to any operation
// implementing AliasAnalysisOpInterface.
for (Block &block : inlinedBlocks) {
for (auto aliasInterface : block.getOps<LLVM::AliasAnalysisOpInterface>()) {
for (auto aliasInterface : block.getOps<ptr::AliasAnalysisOpInterface>()) {
if (aliasScopes)
aliasInterface.setAliasScopes(concatArrayAttr(
aliasInterface.getAliasScopesOrNull(), aliasScopes));
Expand All @@ -496,7 +497,7 @@ static void handleAliasScopes(Operation *call,
/// operation.
static void handleAccessGroups(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
auto callAccessGroupInterface = dyn_cast<ptr::AccessGroupOpInterface>(call);
if (!callAccessGroupInterface)
return;

Expand All @@ -508,7 +509,7 @@ static void handleAccessGroups(Operation *call,
// AccessGroupOpInterface.
for (Block &block : inlinedBlocks)
for (auto accessGroupOpInterface :
block.getOps<LLVM::AccessGroupOpInterface>())
block.getOps<ptr::AccessGroupOpInterface>())
accessGroupOpInterface.setAccessGroups(concatArrayAttr(
accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
}
Expand Down
52 changes: 0 additions & 52 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,6 @@ static LogicalResult isArrayOf(Operation *op, ArrayAttr array) {
return success();
}

//===----------------------------------------------------------------------===//
// AccessGroupOpInterface
//===----------------------------------------------------------------------===//

LogicalResult mlir::LLVM::detail::verifyAccessGroupOpInterface(Operation *op) {
auto iface = cast<AccessGroupOpInterface>(op);
ArrayAttr accessGroups = iface.getAccessGroupsOrNull();
if (!accessGroups)
return success();

return isArrayOf<AccessGroupAttr>(op, accessGroups);
}

//===----------------------------------------------------------------------===//
// AliasAnalysisOpInterface
//===----------------------------------------------------------------------===//

LogicalResult
mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
auto iface = cast<AliasAnalysisOpInterface>(op);

if (auto aliasScopes = iface.getAliasScopesOrNull())
if (failed(isArrayOf<AliasScopeAttr>(op, aliasScopes)))
return failure();

if (auto noAliasScopes = iface.getNoAliasScopesOrNull())
if (failed(isArrayOf<AliasScopeAttr>(op, noAliasScopes)))
return failure();

ArrayAttr tags = iface.getTBAATagsOrNull();
if (!tags)
return success();

return isArrayOf<TBAATagAttr>(op, tags);
}

SmallVector<Value> mlir::LLVM::AtomicCmpXchgOp::getAccessedOperands() {
return {getPtr()};
}

SmallVector<Value> mlir::LLVM::AtomicRMWOp::getAccessedOperands() {
return {getPtr()};
}

SmallVector<Value> mlir::LLVM::LoadOp::getAccessedOperands() {
return {getAddr()};
}

SmallVector<Value> mlir::LLVM::StoreOp::getAccessedOperands() {
return {getAddr()};
}

SmallVector<Value> mlir::LLVM::MemcpyOp::getAccessedOperands() {
return {getDst(), getSrc()};
}
Expand Down
207 changes: 12 additions & 195 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,189 +100,6 @@ void LLVM::AllocaOp::handleDestructuringComplete(
rewriter.eraseOp(*this);
}

//===----------------------------------------------------------------------===//
// Interfaces for LoadOp/StoreOp
//===----------------------------------------------------------------------===//

bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}

bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }

Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
llvm_unreachable("getStored should not be called on LoadOp");
}

bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }

bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}

Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
return getValue();
}

bool LLVM::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
// If the blocking use is the slot ptr itself, there will be enough
// context to reconstruct the result of the load at removal time, so it can
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getResult().getType() == slot.elemType && !getVolatile_();
}

DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
return DeletionKind::Delete;
}

bool LLVM::StoreOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
// If the blocking use is the slot ptr itself, dropping the store is
// fine, provided we are currently promoting its target value. Don't allow a
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
!getVolatile_();
}

DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the stored slot
// pointer.
for (Operation *user : slot.ptr.getUsers())
if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
declareOp.getVarInfo(),
declareOp.getLocationExpr());
return DeletionKind::Delete;
}

/// Checks if `slot` can be accessed through the provided access type.
static bool isValidAccessType(const MemorySlot &slot, Type accessType,
const DataLayout &dataLayout) {
return dataLayout.getTypeSize(accessType) <=
dataLayout.getTypeSize(slot.elemType);
}

LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
isValidAccessType(slot, getType(), dataLayout));
}

LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
isValidAccessType(slot, getValue().getType(), dataLayout));
}

/// Returns the subslot's type at the requested index.
static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
Attribute index) {
auto subelementIndexMap =
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
if (!subelementIndexMap)
return {};
assert(!subelementIndexMap->empty());

// Note: Returns a null-type when no entry was found.
return subelementIndexMap->lookup(index);
}

bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (getVolatile_())
return false;

// A load always accesses the first element of the destructured slot.
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
Type subslotType = getTypeAtIndex(slot, index);
if (!subslotType)
return false;

// The access can only be replaced when the subslot is read within its bounds.
if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
return false;

usedIndices.insert(index);
return true;
}

DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter,
const DataLayout &dataLayout) {
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
auto it = subslots.find(index);
assert(it != subslots.end());

rewriter.modifyOpInPlace(
*this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
return DeletionKind::Keep;
}

bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (getVolatile_())
return false;

// Storing the pointer to memory cannot be dealt with.
if (getValue() == slot.ptr)
return false;

// A store always accesses the first element of the destructured slot.
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
Type subslotType = getTypeAtIndex(slot, index);
if (!subslotType)
return false;

// The access can only be replaced when the subslot is read within its bounds.
if (dataLayout.getTypeSize(getValue().getType()) >
dataLayout.getTypeSize(subslotType))
return false;

usedIndices.insert(index);
return true;
}

DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
RewriterBase &rewriter,
const DataLayout &dataLayout) {
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
auto it = subslots.find(index);
assert(it != subslots.end());

rewriter.modifyOpInPlace(
*this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
return DeletionKind::Keep;
}

//===----------------------------------------------------------------------===//
// Interfaces for discardable OPs
//===----------------------------------------------------------------------===//
Expand All @@ -308,18 +125,6 @@ DeletionKind LLVM::BitcastOp::removeBlockingUses(
return DeletionKind::Delete;
}

bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return forwardToUsers(*this, newBlockingUses);
}

DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
return DeletionKind::Delete;
}

bool LLVM::LifetimeStartOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
Expand Down Expand Up @@ -380,6 +185,18 @@ DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
return DeletionKind::Delete;
}

bool LLVM::DbgDeclareOp::requiresVisitingMutatedDefs() { return true; }

void LLVM::DbgDeclareOp::visitMutatedDefs(
ArrayRef<std::pair<Operation *, Value>> definitions,
RewriterBase &rewriter) {
for (auto [op, value] : definitions) {
rewriter.setInsertionPointAfter(op);
rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
getLocationExpr());
}
}

bool LLVM::DbgValueOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
Expand Down
178 changes: 29 additions & 149 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,162 +254,42 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
}

//===----------------------------------------------------------------------===//
// DataLayoutTypeInterface

constexpr const static uint64_t kDefaultPointerSizeBits = 64;
constexpr const static uint64_t kDefaultPointerAlignment = 8;

std::optional<uint64_t> mlir::LLVM::extractPointerSpecValue(Attribute attr,
PtrDLEntryPos pos) {
auto spec = cast<DenseIntElementsAttr>(attr);
auto idx = static_cast<int64_t>(pos);
if (idx >= spec.size())
return std::nullopt;
return spec.getValues<uint64_t>()[idx];
}
// LLVMPointerType
//===----------------------------------------------------------------------===//

/// Returns the part of the data layout entry that corresponds to `pos` for the
/// given `type` by interpreting the list of entries `params`. For the pointer
/// type in the default address space, returns the default value if the entries
/// do not provide a custom one, for other address spaces returns std::nullopt.
static std::optional<uint64_t>
getPointerDataLayoutEntry(DataLayoutEntryListRef params, LLVMPointerType type,
PtrDLEntryPos pos) {
// First, look for the entry for the pointer in the current address space.
Attribute currentEntry;
for (DataLayoutEntryInterface entry : params) {
if (!entry.isTypeEntry())
continue;
if (cast<LLVMPointerType>(entry.getKey().get<Type>()).getAddressSpace() ==
type.getAddressSpace()) {
currentEntry = entry.getValue();
break;
mlir::ptr::PtrType LLVMPointerType::get(MLIRContext *context,
unsigned addressSpace) {
return ptr::PtrType::get(context,
AddressSpaceAttr::get(context, addressSpace));
}

Type LLVMPointerType::parse(AsmParser &odsParser) {
FailureOr<unsigned> addressSpace;
// Parse literal '<'
if (!odsParser.parseOptionalLess()) {
if (failed(addressSpace = FieldParser<unsigned>::parse(odsParser))) {
odsParser.emitError(odsParser.getCurrentLocation(),
"failed to parse LLVMPtrType parameter 'memorySpace' "
"which is to be a `unsigned`");
return {};
}
// Parse literal '>'
if (odsParser.parseGreater())
return {};
}
if (currentEntry) {
std::optional<uint64_t> value = extractPointerSpecValue(currentEntry, pos);
// If the optional `PtrDLEntryPos::Index` entry is not available, use the
// pointer size as the index bitwidth.
if (!value && pos == PtrDLEntryPos::Index)
value = extractPointerSpecValue(currentEntry, PtrDLEntryPos::Size);
bool isSizeOrIndex =
pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
return *value / (isSizeOrIndex ? 1 : kBitsInByte);
}

// If not found, and this is the pointer to the default memory space, assume
// 64-bit pointers.
if (type.getAddressSpace() == 0) {
bool isSizeOrIndex =
pos == PtrDLEntryPos::Size || pos == PtrDLEntryPos::Index;
return isSizeOrIndex ? kDefaultPointerSizeBits : kDefaultPointerAlignment;
}

return std::nullopt;
}

llvm::TypeSize
LLVMPointerType::getTypeSizeInBits(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
if (std::optional<uint64_t> size =
getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Size))
return llvm::TypeSize::getFixed(*size);

// For other memory spaces, use the size of the pointer to the default memory
// space.
return dataLayout.getTypeSizeInBits(get(getContext()));
}

uint64_t LLVMPointerType::getABIAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
if (std::optional<uint64_t> alignment =
getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Abi))
return *alignment;

return dataLayout.getTypeABIAlignment(get(getContext()));
}

uint64_t
LLVMPointerType::getPreferredAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
if (std::optional<uint64_t> alignment =
getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Preferred))
return *alignment;

return dataLayout.getTypePreferredAlignment(get(getContext()));
}

std::optional<uint64_t>
LLVMPointerType::getIndexBitwidth(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
if (std::optional<uint64_t> indexBitwidth =
getPointerDataLayoutEntry(params, *this, PtrDLEntryPos::Index))
return *indexBitwidth;

return dataLayout.getTypeIndexBitwidth(get(getContext()));
return LLVMPointerType::get(odsParser.getContext(), addressSpace.value_or(0));
}

bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
DataLayoutEntryListRef newLayout) const {
for (DataLayoutEntryInterface newEntry : newLayout) {
if (!newEntry.isTypeEntry())
continue;
uint64_t size = kDefaultPointerSizeBits;
uint64_t abi = kDefaultPointerAlignment;
auto newType = llvm::cast<LLVMPointerType>(newEntry.getKey().get<Type>());
const auto *it =
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() ==
newType.getAddressSpace();
}
return false;
});
if (it == oldLayout.end()) {
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = llvm::dyn_cast_if_present<Type>(entry.getKey())) {
return llvm::cast<LLVMPointerType>(type).getAddressSpace() == 0;
}
return false;
});
}
if (it != oldLayout.end()) {
size = *extractPointerSpecValue(*it, PtrDLEntryPos::Size);
abi = *extractPointerSpecValue(*it, PtrDLEntryPos::Abi);
}

Attribute newSpec = llvm::cast<DenseIntElementsAttr>(newEntry.getValue());
uint64_t newSize = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Size);
uint64_t newAbi = *extractPointerSpecValue(newSpec, PtrDLEntryPos::Abi);
if (size != newSize || abi < newAbi || abi % newAbi != 0)
return false;
}
return true;
void LLVMPointerType::print(AsmPrinter &odsPrinter) const {
if (unsigned as = getAddressSpace(); as != 0)
odsPrinter << "<" << as << ">";
}

LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
Location loc) const {
for (DataLayoutEntryInterface entry : entries) {
if (!entry.isTypeEntry())
continue;
auto key = entry.getKey().get<Type>();
auto values = llvm::dyn_cast<DenseIntElementsAttr>(entry.getValue());
if (!values || (values.size() != 3 && values.size() != 4)) {
return emitError(loc)
<< "expected layout attribute for " << key
<< " to be a dense integer elements attribute with 3 or 4 "
"elements";
}
if (!values.getElementType().isInteger(64))
return emitError(loc) << "expected i64 parameters for " << key;

if (extractPointerSpecValue(values, PtrDLEntryPos::Abi) >
extractPointerSpecValue(values, PtrDLEntryPos::Preferred)) {
return emitError(loc) << "preferred alignment is expected to be at least "
"as large as ABI alignment";
}
}
return success();
bool mlir::LLVM::isLLVMPointerType(Type type) {
if (auto ptrTy = mlir::dyn_cast<ptr::PtrType>(type))
return ptrTy.getMemorySpace() &&
mlir::isa<AddressSpaceAttr>(ptrTy.getMemorySpace());
return false;
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct MemRefPointerLikeModel

struct LLVMPointerPointerLikeModel
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
LLVM::LLVMPointerType> {
ptr::PtrType> {
Type getElementType(Type pointer) const { return Type(); }
};
} // namespace
Expand All @@ -65,8 +65,7 @@ void OpenACCDialect::initialize() {
// the other dialects. This is probably better than having dialects like LLVM
// and memref be dependent on OpenACC.
MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
ptr::PtrType::attachInterface<LLVMPointerPointerLikeModel>(*getContext());
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct MemRefPointerLikeModel

struct LLVMPointerPointerLikeModel
: public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
LLVM::LLVMPointerType> {
ptr::PtrType> {
Type getElementType(Type pointer) const { return Type(); }
};

Expand Down Expand Up @@ -82,8 +82,7 @@ void OpenMPDialect::initialize() {

addInterface<OpenMPDialectFoldInterface>();
MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
*getContext());
ptr::PtrType::attachInterface<LLVMPointerPointerLikeModel>(*getContext());

// Attach default offload module interface to module op to access
// offload functionality through
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Ptr/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_dialect_library(
MLIRPtrDialect
PtrTypes.cpp
PtrDialect.cpp
PtrMemorySlot.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/mlir/Dialect/Pointer
DEPENDS
MLIRPtrOpsIncGen
MLIRPtrOpsEnumsGen
MLIRPtrOpsAttributesIncGen
MLIRPtrMemorySpaceInterfacesIncGen
LINK_LIBS
PUBLIC
MLIRIR
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
)
Loading