From fb4bf52aa59d5f8665053491c523f3bcb2321759 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Fri, 1 Sep 2017 23:07:08 -0700 Subject: [PATCH] Update the speed benchmark code Summary: (for TIR demo cases) Closes https://github.com/caffe2/caffe2/pull/1160 Differential Revision: D5761679 Pulled By: Yangqing fbshipit-source-id: 53b6c7fd098a394eba51baeac1e70371bcddf360 --- CMakeLists.txt | 4 +- caffe2/binaries/speed_benchmark.cc | 93 +++++++++++++++++++++++++++++- caffe2/core/common.h | 10 ++++ caffe2/core/flags.cc | 2 +- cmake/ProtoBuf.cmake | 1 + 5 files changed, 105 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8054d98f167..95f2482d012 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -121,7 +121,9 @@ endif() # added to the list of include directories, prefixing # PROJECT_SOURCE_DIR means this source tree always takes precedence. include_directories(BEFORE ${PROJECT_SOURCE_DIR}) -include_directories(BEFORE ${PROJECT_SOURCE_DIR}/build_host_protoc/include) +if (Caffe2_IS_CUSTOM_PROTOBUF) + include_directories(BEFORE ${PROJECT_SOURCE_DIR}/build_host_protoc/include) +endif() # Prefix path to generated Caffe2 headers. # These need to take precedence over their empty counterparts located diff --git a/caffe2/binaries/speed_benchmark.cc b/caffe2/binaries/speed_benchmark.cc index 5f58169f67e..e5b0e3d3b43 100644 --- a/caffe2/binaries/speed_benchmark.cc +++ b/caffe2/binaries/speed_benchmark.cc @@ -1,27 +1,114 @@ +#include + #include "caffe2/core/init.h" #include "caffe2/core/operator.h" #include "caffe2/proto/caffe2.pb.h" #include "caffe2/utils/proto_utils.h" +#include "caffe2/utils/string_utils.h" #include "caffe2/core/logging.h" CAFFE2_DEFINE_string(net, "", "The given net to benchmark."); CAFFE2_DEFINE_string(init_net, "", "The given net to initialize any parameters."); +CAFFE2_DEFINE_string(input, "", + "Input that is needed for running the network. If " + "multiple input needed, use comma separated string."); +CAFFE2_DEFINE_string(input_file, "", + "Input file that contain the serialized protobuf for " + "the input blobs. If multiple input needed, use comma " + "separated string. Must have the same number of items " + "as input does."); +CAFFE2_DEFINE_string(input_dims, "", + "Alternate to input_files, if all inputs are simple " + "float TensorCPUs, specify the dimension using comma " + "separated numbers. If multiple input needed, use " + "semicolon to separate the dimension of different " + "tensors."); +CAFFE2_DEFINE_string(output, "", + "Output that should be dumped after the execution " + "finishes. If multiple outputs are needed, use comma " + "separated string. If you want to dump everything, pass " + "'*' as the output value."); +CAFFE2_DEFINE_string(output_folder, "", + "The folder that the output should be written to. This " + "folder must already exist in the file system."); CAFFE2_DEFINE_int(warmup, 0, "The number of iterations to warm up."); CAFFE2_DEFINE_int(iter, 10, "The number of iterations to run."); CAFFE2_DEFINE_bool(run_individual, false, "Whether to benchmark individual operators."); +using std::string; +using std::unique_ptr; +using std::vector; + int main(int argc, char** argv) { caffe2::GlobalInit(&argc, &argv); - std::unique_ptr workspace(new caffe2::Workspace()); + unique_ptr workspace(new caffe2::Workspace()); + // Load input. + if (caffe2::FLAGS_input.size()) { + vector input_names = caffe2::split(',', caffe2::FLAGS_input); + if (caffe2::FLAGS_input_file.size()) { + vector input_files = caffe2::split(',', caffe2::FLAGS_input_file); + CAFFE_ENFORCE_EQ( + input_names.size(), input_files.size(), + "Input name and file should have the same number."); + for (int i = 0; i < input_names.size(); ++i) { + caffe2::BlobProto blob_proto; + CAFFE_ENFORCE(caffe2::ReadProtoFromFile(input_files[i], &blob_proto)); + workspace->GetBlob(input_names[i])->Deserialize(blob_proto); + } + } else if (caffe2::FLAGS_input_dims.size()) { + vector input_dims_list = caffe2::split(';', caffe2::FLAGS_input_dims); + CAFFE_ENFORCE_EQ( + input_names.size(), input_dims_list.size(), + "Input name and dims should have the same number of items."); + for (int i = 0; i < input_names.size(); ++i) { + vector input_dims_str = caffe2::split(',', input_dims_list[i]); + vector input_dims; + for (const string& s : input_dims_str) { + input_dims.push_back(caffe2::stoi(s)); + } + caffe2::TensorCPU* tensor = + workspace->GetBlob(input_names[i])->GetMutable(); + tensor->Reshape(input_dims); + tensor->mutable_data(); + } + } else { + CAFFE_THROW("You requested input tensors, but neither input_file nor " + "input_dims is set."); + } + } // Run initialization network. caffe2::NetDef net_def; CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_init_net, &net_def)); CAFFE_ENFORCE(workspace->RunNetOnce(net_def)); + // Run main network. CAFFE_ENFORCE(ReadProtoFromFile(caffe2::FLAGS_net, &net_def)); caffe2::NetBase* net = workspace->CreateNet(net_def); CHECK_NOTNULL(net); - CAFFE_ENFORCE(net->Run()); - net->TEST_Benchmark(caffe2::FLAGS_warmup, caffe2::FLAGS_iter, caffe2::FLAGS_run_individual); + net->TEST_Benchmark( + caffe2::FLAGS_warmup, + caffe2::FLAGS_iter, + caffe2::FLAGS_run_individual); + + string output_prefix = + caffe2::FLAGS_output_folder.size() + ? caffe2::FLAGS_output_folder + "/" + : ""; + if (caffe2::FLAGS_output.size()) { + vector output_names = caffe2::split(',', caffe2::FLAGS_output); + if (caffe2::FLAGS_output == "*") { + output_names = workspace->Blobs(); + } + for (const string& name : output_names) { + CAFFE_ENFORCE( + workspace->HasBlob(name), + "You requested a non-existing blob: ", + name); + string serialized = workspace->GetBlob(name)->Serialize(name); + string output_filename = output_prefix + name; + caffe2::WriteStringToFile(serialized, output_filename.c_str()); + } + } + return 0; } diff --git a/caffe2/core/common.h b/caffe2/core/common.h index aa0e0bbdb37..4bd2235b693 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -159,6 +159,7 @@ make_unique(Args&&...) = delete; // to_string implementation for Android related stuff. #ifndef __ANDROID__ using std::to_string; +using std::stoi; #else template std::string to_string(T value) @@ -167,6 +168,15 @@ std::string to_string(T value) os << value; return os.str(); } + +inline int stoi(const string& str) +{ + std::stringstream ss; + int n = 0; + ss << str; + ss >> n; + return n; +} #endif // dynamic cast reroute: if RTTI is disabled, go to reinterpret_cast diff --git a/caffe2/core/flags.cc b/caffe2/core/flags.cc index dc0f0e93b7e..2df6edc7848 100644 --- a/caffe2/core/flags.cc +++ b/caffe2/core/flags.cc @@ -63,7 +63,7 @@ bool ParseCaffeCommandLineFlags(int* pargc, char*** pargv) { for (int i = 1; i < *pargc; ++i) { string arg(argv[i]); - if (arg == "--help") { + if (arg.find("--help") != string::npos) { // Print the help message, and quit. std::cout << UsageMessage() << std::endl; std::cout << "Arguments: " << std::endl; diff --git a/cmake/ProtoBuf.cmake b/cmake/ProtoBuf.cmake index 20640676537..17b9eba1686 100644 --- a/cmake/ProtoBuf.cmake +++ b/cmake/ProtoBuf.cmake @@ -26,6 +26,7 @@ function(custom_protobuf_find) message(STATUS "Using protobuf compiler ${PROTOBUF_PROTOC_EXECUTABLE}.") endif() set(Protobuf_FOUND TRUE PARENT_SCOPE) + set(Caffe2_IS_CUSTOM_PROTOBUF TRUE PARENT_SCOPE) endfunction() if (WIN32)