/
caffe2model.cc
129 lines (114 loc) · 3.87 KB
/
caffe2model.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/**
* DeepDetect
* Copyright (c) 2018 Jolibrain
* Author: Julien Chicha
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/
#include "backends/caffe2/caffe2model.h"
#include "mllibstrategy.h"
#include "utils/fileops.hpp"
namespace dd {
Caffe2Model::Caffe2Model(const APIData &ad, APIData &adg,
const std::shared_ptr<spdlog::logger> &logger)
:MLModel(ad, adg, logger)
{
std::map<std::string, std::string *> names =
{
{ "predictf", &_predict },
{ "initf", &_init },
{ "corresp", &_corresp },
{ "weights", &_weights }
};
// Update from API
for (auto &it : names) {
if (ad.has(it.first)) {
*it.second = ad.get(it.first).get<std::string>();
}
}
// Register repositories
this->_repo = ad.get("repository").get<std::string>();
this->_mlmodel_template_repo = ad.has("templates") ?
ad.get("templates").get<std::string>() : "caffe2"; // Default
update_from_repository(spdlog::get("api"));
}
void Caffe2Model::update_from_repository(const std::shared_ptr<spdlog::logger> &logger) {
std::map<std::string, std::string *> names =
{
{ "predict_net.pb", &_predict },
{ "init_net.pb", &_init },
{ "corresp.txt", &_corresp },
{ "mean.pb", &_meanfile },
{ "init_state.pb", &_init_state },
{ "dbreader_state.pb", &_dbreader_state },
{ "dbreader_train_state.pb", &_dbreader_train_state },
{ "iter_state.pb", &_iter_state },
{ "lr_state.pb", &_lr_state },
};
// List available files
std::unordered_set<std::string> lfiles;
if (fileops::list_directory(_repo, true, false, false, lfiles)) {
std::string msg("error reading or listing Caffe2 models in repository " + _repo);
logger->error(msg);
throw MLLibBadParamException(msg);
}
for (const std::string &file : lfiles) {
for (auto &it : names) {
// if the file name contains a string from the map
if (file.find(it.first) != std::string::npos) {
// And if the corresponding variable is still uninitialized
if (it.second->empty()) {
*it.second = file;
}
break;
}
}
}
read_corresp_file();
}
void Caffe2Model::get_hcorresp(std::vector<std::string> &clnames) {
int i = 0;
for (std::string &name : clnames) {
name = get_hcorresp(i++);
}
}
void Caffe2Model::write_state(const google::protobuf::Message &init,
const std::map<std::string, std::string> &blobs) {
for (auto it : blobs) {
std::ofstream(_repo + "/" + it.first + "_state.pb") << it.second;
}
std::ofstream f(_repo + "/init_state.pb");
init.SerializeToOstream(&f);
}
void Caffe2Model::list_template_files(const std::string &name,
std::map<std::string, std::string> &files,
bool external_weights) {
// Path manipulation
std::string source = this->_mlmodel_template_repo + '/' + name;
auto set_path = [&](const std::string &net, const std::string &remote="") {
files[remote.empty() ? source + '/' + net + ".pbtxt" : remote] = _repo + '/' + net + ".pb";
};
// Choose the files
set_path("predict_net");
if (!external_weights) {
set_path("init_net");
} else if (_weights.empty()) {
throw MLLibBadParamException("No external weights specified");
} else {
set_path("init_net", _weights);
}
}
}