Skip to content

Commit

Permalink
add loading TensorFlow net from memory buffer
Browse files Browse the repository at this point in the history
add a corresponding test
  • Loading branch information
r2d3 committed Nov 2, 2017
1 parent f9ac166 commit 8425aec
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 7 deletions.
11 changes: 11 additions & 0 deletions modules/dnn/include/opencv2/dnn/dnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*/
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());

/** @brief Reads a network model stored in Tensorflow model in memory
* @defails This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferModel buffer containing the content of the pb file
* @param lenModel len of bufferModel
* @param bufferConfig buffer containing the content of the pbtxt file
* @param lenConfig len of bufferConfig
*/
CV_EXPORTS_W Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
const char *bufferConfig = NULL, size_t lenConfig = 0);

/** @brief Reads a network model stored in Torch model file.
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
*/
Expand Down
21 changes: 21 additions & 0 deletions modules/dnn/src/tensorflow/tf_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
class TFImporter : public Importer {
public:
TFImporter(const char *model, const char *config = NULL);
TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig = NULL, size_t lenConfig = 0);

void populateNet(Net dstNet);
~TFImporter() {}

Expand Down Expand Up @@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
}

TFImporter::TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig, size_t lenConfig)
{
if (dataModel != NULL && lenModel > 0)
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
if (dataConfig != NULL && lenConfig > 0)
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
}

void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
{
MatShape shape;
Expand Down Expand Up @@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
return net;
}

Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
const char* bufferConfig, size_t lenConfig)
{
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
Net net;
importer.populateNet(net);
return net;
}

CV__DNN_EXPERIMENTAL_NS_END
}} // namespace
41 changes: 36 additions & 5 deletions modules/dnn/src/tensorflow/tf_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,56 @@ bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
return success;
}

bool ReadProtoFromBinaryBufferTF(const char* data, size_t len, Message* proto) {
ZeroCopyInputStream* raw_input = new ArrayInputStream(data, len);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);

bool success = proto->ParseFromCodedStream(coded_input);

delete coded_input;
delete raw_input;
return success;
}

bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
IstreamInputStream input(&fs);
bool success = google::protobuf::TextFormat::Parse(&input, proto);
fs.close();

return success;
}

bool ReadProtoFromTextBufferTF(const char* data, size_t len, Message* proto) {
ArrayInputStream input(data, len);
bool success = google::protobuf::TextFormat::Parse(&input, proto);
return success;
}

void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
}

void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryBufferTF(data, len, param))
<< "Failed to parse GraphDef buffer";
}

void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
CHECK(ReadProtoFromTextFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
}

void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextBufferTF(data, len, param))
<< "Failed to parse GraphDef buffer";
}

}
Expand Down
7 changes: 7 additions & 0 deletions modules/dnn/src/tensorflow/tf_io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param);

// Read parameters from a memory buffer into a GraphDef proto message.
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);

void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);

}
}

Expand Down
56 changes: 54 additions & 2 deletions modules/dnn/test/test_tf_importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,61 @@ TEST(Test_TensorFlow, inception_accuracy)
normAssert(ref, out);
}

static bool readFileInMemory(const string& filename, char** data, size_t* len)
{
if (len == NULL || data == NULL)
return false;

FILE *f = fopen(filename.c_str(), "rb");
if (f == NULL)
return false;

fseek(f, 0, SEEK_END);
*len = ftell(f);
fseek(f, 0, SEEK_SET);

*data = (char*)malloc(*len + 1);
(void)fread(*data, *len, 1, f);
fclose(f);

return true;
}

static std::string path(const std::string& file)
{
return findDataFile("dnn/tensorflow/" + file, false);
}

static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
double l1 = 1e-5, double lInf = 1e-4)
double l1 = 1e-5, double lInf = 1e-4,
bool memoryLoad = false)
{
std::string netPath = path(prefix + "_net.pb");
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
std::string inpPath = path(prefix + "_in.npy");
std::string outPath = path(prefix + "_out.npy");

Net net = readNetFromTensorflow(netPath, netConfig);
Net net;
if (memoryLoad)
{
// Load files into a memory buffers
char *dataModel = NULL;
size_t lenModel = 0;
ASSERT_TRUE(readFileInMemory(netPath, &dataModel, &lenModel));

char *dataConfig = NULL;
size_t lenConfig = 0;
if (hasText)
ASSERT_TRUE(readFileInMemory(netConfig, &dataConfig, &lenConfig));

net = readNetFromTensorflow(dataModel, lenModel, dataConfig, lenConfig);
ASSERT_FALSE(net.empty());

free(dataModel);
free(dataConfig);
}
else
net = readNetFromTensorflow(netPath, netConfig);

cv::Mat input = blobFromNPY(inpPath);
cv::Mat target = blobFromNPY(outPath);
Expand Down Expand Up @@ -216,4 +257,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
runTensorFlowNet("resize_nearest_neighbor");
}

TEST(Test_TensorFlow, memory_read)
{
double l1 = 1e-5;
double lInf = 1e-4;
runTensorFlowNet("lstm", true, l1, lInf, true);

runTensorFlowNet("batch_norm", false, l1, lInf, true);
runTensorFlowNet("fused_batch_norm", false, l1, lInf, true);
runTensorFlowNet("batch_norm_text", true, l1, lInf, true);
}

}

0 comments on commit 8425aec

Please sign in to comment.