Skip to content

Commit

Permalink
[AMDGPU] Switched HSA metadata to use MsgPackDocument
Browse files Browse the repository at this point in the history
Summary:
MsgPackDocument is the lighter-weight replacement for MsgPackTypes. This
commit switches AMDGPU HSA metadata processing to use MsgPackDocument
instead of MsgPackTypes.

Differential Revision: https://reviews.llvm.org/D57024

Change-Id: I0751668013abe8c87db01db1170831a76079b3a6
llvm-svn: 356081
  • Loading branch information
Tim Renouf committed Mar 13, 2019
1 parent 4ced8de commit ed0b9af
Show file tree
Hide file tree
Showing 17 changed files with 1,730 additions and 1,714 deletions.
30 changes: 15 additions & 15 deletions llvm/include/llvm/BinaryFormat/AMDGPUMetadataVerifier.h
Expand Up @@ -16,7 +16,7 @@
#ifndef LLVM_BINARYFORMAT_AMDGPUMETADATAVERIFIER_H
#define LLVM_BINARYFORMAT_AMDGPUMETADATAVERIFIER_H

#include "llvm/BinaryFormat/MsgPackTypes.h"
#include "llvm/BinaryFormat/MsgPackDocument.h"

namespace llvm {
namespace AMDGPU {
Expand All @@ -33,22 +33,22 @@ namespace V3 {
class MetadataVerifier {
bool Strict;

bool verifyScalar(msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind,
function_ref<bool(msgpack::ScalarNode &)> verifyValue = {});
bool verifyInteger(msgpack::Node &Node);
bool verifyArray(msgpack::Node &Node,
function_ref<bool(msgpack::Node &)> verifyNode,
bool verifyScalar(msgpack::DocNode &Node, msgpack::Type SKind,
function_ref<bool(msgpack::DocNode &)> verifyValue = {});
bool verifyInteger(msgpack::DocNode &Node);
bool verifyArray(msgpack::DocNode &Node,
function_ref<bool(msgpack::DocNode &)> verifyNode,
Optional<size_t> Size = None);
bool verifyEntry(msgpack::MapNode &MapNode, StringRef Key, bool Required,
function_ref<bool(msgpack::Node &)> verifyNode);
bool verifyEntry(msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
function_ref<bool(msgpack::DocNode &)> verifyNode);
bool
verifyScalarEntry(msgpack::MapNode &MapNode, StringRef Key, bool Required,
msgpack::ScalarNode::ScalarKind SKind,
function_ref<bool(msgpack::ScalarNode &)> verifyValue = {});
bool verifyIntegerEntry(msgpack::MapNode &MapNode, StringRef Key,
verifyScalarEntry(msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
msgpack::Type SKind,
function_ref<bool(msgpack::DocNode &)> verifyValue = {});
bool verifyIntegerEntry(msgpack::MapDocNode &MapNode, StringRef Key,
bool Required);
bool verifyKernelArgs(msgpack::Node &Node);
bool verifyKernel(msgpack::Node &Node);
bool verifyKernelArgs(msgpack::DocNode &Node);
bool verifyKernel(msgpack::DocNode &Node);

public:
/// Construct a MetadataVerifier, specifying whether it will operate in \p
Expand All @@ -58,7 +58,7 @@ class MetadataVerifier {
/// Verify given HSA metadata.
///
/// \returns True when successful, false when metadata is invalid.
bool verify(msgpack::Node &HSAMetadataRoot);
bool verify(msgpack::DocNode &HSAMetadataRoot);
};

} // end namespace V3
Expand Down
152 changes: 72 additions & 80 deletions llvm/lib/BinaryFormat/AMDGPUMetadataVerifier.cpp
Expand Up @@ -20,98 +20,92 @@ namespace HSAMD {
namespace V3 {

bool MetadataVerifier::verifyScalar(
msgpack::Node &Node, msgpack::ScalarNode::ScalarKind SKind,
function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
auto ScalarPtr = dyn_cast<msgpack::ScalarNode>(&Node);
if (!ScalarPtr)
return false;
auto &Scalar = *ScalarPtr;
// Do not output extraneous tags for types we know from the spec.
Scalar.IgnoreTag = true;
if (Scalar.getScalarKind() != SKind) {
msgpack::DocNode &Node, msgpack::Type SKind,
function_ref<bool(msgpack::DocNode &)> verifyValue) {
if (!Node.isScalar())
return false;
if (Node.getKind() != SKind) {
if (Strict)
return false;
// If we are not strict, we interpret string values as "implicitly typed"
// and attempt to coerce them to the expected type here.
if (Scalar.getScalarKind() != msgpack::ScalarNode::SK_String)
if (Node.getKind() != msgpack::Type::String)
return false;
std::string StringValue = Scalar.getString();
Scalar.setScalarKind(SKind);
if (Scalar.inputYAML(StringValue) != StringRef())
StringRef StringValue = Node.getString();
Node.fromString(StringValue);
if (Node.getKind() != SKind)
return false;
}
if (verifyValue)
return verifyValue(Scalar);
return verifyValue(Node);
return true;
}

bool MetadataVerifier::verifyInteger(msgpack::Node &Node) {
if (!verifyScalar(Node, msgpack::ScalarNode::SK_UInt))
if (!verifyScalar(Node, msgpack::ScalarNode::SK_Int))
bool MetadataVerifier::verifyInteger(msgpack::DocNode &Node) {
if (!verifyScalar(Node, msgpack::Type::UInt))
if (!verifyScalar(Node, msgpack::Type::Int))
return false;
return true;
}

bool MetadataVerifier::verifyArray(
msgpack::Node &Node, function_ref<bool(msgpack::Node &)> verifyNode,
msgpack::DocNode &Node, function_ref<bool(msgpack::DocNode &)> verifyNode,
Optional<size_t> Size) {
auto ArrayPtr = dyn_cast<msgpack::ArrayNode>(&Node);
if (!ArrayPtr)
if (!Node.isArray())
return false;
auto &Array = *ArrayPtr;
auto &Array = Node.getArray();
if (Size && Array.size() != *Size)
return false;
for (auto &Item : Array)
if (!verifyNode(*Item.get()))
if (!verifyNode(Item))
return false;

return true;
}

bool MetadataVerifier::verifyEntry(
msgpack::MapNode &MapNode, StringRef Key, bool Required,
function_ref<bool(msgpack::Node &)> verifyNode) {
msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
function_ref<bool(msgpack::DocNode &)> verifyNode) {
auto Entry = MapNode.find(Key);
if (Entry == MapNode.end())
return !Required;
return verifyNode(*Entry->second.get());
return verifyNode(Entry->second);
}

bool MetadataVerifier::verifyScalarEntry(
msgpack::MapNode &MapNode, StringRef Key, bool Required,
msgpack::ScalarNode::ScalarKind SKind,
function_ref<bool(msgpack::ScalarNode &)> verifyValue) {
return verifyEntry(MapNode, Key, Required, [=](msgpack::Node &Node) {
msgpack::MapDocNode &MapNode, StringRef Key, bool Required,
msgpack::Type SKind,
function_ref<bool(msgpack::DocNode &)> verifyValue) {
return verifyEntry(MapNode, Key, Required, [=](msgpack::DocNode &Node) {
return verifyScalar(Node, SKind, verifyValue);
});
}

bool MetadataVerifier::verifyIntegerEntry(msgpack::MapNode &MapNode,
bool MetadataVerifier::verifyIntegerEntry(msgpack::MapDocNode &MapNode,
StringRef Key, bool Required) {
return verifyEntry(MapNode, Key, Required, [this](msgpack::Node &Node) {
return verifyEntry(MapNode, Key, Required, [this](msgpack::DocNode &Node) {
return verifyInteger(Node);
});
}

bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
auto ArgsMapPtr = dyn_cast<msgpack::MapNode>(&Node);
if (!ArgsMapPtr)
bool MetadataVerifier::verifyKernelArgs(msgpack::DocNode &Node) {
if (!Node.isMap())
return false;
auto &ArgsMap = *ArgsMapPtr;
auto &ArgsMap = Node.getMap();

if (!verifyScalarEntry(ArgsMap, ".name", false,
msgpack::ScalarNode::SK_String))
msgpack::Type::String))
return false;
if (!verifyScalarEntry(ArgsMap, ".type_name", false,
msgpack::ScalarNode::SK_String))
msgpack::Type::String))
return false;
if (!verifyIntegerEntry(ArgsMap, ".size", true))
return false;
if (!verifyIntegerEntry(ArgsMap, ".offset", true))
return false;
if (!verifyScalarEntry(ArgsMap, ".value_kind", true,
msgpack::ScalarNode::SK_String,
[](msgpack::ScalarNode &SNode) {
msgpack::Type::String,
[](msgpack::DocNode &SNode) {
return StringSwitch<bool>(SNode.getString())
.Case("by_value", true)
.Case("global_buffer", true)
Expand All @@ -131,8 +125,8 @@ bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
}))
return false;
if (!verifyScalarEntry(ArgsMap, ".value_type", true,
msgpack::ScalarNode::SK_String,
[](msgpack::ScalarNode &SNode) {
msgpack::Type::String,
[](msgpack::DocNode &SNode) {
return StringSwitch<bool>(SNode.getString())
.Case("struct", true)
.Case("i8", true)
Expand All @@ -152,8 +146,8 @@ bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
if (!verifyIntegerEntry(ArgsMap, ".pointee_align", false))
return false;
if (!verifyScalarEntry(ArgsMap, ".address_space", false,
msgpack::ScalarNode::SK_String,
[](msgpack::ScalarNode &SNode) {
msgpack::Type::String,
[](msgpack::DocNode &SNode) {
return StringSwitch<bool>(SNode.getString())
.Case("private", true)
.Case("global", true)
Expand All @@ -165,8 +159,8 @@ bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
}))
return false;
if (!verifyScalarEntry(ArgsMap, ".access", false,
msgpack::ScalarNode::SK_String,
[](msgpack::ScalarNode &SNode) {
msgpack::Type::String,
[](msgpack::DocNode &SNode) {
return StringSwitch<bool>(SNode.getString())
.Case("read_only", true)
.Case("write_only", true)
Expand All @@ -175,8 +169,8 @@ bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
}))
return false;
if (!verifyScalarEntry(ArgsMap, ".actual_access", false,
msgpack::ScalarNode::SK_String,
[](msgpack::ScalarNode &SNode) {
msgpack::Type::String,
[](msgpack::DocNode &SNode) {
return StringSwitch<bool>(SNode.getString())
.Case("read_only", true)
.Case("write_only", true)
Expand All @@ -185,36 +179,35 @@ bool MetadataVerifier::verifyKernelArgs(msgpack::Node &Node) {
}))
return false;
if (!verifyScalarEntry(ArgsMap, ".is_const", false,
msgpack::ScalarNode::SK_Boolean))
msgpack::Type::Boolean))
return false;
if (!verifyScalarEntry(ArgsMap, ".is_restrict", false,
msgpack::ScalarNode::SK_Boolean))
msgpack::Type::Boolean))
return false;
if (!verifyScalarEntry(ArgsMap, ".is_volatile", false,
msgpack::ScalarNode::SK_Boolean))
msgpack::Type::Boolean))
return false;
if (!verifyScalarEntry(ArgsMap, ".is_pipe", false,
msgpack::ScalarNode::SK_Boolean))
msgpack::Type::Boolean))
return false;

return true;
}

bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
auto KernelMapPtr = dyn_cast<msgpack::MapNode>(&Node);
if (!KernelMapPtr)
bool MetadataVerifier::verifyKernel(msgpack::DocNode &Node) {
if (!Node.isMap())
return false;
auto &KernelMap = *KernelMapPtr;
auto &KernelMap = Node.getMap();

if (!verifyScalarEntry(KernelMap, ".name", true,
msgpack::ScalarNode::SK_String))
msgpack::Type::String))
return false;
if (!verifyScalarEntry(KernelMap, ".symbol", true,
msgpack::ScalarNode::SK_String))
msgpack::Type::String))
return false;
if (!verifyScalarEntry(KernelMap, ".language", false,
msgpack::ScalarNode::SK_String,
[](msgpack::ScalarNode &SNode) {
msgpack::Type::String,
[](msgpack::DocNode &SNode) {
return StringSwitch<bool>(SNode.getString())
.Case("OpenCL C", true)
.Case("OpenCL C++", true)
Expand All @@ -226,41 +219,41 @@ bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
}))
return false;
if (!verifyEntry(
KernelMap, ".language_version", false, [this](msgpack::Node &Node) {
KernelMap, ".language_version", false, [this](msgpack::DocNode &Node) {
return verifyArray(
Node,
[this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
[this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
}))
return false;
if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::Node &Node) {
return verifyArray(Node, [this](msgpack::Node &Node) {
if (!verifyEntry(KernelMap, ".args", false, [this](msgpack::DocNode &Node) {
return verifyArray(Node, [this](msgpack::DocNode &Node) {
return verifyKernelArgs(Node);
});
}))
return false;
if (!verifyEntry(KernelMap, ".reqd_workgroup_size", false,
[this](msgpack::Node &Node) {
[this](msgpack::DocNode &Node) {
return verifyArray(Node,
[this](msgpack::Node &Node) {
[this](msgpack::DocNode &Node) {
return verifyInteger(Node);
},
3);
}))
return false;
if (!verifyEntry(KernelMap, ".workgroup_size_hint", false,
[this](msgpack::Node &Node) {
[this](msgpack::DocNode &Node) {
return verifyArray(Node,
[this](msgpack::Node &Node) {
[this](msgpack::DocNode &Node) {
return verifyInteger(Node);
},
3);
}))
return false;
if (!verifyScalarEntry(KernelMap, ".vec_type_hint", false,
msgpack::ScalarNode::SK_String))
msgpack::Type::String))
return false;
if (!verifyScalarEntry(KernelMap, ".device_enqueue_symbol", false,
msgpack::ScalarNode::SK_String))
msgpack::Type::String))
return false;
if (!verifyIntegerEntry(KernelMap, ".kernarg_segment_size", true))
return false;
Expand All @@ -286,29 +279,28 @@ bool MetadataVerifier::verifyKernel(msgpack::Node &Node) {
return true;
}

bool MetadataVerifier::verify(msgpack::Node &HSAMetadataRoot) {
auto RootMapPtr = dyn_cast<msgpack::MapNode>(&HSAMetadataRoot);
if (!RootMapPtr)
bool MetadataVerifier::verify(msgpack::DocNode &HSAMetadataRoot) {
if (!HSAMetadataRoot.isMap())
return false;
auto &RootMap = *RootMapPtr;
auto &RootMap = HSAMetadataRoot.getMap();

if (!verifyEntry(
RootMap, "amdhsa.version", true, [this](msgpack::Node &Node) {
RootMap, "amdhsa.version", true, [this](msgpack::DocNode &Node) {
return verifyArray(
Node,
[this](msgpack::Node &Node) { return verifyInteger(Node); }, 2);
[this](msgpack::DocNode &Node) { return verifyInteger(Node); }, 2);
}))
return false;
if (!verifyEntry(
RootMap, "amdhsa.printf", false, [this](msgpack::Node &Node) {
return verifyArray(Node, [this](msgpack::Node &Node) {
return verifyScalar(Node, msgpack::ScalarNode::SK_String);
RootMap, "amdhsa.printf", false, [this](msgpack::DocNode &Node) {
return verifyArray(Node, [this](msgpack::DocNode &Node) {
return verifyScalar(Node, msgpack::Type::String);
});
}))
return false;
if (!verifyEntry(RootMap, "amdhsa.kernels", true,
[this](msgpack::Node &Node) {
return verifyArray(Node, [this](msgpack::Node &Node) {
[this](msgpack::DocNode &Node) {
return verifyArray(Node, [this](msgpack::DocNode &Node) {
return verifyKernel(Node);
});
}))
Expand Down

0 comments on commit ed0b9af

Please sign in to comment.