Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][mlir][llvm] support new-struct-path-tbaa #119698

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

PikachuHyA
Copy link
Contributor

No description provided.

@PikachuHyA PikachuHyA marked this pull request as draft December 12, 2024 13:01
@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: PikachuHy (PikachuHyA)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/119698.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td (+77-1)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+1-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+2-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp (+9-1)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+36-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index e8eeafd09a9cba..198e1f8982ef14 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1080,8 +1080,84 @@ def LLVM_TBAATagAttr : LLVM_Attr<"TBAATag", "tbaa_tag"> {
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def LLVM_TBAAStructFieldAttr : LLVM_Attr<"TBAAStructField", "tbaa_struct_field"> {
+  let parameters = (ins
+    "TBAANodeAttr":$typeDesc,
+    "int64_t":$offset,
+    "int64_t":$size
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+
+def LLVM_TBAAStructFieldAttrArray : ArrayRefParameter<"TBAAStructFieldAttr"> {
+  let printer = [{
+    $_printer << '{';
+    llvm::interleaveComma($_self, $_printer, [&](TBAAStructFieldAttr attr) {
+        $_printer.printStrippedAttrOrType(attr);
+    });
+    $_printer << '}';
+  }];
+
+  let parser = [{
+    [&]() -> FailureOr<SmallVector<TBAAStructFieldAttr>> {
+        using Result = SmallVector<TBAAStructFieldAttr>;
+        if ($_parser.parseLBrace())
+            return failure();
+        FailureOr<Result> result = FieldParser<Result>::parse($_parser);
+        if (failed(result))
+            return failure();
+        if ($_parser.parseRBrace())
+            return failure();
+        return result;
+    }()
+  }];
+}
+
+def LLVM_TBAATypeNodeAttr : LLVM_Attr<"TBAATypeNode", "tbaa_type_node", [], "TBAANodeAttr"> {
+  let parameters = (ins
+    "TBAANodeAttr":$parent,
+    "int64_t":$size,
+    StringRefParameter<>:$id,
+    LLVM_TBAAStructFieldAttrArray:$fields
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagAttr : LLVM_Attr<"TBAAAccessTag", "tbaa_access_tag"> {
+  let parameters = (ins
+    "TBAATypeNodeAttr":$base_type,
+    "TBAATypeNodeAttr":$access_type,
+    "int64_t":$offset,
+    "int64_t":$size
+  );
+  let builders = [
+    AttrBuilderWithInferredContext<(ins "TBAATypeNodeAttr":$baseType,
+                                        "TBAATypeNodeAttr":$accessType,
+                                        "int64_t":$offset,
+                                        "int64_t":$size), [{
+      return $_get(baseType.getContext(), baseType, accessType, offset, size);
+    }]>
+  ];
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def LLVM_TBAAAccessTagArrayAttr
+    : TypedArrayAttrBase<LLVM_TBAAAccessTagAttr,
+                         LLVM_TBAAAccessTagAttr.summary # " array"> {
+  let constBuilderCall = ?;
+}
+
+// def LLVM_TBAATagAttr2 : AnyAttrOf<[
+//   LLVM_TBAATagAttr,
+//   LLVM_TBAAAccessTagAttr
+// ]>;
+
 def LLVM_TBAATagArrayAttr
-    : TypedArrayAttrBase<LLVM_TBAATagAttr,
+    : TypedArrayAttrBase<AnyAttrOf<[
+  LLVM_TBAATagAttr,
+  LLVM_TBAAAccessTagAttr
+]>,
                          LLVM_TBAATagAttr.summary # " array"> {
   let constBuilderCall = ?;
 }
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..c7a79aa330d3da 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -323,7 +323,7 @@ class ModuleTranslation {
 
   /// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
   /// TBAATagAttr.
-  llvm::MDNode *getTBAANode(TBAATagAttr tbaaAttr) const;
+  llvm::MDNode *getTBAANode(Attribute tbaaAttr) const;
 
   /// Process tbaa LLVM Metadata operations and create LLVM
   /// metadata nodes for them.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6b2d8943bf4885..b2b0b9b331e0b4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3401,7 +3401,8 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
               LoopVectorizeAttr, LoopInterleaveAttr, LoopUnrollAttr,
               LoopUnrollAndJamAttr, LoopLICMAttr, LoopDistributeAttr,
               LoopPipelineAttr, LoopPeeledAttr, LoopUnswitchAttr, TBAARootAttr,
-              TBAATagAttr, TBAATypeDescriptorAttr>([&](auto attr) {
+              TBAATagAttr, TBAATypeDescriptorAttr, TBAAAccessTagAttr,
+              TBAATypeNodeAttr>([&](auto attr) {
           os << decltype(attr)::getMnemonic();
           return AliasResult::OverridableAlias;
         })
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
index cff16afc73af3f..6a9395b1f4a26e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInterfaces.cpp
@@ -58,7 +58,15 @@ mlir::LLVM::detail::verifyAliasAnalysisOpInterface(Operation *op) {
   ArrayAttr tags = iface.getTBAATagsOrNull();
   if (!tags)
     return success();
-
+  if (tags.size() > 0) {
+    if (mlir::isa<TBAATagAttr>(tags[0])) {
+      return isArrayOf<TBAATagAttr>(op, tags);
+    }
+
+    if (mlir::isa<TBAAAccessTagAttr>(tags[0])) {
+      return isArrayOf<TBAAAccessTagAttr>(op, tags);
+    }
+  }
   return isArrayOf<TBAATagAttr>(op, tags);
 }
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..6a6c29869ba805 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1766,7 +1766,8 @@ void ModuleTranslation::setAliasScopeMetadata(AliasAnalysisOpInterface op,
                         llvm::LLVMContext::MD_noalias);
 }
 
-llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+// llvm::MDNode *ModuleTranslation::getTBAANode(TBAATagAttr tbaaAttr) const {
+llvm::MDNode *ModuleTranslation::getTBAANode(Attribute tbaaAttr) const {
   return tbaaMetadataMapping.lookup(tbaaAttr);
 }
 
@@ -1786,7 +1787,8 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
     return;
   }
 
-  llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+  // llvm::MDNode *node = getTBAANode(cast<TBAATagAttr>(tagRefs[0]));
+  llvm::MDNode *node = getTBAANode(tagRefs[0]);
   inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
 }
 
@@ -1806,6 +1808,7 @@ void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
 LogicalResult ModuleTranslation::createTBAAMetadata() {
   llvm::LLVMContext &ctx = llvmModule->getContext();
   llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
+  llvm::IntegerType *sizeTy = llvm::IntegerType::get(ctx, 64);
 
   // Walk the entire module and create all metadata nodes for the TBAA
   // attributes. The code below relies on two invariants of the
@@ -1833,6 +1836,23 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
     tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
   });
 
+  walker.addWalk([&](TBAATypeNodeAttr descriptor) {
+    SmallVector<llvm::Metadata *> operands;
+    operands.push_back(tbaaMetadataMapping.lookup(descriptor.getParent()));
+    operands.push_back(llvm::ConstantAsMetadata::get(
+        llvm::ConstantInt::get(sizeTy, descriptor.getSize())));
+    operands.push_back(llvm::MDString::get(ctx, descriptor.getId()));
+    for (auto field : descriptor.getFields()) {
+      operands.push_back(tbaaMetadataMapping.lookup(field.getTypeDesc()));
+      operands.push_back(llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(offsetTy, field.getOffset())));
+      operands.push_back(llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(sizeTy, field.getSize())));
+    }
+
+    tbaaMetadataMapping.insert({descriptor, llvm::MDNode::get(ctx, operands)});
+  });
+
   walker.addWalk([&](TBAATagAttr tag) {
     SmallVector<llvm::Metadata *> operands;
 
@@ -1848,6 +1868,20 @@ LogicalResult ModuleTranslation::createTBAAMetadata() {
     tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
   });
 
+  walker.addWalk([&](TBAAAccessTagAttr tag) {
+    SmallVector<llvm::Metadata *> operands;
+
+    operands.push_back(tbaaMetadataMapping.lookup(tag.getBaseType()));
+    operands.push_back(tbaaMetadataMapping.lookup(tag.getAccessType()));
+
+    operands.push_back(llvm::ConstantAsMetadata::get(
+        llvm::ConstantInt::get(offsetTy, tag.getOffset())));
+    operands.push_back(llvm::ConstantAsMetadata::get(
+        llvm::ConstantInt::get(sizeTy, tag.getSize())));
+
+    tbaaMetadataMapping.insert({tag, llvm::MDNode::get(ctx, operands)});
+  });
+
   mlirModule->walk([&](AliasAnalysisOpInterface analysisOpInterface) {
     if (auto attr = analysisOpInterface.getTBAATagsOrNull())
       walker.walk(attr);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants