Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[memprof] Add access checks to PortableMemInfoBlock::get* #90121

Merged
merged 2 commits into from
Apr 28, 2024

Conversation

kazutakahirata
Copy link
Contributor

commit 4c8ec8f
Author: Kazu Hirata kazu@google.com
Date: Wed Apr 24 16:25:35 2024 -0700

introduced the idea of serializing/deserializing a subset of the
fields in PortableMemInfoBlock. While it reduces the size of the
indexed MemProf profile file, we now could inadvertently access
unavailable fields and go without noticing.

To protect ourselves from the risk, this patch adds access checks to
PortableMemInfoBlock::get* methods by embedding a bit set representing
available fields into PortableMemInfoBlock.

@llvmbot llvmbot added the PGO Profile Guided Optimizations label Apr 25, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-pgo

Author: Kazu Hirata (kazutakahirata)

Changes

commit 4c8ec8f
Author: Kazu Hirata <kazu@google.com>
Date: Wed Apr 24 16:25:35 2024 -0700

introduced the idea of serializing/deserializing a subset of the
fields in PortableMemInfoBlock. While it reduces the size of the
indexed MemProf profile file, we now could inadvertently access
unavailable fields and go without noticing.

To protect ourselves from the risk, this patch adds access checks to
PortableMemInfoBlock::get* methods by embedding a bit set representing
available fields into PortableMemInfoBlock.


Full diff: https://github.com/llvm/llvm-project/pull/90121.diff

3 Files Affected:

  • (modified) llvm/include/llvm/ProfileData/MemProf.h (+27-7)
  • (modified) llvm/unittests/ProfileData/InstrProfTest.cpp (+4-4)
  • (modified) llvm/unittests/ProfileData/MemProfTest.cpp (+1-1)
diff --git a/llvm/include/llvm/ProfileData/MemProf.h b/llvm/include/llvm/ProfileData/MemProf.h
index d378c3696f8d0b..e59c6b6b02f141 100644
--- a/llvm/include/llvm/ProfileData/MemProf.h
+++ b/llvm/include/llvm/ProfileData/MemProf.h
@@ -2,6 +2,7 @@
 #define LLVM_PROFILEDATA_MEMPROF_H_
 
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLForwardCompat.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/GlobalValue.h"
@@ -10,6 +11,7 @@
 #include "llvm/Support/EndianStream.h"
 #include "llvm/Support/raw_ostream.h"
 
+#include <bitset>
 #include <cstdint>
 #include <optional>
 
@@ -55,7 +57,10 @@ MemProfSchema getHotColdSchema();
 // deserialize methods.
 struct PortableMemInfoBlock {
   PortableMemInfoBlock() = default;
-  explicit PortableMemInfoBlock(const MemInfoBlock &Block) {
+  explicit PortableMemInfoBlock(const MemInfoBlock &Block,
+                                const MemProfSchema &IncomingSchema) {
+    for (const Meta Id : IncomingSchema)
+      Schema.set(llvm::to_underlying(Id));
 #define MIBEntryDef(NameTag, Name, Type) Name = Block.Name;
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
@@ -67,10 +72,12 @@ struct PortableMemInfoBlock {
 
   // Read the contents of \p Ptr based on the \p Schema to populate the
   // MemInfoBlock member.
-  void deserialize(const MemProfSchema &Schema, const unsigned char *Ptr) {
+  void deserialize(const MemProfSchema &IncomingSchema,
+                   const unsigned char *Ptr) {
     using namespace support;
 
-    for (const Meta Id : Schema) {
+    Schema.reset();
+    for (const Meta Id : IncomingSchema) {
       switch (Id) {
 #define MIBEntryDef(NameTag, Name, Type)                                       \
   case Meta::Name: {                                                           \
@@ -82,6 +89,8 @@ struct PortableMemInfoBlock {
         llvm_unreachable("Unknown meta type id, is the profile collected from "
                          "a newer version of the runtime?");
       }
+
+      Schema.set(llvm::to_underlying(Id));
     }
   }
 
@@ -116,15 +125,22 @@ struct PortableMemInfoBlock {
 
   // Define getters for each type which can be called by analyses.
 #define MIBEntryDef(NameTag, Name, Type)                                       \
-  Type get##Name() const { return Name; }
+  Type get##Name() const {                                                     \
+    assert(Schema[llvm::to_underlying(Meta::Name)]);                           \
+    return Name;                                                               \
+  }
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
 
   void clear() { *this = PortableMemInfoBlock(); }
 
   bool operator==(const PortableMemInfoBlock &Other) const {
+    if (Other.Schema != Schema)
+      return false;
+
 #define MIBEntryDef(NameTag, Name, Type)                                       \
-  if (Other.get##Name() != get##Name())                                        \
+  if (Schema[llvm::to_underlying(Meta::Name)] &&                               \
+      Other.get##Name() != get##Name())                                        \
     return false;
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
@@ -155,6 +171,9 @@ struct PortableMemInfoBlock {
   }
 
 private:
+  // The set of available fields, indexed by Meta::Name.
+  std::bitset<llvm::to_underlying(Meta::Size)> Schema;
+
 #define MIBEntryDef(NameTag, Name, Type) Type Name = Type();
 #include "llvm/ProfileData/MIBEntryDef.inc"
 #undef MIBEntryDef
@@ -296,8 +315,9 @@ struct IndexedAllocationInfo {
 
   IndexedAllocationInfo() = default;
   IndexedAllocationInfo(ArrayRef<FrameId> CS, CallStackId CSId,
-                        const MemInfoBlock &MB)
-      : CallStack(CS.begin(), CS.end()), CSId(CSId), Info(MB) {}
+                        const MemInfoBlock &MB,
+                        const MemProfSchema &Schema = getFullSchema())
+      : CallStack(CS.begin(), CS.end()), CSId(CSId), Info(MB, Schema) {}
 
   // Returns the size in bytes when this allocation info struct is serialized.
   size_t serializedSize(const MemProfSchema &Schema,
diff --git a/llvm/unittests/ProfileData/InstrProfTest.cpp b/llvm/unittests/ProfileData/InstrProfTest.cpp
index edc427dcbc4540..1b0ee6b8cdab98 100644
--- a/llvm/unittests/ProfileData/InstrProfTest.cpp
+++ b/llvm/unittests/ProfileData/InstrProfTest.cpp
@@ -407,13 +407,13 @@ IndexedMemProfRecord makeRecord(
 IndexedMemProfRecord
 makeRecordV2(std::initializer_list<::llvm::memprof::CallStackId> AllocFrames,
              std::initializer_list<::llvm::memprof::CallStackId> CallSiteFrames,
-             const MemInfoBlock &Block) {
+             const MemInfoBlock &Block, const memprof::MemProfSchema &Schema) {
   llvm::memprof::IndexedMemProfRecord MR;
   for (const auto &CSId : AllocFrames)
     // We don't populate IndexedAllocationInfo::CallStack because we use it only
     // in Version0 and Version1.
     MR.AllocSites.emplace_back(::llvm::SmallVector<memprof::FrameId>(), CSId,
-                               Block);
+                               Block, Schema);
   for (const auto &CSId : CallSiteFrames)
     MR.CallSiteIds.push_back(CSId);
   return MR;
@@ -544,7 +544,7 @@ TEST_F(InstrProfTest, test_memprof_v2_full_schema) {
 
   const IndexedMemProfRecord IndexedMR = makeRecordV2(
       /*AllocFrames=*/{0x111, 0x222},
-      /*CallSiteFrames=*/{0x333}, MIB);
+      /*CallSiteFrames=*/{0x333}, MIB, memprof::getFullSchema());
   const FrameIdMapTy IdToFrameMap = getFrameMapping();
   const auto CSIdToCallStackMap = getCallStackMapping();
   for (const auto &I : IdToFrameMap) {
@@ -584,7 +584,7 @@ TEST_F(InstrProfTest, test_memprof_v2_partial_schema) {
 
   const IndexedMemProfRecord IndexedMR = makeRecordV2(
       /*AllocFrames=*/{0x111, 0x222},
-      /*CallSiteFrames=*/{0x333}, MIB);
+      /*CallSiteFrames=*/{0x333}, MIB, memprof::getHotColdSchema());
   const FrameIdMapTy IdToFrameMap = getFrameMapping();
   const auto CSIdToCallStackMap = getCallStackMapping();
   for (const auto &I : IdToFrameMap) {
diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp
index 503901094ba9a5..2e881adcefcec5 100644
--- a/llvm/unittests/ProfileData/MemProfTest.cpp
+++ b/llvm/unittests/ProfileData/MemProfTest.cpp
@@ -241,7 +241,7 @@ TEST(MemProf, PortableWrapper) {
                     /*dealloc_cpu=*/4);
 
   const auto Schema = llvm::memprof::getFullSchema();
-  PortableMemInfoBlock WriteBlock(Info);
+  PortableMemInfoBlock WriteBlock(Info, Schema);
 
   std::string Buffer;
   llvm::raw_string_ostream OS(Buffer);

Copy link
Contributor

@snehasish snehasish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also extend the new unit tests to check that the expected bits were set?

Copy link
Contributor

@teresajohnson teresajohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! lgtm with Snehasish's suggestion

@kazutakahirata
Copy link
Contributor Author

Can you also extend the new unit tests to check that the expected bits were set?

I just added a unit test. Please take a look. Thanks!

Copy link
Contributor

@snehasish snehasish left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

  commit 4c8ec8f
  Author: Kazu Hirata <kazu@google.com>
  Date:   Wed Apr 24 16:25:35 2024 -0700

introduced the idea of serializing/deserializing a subset of the
fields in PortableMemInfoBlock.  While it reduces the size of the
indexed MemProf profile file, we now could inadvertently access
unavailable fields and go without noticing.

To protect ourselves from the risk, this patch adds access checks to
PortableMemInfoBlock::get* methods by embedding a bit set representing
available fields into PortableMemInfoBlock.
@kazutakahirata kazutakahirata merged commit c9dae43 into llvm:main Apr 28, 2024
3 of 4 checks passed
@kazutakahirata kazutakahirata deleted the pr_memprof_mib_get branch April 28, 2024 19:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PGO Profile Guided Optimizations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants