Skip to content

Commit

Permalink
add loading TensorFlow/Caffe net from memory buffer
Browse files Browse the repository at this point in the history
add a corresponding test
  • Loading branch information
r2d3 committed Nov 20, 2017
1 parent 6e4f943 commit f723ced
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 43 deletions.
22 changes: 22 additions & 0 deletions modules/dnn/include/opencv2/dnn/dnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,11 +634,33 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*/
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());

/** @brief Reads a network model stored in Caffe model in memory.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferProto buffer containing the content of the .prototxt file
* @param lenProto length of bufferProto
* @param bufferModel buffer containing the content of the .caffemodel file
* @param lenModel length of bufferModel
*/
CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel = NULL, size_t lenModel = 0);

/** @brief Reads a network model stored in Tensorflow model file.
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
*/
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());

/** @brief Reads a network model stored in Tensorflow model in memory.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferModel buffer containing the content of the pb file
* @param lenModel length of bufferModel
* @param bufferConfig buffer containing the content of the pbtxt file
* @param lenConfig length of bufferConfig
*/
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
const char *bufferConfig = NULL, size_t lenConfig = 0);

/** @brief Reads a network model stored in Torch model file.
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
*/
Expand Down
20 changes: 20 additions & 0 deletions modules/dnn/src/caffe/caffe_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ class CaffeImporter : public Importer
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
}

CaffeImporter(const char *dataProto, size_t lenProto,
const char *dataModel, size_t lenModel)
{
CV_TRACE_FUNCTION();

ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);

if (dataModel != NULL && lenModel > 0)
ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);
}

void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
{
const Reflection *refl = msg.GetReflection();
Expand Down Expand Up @@ -400,6 +411,15 @@ Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String
return net;
}

Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel, size_t lenModel)
{
CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);
Net net;
caffeImporter.populateNet(net);
return net;
}

#endif //HAVE_PROTOBUF

CV__DNN_EXPERIMENTAL_NS_END
Expand Down
45 changes: 34 additions & 11 deletions modules/dnn/src/caffe/caffe_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,28 +1108,37 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {

const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.

bool ReadProtoFromBinary(ZeroCopyInputStream* input, Message *proto) {
CodedInputStream coded_input(input);
coded_input.SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

return proto->ParseFromCodedStream(&coded_input);
}

bool ReadProtoFromTextFile(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
IstreamInputStream input(&fs);
bool success = google::protobuf::TextFormat::Parse(&input, proto);
fs.close();
return success;
return google::protobuf::TextFormat::Parse(&input, proto);
}

bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
IstreamInputStream raw_input(&fs);

return ReadProtoFromBinary(&raw_input, proto);
}

bool ReadProtoFromTextBuffer(const char* data, size_t len, Message* proto) {
ArrayInputStream input(data, len);
return google::protobuf::TextFormat::Parse(&input, proto);
}

bool success = proto->ParseFromCodedStream(coded_input);

delete coded_input;
delete raw_input;
fs.close();
return success;
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, Message* proto) {
ArrayInputStream raw_input(data, len);
return ReadProtoFromBinary(&raw_input, proto);
}

void ReadNetParamsFromTextFileOrDie(const char* param_file,
Expand All @@ -1139,13 +1148,27 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
UpgradeNetAsNeeded(param_file, param);
}

void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
NetParameter* param) {
CHECK(ReadProtoFromTextBuffer(data, len, param))
<< "Failed to parse NetParameter buffer";
UpgradeNetAsNeeded("memory buffer", param);
}

void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
NetParameter* param) {
CHECK(ReadProtoFromBinaryFile(param_file, param))
<< "Failed to parse NetParameter file: " << param_file;
UpgradeNetAsNeeded(param_file, param);
}

void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
NetParameter* param) {
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
<< "Failed to parse NetParameter buffer";
UpgradeNetAsNeeded("memory buffer", param);
}

}
}
#endif
12 changes: 12 additions & 0 deletions modules/dnn/src/caffe/caffe_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
caffe::NetParameter* param);

// Read parameters from a memory buffer into a NetParammeter proto message.
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
caffe::NetParameter* param);
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
caffe::NetParameter* param);

// Utility functions used internally by Caffe and TensorFlow loaders
bool ReadProtoFromTextFile(const char* filename, ::google::protobuf::Message* proto);
bool ReadProtoFromBinaryFile(const char* filename, ::google::protobuf::Message* proto);
bool ReadProtoFromTextBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);

}
}
#endif
Expand Down
21 changes: 21 additions & 0 deletions modules/dnn/src/tensorflow/tf_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
class TFImporter : public Importer {
public:
TFImporter(const char *model, const char *config = NULL);
TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig = NULL, size_t lenConfig = 0);

void populateNet(Net dstNet);
~TFImporter() {}

Expand Down Expand Up @@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
}

TFImporter::TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig, size_t lenConfig)
{
if (dataModel != NULL && lenModel > 0)
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
if (dataConfig != NULL && lenConfig > 0)
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
}

void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
{
MatShape shape;
Expand Down Expand Up @@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
return net;
}

Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
const char* bufferConfig, size_t lenConfig)
{
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
Net net;
importer.populateNet(net);
return net;
}

CV__DNN_EXPERIMENTAL_NS_END
}} // namespace
44 changes: 16 additions & 28 deletions modules/dnn/src/tensorflow/tf_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Implementation of various functions which are related to Tensorflow models readi

#include "graph.pb.h"
#include "tf_io.hpp"
#include "../caffe/caffe_io.hpp"
#include "../caffe/glog_emulator.hpp"

namespace cv {
Expand All @@ -36,41 +37,28 @@ using namespace ::google::protobuf::io;

const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.

// TODO: remove Caffe duplicate
bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

bool success = proto->ParseFromCodedStream(coded_input);

delete coded_input;
delete raw_input;
fs.close();
return success;
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFile(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
}

bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
IstreamInputStream input(&fs);
bool success = google::protobuf::TextFormat::Parse(&input, proto);
fs.close();
return success;
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
<< "Failed to parse GraphDef buffer";
}

void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
}

void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextBuffer(data, len, param))
<< "Failed to parse GraphDef buffer";
}

}
Expand Down
7 changes: 7 additions & 0 deletions modules/dnn/src/tensorflow/tf_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param);

// Read parameters from a memory buffer into a GraphDef proto message.
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);

void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);

}
}

Expand Down
37 changes: 35 additions & 2 deletions modules/dnn/test/test_caffe_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ static std::string _tf(TString filename)
return (getOpenCVExtraDir() + "/dnn/") + filename;
}

TEST(Test_Caffe, memory_read)
{
const string proto = findDataFile("dnn/bvlc_googlenet.prototxt", false);
const string model = findDataFile("dnn/bvlc_googlenet.caffemodel", false);

string dataProto;
ASSERT_TRUE(readFileInMemory(proto, dataProto));
string dataModel;
ASSERT_TRUE(readFileInMemory(model, dataModel));

Net net = readNetFromCaffe(dataProto.c_str(), dataProto.size());
ASSERT_FALSE(net.empty());

Net net2 = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
dataModel.c_str(), dataModel.size());
ASSERT_FALSE(net2.empty());
}

TEST(Test_Caffe, read_gtsrb)
{
Net net = readNetFromCaffe(_tf("gtsrb.prototxt"));
Expand All @@ -67,13 +85,26 @@ TEST(Test_Caffe, read_googlenet)
ASSERT_FALSE(net.empty());
}

TEST(Reproducibility_AlexNet, Accuracy)
typedef testing::TestWithParam<tuple<bool> > Reproducibility_AlexNet;
TEST_P(Reproducibility_AlexNet, Accuracy)
{
bool readFromMemory = get<0>(GetParam());
Net net;
{
const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false);
const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false);
net = readNetFromCaffe(proto, model);
if (readFromMemory)
{
string dataProto;
ASSERT_TRUE(readFileInMemory(proto, dataProto));
string dataModel;
ASSERT_TRUE(readFileInMemory(model, dataModel));

net = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
dataModel.c_str(), dataModel.size());
}
else
net = readNetFromCaffe(proto, model);
ASSERT_FALSE(net.empty());
}

Expand All @@ -86,6 +117,8 @@ TEST(Reproducibility_AlexNet, Accuracy)
normAssert(ref, out);
}

INSTANTIATE_TEST_CASE_P(Test_Caffe, Reproducibility_AlexNet, testing::Values(true, false));

#if !defined(_WIN32) || defined(_WIN64)
TEST(Reproducibility_FCN, Accuracy)
{
Expand Down
19 changes: 19 additions & 0 deletions modules/dnn/test/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,23 @@ inline void normAssert(cv::InputArray ref, cv::InputArray test, const char *comm
EXPECT_LE(normInf, lInf) << comment;
}

inline bool readFileInMemory(const std::string& filename, std::string& content)
{
std::ios::openmode mode = std::ios::in | std::ios::binary;
std::ifstream ifs(filename.c_str(), mode);
if (!ifs.is_open())
return false;

content.clear();

ifs.seekg(0, std::ios::end);
content.reserve(ifs.tellg());
ifs.seekg(0, std::ios::beg);

content.assign((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());

return true;
}

#endif

0 comments on commit f723ced

Please sign in to comment.