Skip to content

Commit 8425aec

Browse files
committed
add loading TensorFlow net from memory buffer
add a corresponding test
1 parent f9ac166 commit 8425aec

File tree

5 files changed

+129
-7
lines changed

5 files changed

+129
-7
lines changed

modules/dnn/include/opencv2/dnn/dnn.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
639639
*/
640640
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
641641

642+
/** @brief Reads a network model stored in Tensorflow model in memory
643+
* @defails This is an overloaded member function, provided for convenience.
644+
* It differs from the above function only in what argument(s) it accepts.
645+
* @param bufferModel buffer containing the content of the pb file
646+
* @param lenModel len of bufferModel
647+
* @param bufferConfig buffer containing the content of the pbtxt file
648+
* @param lenConfig len of bufferConfig
649+
*/
650+
CV_EXPORTS_W Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
651+
const char *bufferConfig = NULL, size_t lenConfig = 0);
652+
642653
/** @brief Reads a network model stored in Torch model file.
643654
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
644655
*/

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
449449
class TFImporter : public Importer {
450450
public:
451451
TFImporter(const char *model, const char *config = NULL);
452+
TFImporter(const char *dataModel, size_t lenModel,
453+
const char *dataConfig = NULL, size_t lenConfig = 0);
454+
452455
void populateNet(Net dstNet);
453456
~TFImporter() {}
454457

@@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
479482
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
480483
}
481484

485+
TFImporter::TFImporter(const char *dataModel, size_t lenModel,
486+
const char *dataConfig, size_t lenConfig)
487+
{
488+
if (dataModel != NULL && lenModel > 0)
489+
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
490+
if (dataConfig != NULL && lenConfig > 0)
491+
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
492+
}
493+
482494
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
483495
{
484496
MatShape shape;
@@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
13261338
return net;
13271339
}
13281340

1341+
Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
1342+
const char* bufferConfig, size_t lenConfig)
1343+
{
1344+
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
1345+
Net net;
1346+
importer.populateNet(net);
1347+
return net;
1348+
}
1349+
13291350
CV__DNN_EXPERIMENTAL_NS_END
13301351
}} // namespace

modules/dnn/src/tensorflow/tf_io.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,56 @@ bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
5252
return success;
5353
}
5454

55+
bool ReadProtoFromBinaryBufferTF(const char* data, size_t len, Message* proto) {
56+
ZeroCopyInputStream* raw_input = new ArrayInputStream(data, len);
57+
CodedInputStream* coded_input = new CodedInputStream(raw_input);
58+
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
59+
60+
bool success = proto->ParseFromCodedStream(coded_input);
61+
62+
delete coded_input;
63+
delete raw_input;
64+
return success;
65+
}
66+
5567
bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
5668
std::ifstream fs(filename, std::ifstream::in);
5769
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
5870
IstreamInputStream input(&fs);
5971
bool success = google::protobuf::TextFormat::Parse(&input, proto);
6072
fs.close();
73+
74+
return success;
75+
}
76+
77+
bool ReadProtoFromTextBufferTF(const char* data, size_t len, Message* proto) {
78+
ArrayInputStream input(data, len);
79+
bool success = google::protobuf::TextFormat::Parse(&input, proto);
6180
return success;
6281
}
6382

6483
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
65-
tensorflow::GraphDef* param) {
66-
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
67-
<< "Failed to parse GraphDef file: " << param_file;
84+
tensorflow::GraphDef* param) {
85+
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
86+
<< "Failed to parse GraphDef file: " << param_file;
87+
}
88+
89+
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
90+
tensorflow::GraphDef* param) {
91+
CHECK(ReadProtoFromBinaryBufferTF(data, len, param))
92+
<< "Failed to parse GraphDef buffer";
6893
}
6994

7095
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
7196
tensorflow::GraphDef* param) {
72-
CHECK(ReadProtoFromTextFileTF(param_file, param))
73-
<< "Failed to parse GraphDef file: " << param_file;
97+
CHECK(ReadProtoFromTextFileTF(param_file, param))
98+
<< "Failed to parse GraphDef file: " << param_file;
99+
}
100+
101+
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
102+
tensorflow::GraphDef* param) {
103+
CHECK(ReadProtoFromTextBufferTF(data, len, param))
104+
<< "Failed to parse GraphDef buffer";
74105
}
75106

76107
}

modules/dnn/src/tensorflow/tf_io.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
2525
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
2626
tensorflow::GraphDef* param);
2727

28+
// Read parameters from a memory buffer into a GraphDef proto message.
29+
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
30+
tensorflow::GraphDef* param);
31+
32+
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
33+
tensorflow::GraphDef* param);
34+
2835
}
2936
}
3037

modules/dnn/test/test_tf_importer.cpp

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,61 @@ TEST(Test_TensorFlow, inception_accuracy)
6969
normAssert(ref, out);
7070
}
7171

72+
static bool readFileInMemory(const string& filename, char** data, size_t* len)
73+
{
74+
if (len == NULL || data == NULL)
75+
return false;
76+
77+
FILE *f = fopen(filename.c_str(), "rb");
78+
if (f == NULL)
79+
return false;
80+
81+
fseek(f, 0, SEEK_END);
82+
*len = ftell(f);
83+
fseek(f, 0, SEEK_SET);
84+
85+
*data = (char*)malloc(*len + 1);
86+
(void)fread(*data, *len, 1, f);
87+
fclose(f);
88+
89+
return true;
90+
}
91+
7292
static std::string path(const std::string& file)
7393
{
7494
return findDataFile("dnn/tensorflow/" + file, false);
7595
}
7696

7797
static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
78-
double l1 = 1e-5, double lInf = 1e-4)
98+
double l1 = 1e-5, double lInf = 1e-4,
99+
bool memoryLoad = false)
79100
{
80101
std::string netPath = path(prefix + "_net.pb");
81102
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
82103
std::string inpPath = path(prefix + "_in.npy");
83104
std::string outPath = path(prefix + "_out.npy");
84105

85-
Net net = readNetFromTensorflow(netPath, netConfig);
106+
Net net;
107+
if (memoryLoad)
108+
{
109+
// Load files into a memory buffers
110+
char *dataModel = NULL;
111+
size_t lenModel = 0;
112+
ASSERT_TRUE(readFileInMemory(netPath, &dataModel, &lenModel));
113+
114+
char *dataConfig = NULL;
115+
size_t lenConfig = 0;
116+
if (hasText)
117+
ASSERT_TRUE(readFileInMemory(netConfig, &dataConfig, &lenConfig));
118+
119+
net = readNetFromTensorflow(dataModel, lenModel, dataConfig, lenConfig);
120+
ASSERT_FALSE(net.empty());
121+
122+
free(dataModel);
123+
free(dataConfig);
124+
}
125+
else
126+
net = readNetFromTensorflow(netPath, netConfig);
86127

87128
cv::Mat input = blobFromNPY(inpPath);
88129
cv::Mat target = blobFromNPY(outPath);
@@ -216,4 +257,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
216257
runTensorFlowNet("resize_nearest_neighbor");
217258
}
218259

260+
TEST(Test_TensorFlow, memory_read)
261+
{
262+
double l1 = 1e-5;
263+
double lInf = 1e-4;
264+
runTensorFlowNet("lstm", true, l1, lInf, true);
265+
266+
runTensorFlowNet("batch_norm", false, l1, lInf, true);
267+
runTensorFlowNet("fused_batch_norm", false, l1, lInf, true);
268+
runTensorFlowNet("batch_norm_text", true, l1, lInf, true);
269+
}
270+
219271
}

0 commit comments

Comments
 (0)