Skip to content

Commit

Permalink
Make some //tf/compiler/mlir/tensorflow/utils modules build with --co…
Browse files Browse the repository at this point in the history
…nfig android_arm64.

PiperOrigin-RevId: 291246582
Change-Id: Idf13e25462a9722ea70bcc48a9c7b9091bcba8a5
  • Loading branch information
impjdi authored and tensorflower-gardener committed Jan 23, 2020
1 parent b840bf5 commit 29c1d30
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 65 deletions.
22 changes: 16 additions & 6 deletions tensorflow/compiler/mlir/tensorflow/BUILD
Expand Up @@ -383,15 +383,24 @@ cc_library(
)

cc_library(
name = "import_utils",
srcs = [
"utils/import_utils.cc",
],
hdrs = [
"utils/import_utils.h",
name = "parse_text_proto",
srcs = ["utils/parse_text_proto.cc"],
hdrs = ["utils/parse_text_proto.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "import_utils",
srcs = ["utils/import_utils.cc"],
hdrs = ["utils/import_utils.h"],
deps = [
":error_util",
":parse_text_proto",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -598,6 +607,7 @@ cc_library(
srcs = ["utils/mangling_util.cc"],
hdrs = ["utils/mangling_util.h"],
deps = [
":parse_text_proto",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
Expand Down
43 changes: 13 additions & 30 deletions tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc
Expand Up @@ -19,59 +19,42 @@ limitations under the License.
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/raw_ostream.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {
namespace {
// Error collector that simply ignores errors reported.
class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector {
public:
void AddError(int line, int column, const std::string& message) override {}
};

inline llvm::StringRef StringViewToRef(absl::string_view view) {
return {view.data(), view.size()};
}
} // namespace

namespace tensorflow {

Status LoadProtoFromBuffer(absl::string_view input,
tensorflow::protobuf::Message* proto) {
tensorflow::protobuf::TextFormat::Parser parser;
// Don't produce errors when attempting to parse text format as it would fail
// when the input is actually a binary file.
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
protobuf::MessageLite* proto) {
// Attempt to parse as text.
tensorflow::protobuf::io::ArrayInputStream input_stream(input.data(),
input.size());
if (parser.Parse(&input_stream, proto)) {
return Status::OK();
}
if (ParseTextProto(input, "", proto).ok()) return Status::OK();

// Else attempt to parse as binary.
proto->Clear();
tensorflow::protobuf::io::ArrayInputStream binary_stream(input.data(),
input.size());
if (proto->ParseFromZeroCopyStream(&binary_stream)) {
return Status::OK();
}
protobuf::io::ArrayInputStream binary_stream(input.data(), input.size());
if (proto->ParseFromZeroCopyStream(&binary_stream)) return Status::OK();

LOG(ERROR) << "Error parsing Protobuf";
return errors::InvalidArgument("Could not parse input proto");
}

Status LoadProtoFromFile(absl::string_view input_filename,
tensorflow::protobuf::Message* proto) {
auto file_or_err =
protobuf::MessageLite* proto) {
const auto file_or_err =
llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename));
if (std::error_code error = file_or_err.getError())
if (std::error_code error = file_or_err.getError()) {
return errors::InvalidArgument("Could not open input file");
}

auto& input_file = *file_or_err;
const auto& input_file = *file_or_err;
absl::string_view content(input_file->getBufferStart(),
input_file->getBufferSize());

return LoadProtoFromBuffer(content, proto);
}

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/tensorflow/utils/import_utils.h
Expand Up @@ -25,12 +25,12 @@ namespace tensorflow {
// Reads text (.pbtext) or binary (.pb) format of a proto message from the given
// buffer. Returns error status of the file is not found or malformed proto.
Status LoadProtoFromBuffer(absl::string_view input,
tensorflow::protobuf::Message* proto);
tensorflow::protobuf::MessageLite* proto);

// Reads text (.pbtext) or binary (.pb) format of a proto message from the given
// file path. Returns error status of the file is not found or malformed proto.
Status LoadProtoFromFile(absl::string_view input_filename,
tensorflow::protobuf::Message* proto);
tensorflow::protobuf::MessageLite* proto);

} // namespace tensorflow

Expand Down
31 changes: 4 additions & 27 deletions tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
Expand All @@ -26,21 +27,12 @@ limitations under the License.
namespace tensorflow {
namespace mangling_util {
namespace {

const char kAttributePrefix[] = "tf.";
const char kDataTypePrefix[] = "tfdtype$";
const char kTensorShapePrefix[] = "tfshape$";
const char kTensorPrefix[] = "tftensor$";

// Sets output to the given input with 'prefix' stripped, or return an error if
// the prefix did not exist.
Status ConsumePrefix(absl::string_view str, absl::string_view prefix,
absl::string_view* output) {
if (absl::StartsWith(str, prefix)) {
*output = str.substr(prefix.size());
return Status::OK();
}
return errors::FailedPrecondition("Not a mangled string");
}
} // namespace

string MangleAttributeName(absl::string_view str) {
Expand Down Expand Up @@ -73,30 +65,15 @@ string MangleShape(const TensorShapeProto& shape) {
}

Status DemangleShape(absl::string_view str, TensorShapeProto* proto) {
absl::string_view pbtxt;
TF_RETURN_IF_ERROR(ConsumePrefix(str, kTensorShapePrefix, &pbtxt));
tensorflow::protobuf::io::ArrayInputStream input_stream(pbtxt.data(),
pbtxt.size());
if (!tensorflow::protobuf::TextFormat::Parse(&input_stream, proto)) {
return errors::FailedPrecondition(
"Could not parse TFTensorShape mangled proto");
}
return Status::OK();
return ParseTextProto(str, kTensorShapePrefix, proto);
}

string MangleTensor(const TensorProto& tensor) {
return absl::StrCat(kTensorPrefix, tensor.ShortDebugString());
}

Status DemangleTensor(absl::string_view str, TensorProto* proto) {
absl::string_view pbtxt;
TF_RETURN_IF_ERROR(ConsumePrefix(str, kTensorPrefix, &pbtxt));
tensorflow::protobuf::io::ArrayInputStream input_stream(pbtxt.data(),
pbtxt.size());
if (!tensorflow::protobuf::TextFormat::Parse(&input_stream, proto)) {
return errors::FailedPrecondition("Could not parse TFTensor mangled proto");
}
return Status::OK();
return ParseTextProto(str, kTensorPrefix, proto);
}

string MangleDataType(const DataType& dtype) {
Expand Down
74 changes: 74 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.cc
@@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h"

#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {

#ifndef TENSORFLOW_LITE_PROTOS
namespace {
// Error collector that simply ignores errors reported.
class NoOpErrorCollector : public protobuf::io::ErrorCollector {
public:
void AddError(int line, int column, const std::string& message) override {}
};
} // namespace
#endif // TENSORFLOW_LITE_PROTOS

Status ConsumePrefix(absl::string_view str, absl::string_view prefix,
absl::string_view* output) {
if (absl::StartsWith(str, prefix)) {
*output = str.substr(prefix.size());
return Status::OK();
}
return errors::NotFound("No prefix \"", prefix, "\" in \"", str, "\"");
}

Status ParseTextProto(absl::string_view text_proto,
absl::string_view prefix_to_strip,
protobuf::MessageLite* parsed_proto) {
#ifndef TENSORFLOW_LITE_PROTOS
protobuf::TextFormat::Parser parser;
// Don't produce errors when attempting to parse text format as it would fail
// when the input is actually a binary file.
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
// Attempt to parse as text.
absl::string_view text_proto_without_prefix = text_proto;
if (!prefix_to_strip.empty()) {
TF_RETURN_IF_ERROR(
ConsumePrefix(text_proto, prefix_to_strip, &text_proto_without_prefix));
}
protobuf::io::ArrayInputStream input_stream(text_proto_without_prefix.data(),
text_proto_without_prefix.size());
if (parser.Parse(&input_stream,
tensorflow::down_cast<protobuf::Message*>(parsed_proto))) {
return Status::OK();
}
parsed_proto->Clear();
return errors::InvalidArgument("Could not parse text proto: ", text_proto);
#else
return errors::Unavailable("Cannot parse text protos on mobile.");
#endif // TENSORFLOW_LITE_PROTOS
}

} // namespace tensorflow
39 changes: 39 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h
@@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARSE_TEXT_PROTO_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARSE_TEXT_PROTO_H_

#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {

// Sets output to the given input with `prefix` stripped, or returns an error if
// the prefix doesn't exist.
Status ConsumePrefix(absl::string_view str, absl::string_view prefix,
absl::string_view* output);

// Strips `prefix_to_strip` from `text_proto`, parses, and returns the parsed
// proto.
Status ParseTextProto(absl::string_view text_proto,
absl::string_view prefix_to_strip,
protobuf::MessageLite* parsed_proto);

} // namespace tensorflow

#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARSE_TEXT_PROTO_H_

0 comments on commit 29c1d30

Please sign in to comment.