Skip to content

Commit

Permalink
[llvm] Expose type and element count-related APIs on TensorSpec
Browse files Browse the repository at this point in the history
Added a mechanism to check the element type, get the total element
count, and the size of an element.

Differential Revision: https://reviews.llvm.org/D85250
  • Loading branch information
mtrofin committed Aug 5, 2020
1 parent ac70b37 commit 90b9c49
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
13 changes: 11 additions & 2 deletions llvm/include/llvm/Analysis/Utils/TFUtils.h
Expand Up @@ -66,10 +66,18 @@ class TensorSpec final {

bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }

/// Get the number of elements in a tensor with this shape.
size_t getElementCount() const { return ElementCount; }
/// Get the size, in bytes, of one element.
size_t getElementByteSize() const;

template <typename T> bool isElementType() const {
return getDataType<T>() == TypeIndex;
}

private:
TensorSpec(const std::string &Name, int Port, int TypeIndex,
const std::vector<int64_t> &Shape)
: Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape) {}
const std::vector<int64_t> &Shape);

template <typename T> static int getDataType() {
llvm_unreachable("Undefined tensor type");
Expand All @@ -79,6 +87,7 @@ class TensorSpec final {
int Port = 0;
int TypeIndex = 0;
std::vector<int64_t> Shape;
size_t ElementCount = 0;
};

Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Analysis/TFUtils.cpp
Expand Up @@ -24,6 +24,7 @@
#include "tensorflow/c/c_api_experimental.h"

#include <cassert>
#include <numeric>

using namespace llvm;

Expand Down Expand Up @@ -84,6 +85,16 @@ class EvaluationResultImpl {
std::vector<TF_Tensor *> Output;
};

size_t TensorSpec::getElementByteSize() const {
return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
}

TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
const std::vector<int64_t> &Shape)
: Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
std::multiplies<int64_t>())) {}

Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
const json::Value &Value) {
auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
Expand Down
15 changes: 15 additions & 0 deletions llvm/unittests/Analysis/TFUtilsTest.cpp
Expand Up @@ -123,3 +123,18 @@ TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
auto Spec = getTensorSpecFromJSON(Ctx, *Value);
EXPECT_FALSE(Spec.hasValue());
}

TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1});
auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1});
auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10});
auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10});
EXPECT_TRUE(Spec1D.isElementType<int16_t>());
EXPECT_FALSE(Spec3DLarge.isElementType<double>());
EXPECT_EQ(Spec1D.getElementCount(), 1);
EXPECT_EQ(Spec2D.getElementCount(), 1);
EXPECT_EQ(Spec1DLarge.getElementCount(), 10);
EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
}

0 comments on commit 90b9c49

Please sign in to comment.