Skip to content

Commit

Permalink
Merge pull request #688 from fantes/detection_calibration
Browse files Browse the repository at this point in the history
Detection calibration
  • Loading branch information
beniz committed Feb 5, 2020
2 parents 7d47af6 + fa96e93 commit 65fbb36
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 14 deletions.
85 changes: 75 additions & 10 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,18 @@ namespace dd
ad_res.add("learning_rate",this->get_meas("learning_rate"));
APIData ad_out = ad.getobj("parameters").getobj("output");

bool output_logits = false;
if (ad_out.has("logits"))
output_logits = ad_out.get("logits").get<bool>();

boost::shared_ptr<Blob<float>> logits_blob;
if (ad_out.has("logits_blob")) {
std::string logits_blob_name = ad_out.get("logits_blob").get<std::string>();
logits_blob = findBlobByName(net,logits_blob_name);
output_logits = true;
}


if (ad.getobj("parameters").getobj("mllib").has("ignore_label"))
{
int ignore_label = ad.getobj("parameters").getobj("mllib").get("ignore_label").get<int>();
Expand All @@ -1681,6 +1693,7 @@ namespace dd
inputc.reset_dv_test();
std::map<int,std::map<int,std::vector<std::pair<float,int>>>> all_true_pos;
std::map<int,std::map<int,std::vector<std::pair<float,int>>>> all_false_pos;
std::map<int,std::map<int,std::vector<std::vector<float>>>> all_logits_pos;
std::map<int,std::map<int,int>> all_num_pos;
while(true)
{
Expand Down Expand Up @@ -1869,33 +1882,49 @@ namespace dd
else if (inputc._bbox)
{
boost::shared_ptr<Blob<float>> detection_eval = findBlobByName(net,"detection_eval");
if (detection_eval->width() != 5)
int evalWidth = 5;

if ((detection_eval->width() != 5 && !output_logits)
|| ((detection_eval->width() != 5 + _nclasses) && output_logits))
throw MLLibBadParamException("wrong width in bbox result");
if (output_logits && (detection_eval->width() == 5 + _nclasses))
{
this->_logger->info("will ouput raw logits");
evalWidth = 5 + _nclasses;
}

int pos = tresults;
const float *result_vec = detection_eval->cpu_data();
int num_det = detection_eval->height();
for (int k=0;k<num_det;k++)
{
int item_id = static_cast<int>(result_vec[k * 5]);
int label = static_cast<int>(result_vec[k * 5 + 1]);
int item_id = static_cast<int>(result_vec[k * evalWidth]);
int label = static_cast<int>(result_vec[k * evalWidth + 1]);
if (item_id == -1) {
// Special row of storing number of positives for a label.
if (all_num_pos[pos].find(label) == all_num_pos[pos].end()) {
all_num_pos[pos][label] = static_cast<int>(result_vec[k * 5 + 2]);
all_num_pos[pos][label] = static_cast<int>(result_vec[k * evalWidth + 2]);
} else {
all_num_pos[pos][label] += static_cast<int>(result_vec[k * 5 + 2]);
all_num_pos[pos][label] += static_cast<int>(result_vec[k * evalWidth + 2]);
}
} else {
// Normal row storing detection status.
float score = result_vec[k * 5 + 2];
int tp = static_cast<int>(result_vec[k * 5 + 3]);
int fp = static_cast<int>(result_vec[k * 5 + 4]);
float score = result_vec[k * evalWidth + 2];
int tp = static_cast<int>(result_vec[k * evalWidth + 3]);
int fp = static_cast<int>(result_vec[k * evalWidth + 4]);
if (tp == 0 && fp == 0) {
// Ignore such case. It happens when a detection bbox is matched to
// a difficult gt bbox and we don't evaluate on difficult gt bbox.
this->_logger->warn("skipping bbox");
continue;
}
if (output_logits)
{
std::vector<float> logits;
for (int l = 0; l< _nclasses; ++l)
logits.push_back(result_vec[k*evalWidth + 5 + l]);
all_logits_pos[pos][label].push_back(logits);
}
all_true_pos[pos][label].push_back(std::make_pair(score, tp));
all_false_pos[pos][label].push_back(std::make_pair(score, fp));
}
Expand All @@ -1921,6 +1950,14 @@ namespace dd
throw MLLibInternalException("Missing output_blob num_pos: " + std::to_string(i));
const std::map<int, int>& num_pos = all_num_pos.find(i)->second;

std::map<int, std::vector<std::vector<float>>> logits_pos;
if (output_logits)
{
if (all_logits_pos.find(i) == all_logits_pos.end())
throw MLLibInternalException("Missing output_blob logits_pos: " + std::to_string(i));
logits_pos = all_logits_pos.find(i)->second;
}

// Sort true_pos and false_pos with descend scores.
std::vector<APIData> vbad;
for (std::map<int, int>::const_iterator it = num_pos.begin();
Expand All @@ -1941,6 +1978,25 @@ namespace dd
}
const std::vector<std::pair<float, int> >& label_false_pos =
false_pos.find(label)->second;

if (output_logits)
{
const std::vector<std::vector<float>> & label_logits =
logits_pos.find(label)->second;

std::vector<APIData> ll;
for (size_t v = 0; v< label_logits.size(); ++v)
{
APIData l;
std::vector<double> lldbl;
lldbl.insert(lldbl.end(), label_logits[v].begin(), label_logits[v].end());
l.add("logits", lldbl);
ll.push_back(l);
}
lbad.add("all_logits", ll);
}



//XXX: AP computed here, store in apidata instead
std::vector<double> tp_d;
Expand Down Expand Up @@ -2059,6 +2115,16 @@ namespace dd
else if ((!_regression && !_autoencoder)|| _ntargets == 1)
{
double target = dv_labels.at(j);
if (output_logits)
{
std::vector<double> logits;
for (int k=0;k<nout;k++)
{
logits.push_back(logits_blob->cpu_data()[j*scperel+k]);
}
bad.add("logits",logits);
}
std::vector<double> logits;
for (int k=0;k<nout;k++)
{
predictions.push_back(lresults[slot]->cpu_data()[j*scperel+k]);
Expand Down Expand Up @@ -4056,10 +4122,9 @@ namespace dd
template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
boost::shared_ptr<Blob<float>> CaffeLib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::findBlobByName(const caffe::Net<float> *net, const std::string blob_name)
{
const std::vector<int> output_blob_indices = net->output_blob_indices();
for (unsigned int i =0; i<net->blob_names().size(); ++i)
{
if (net->blob_names()[i] == blob_name)
if (net->blob_names()[i] == blob_name)
return net->blobs()[i];
}
return nullptr;
Expand Down
39 changes: 35 additions & 4 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1542,11 +1542,13 @@ namespace dd
std::vector<std::string> preds;
std::vector<std::string> targets;
std::vector<double> confs;
std::vector<std::vector<double>> logits;
for (int i=0;i<batch_size;i++)
{
APIData bad = ad.getobj(std::to_string(i));
std::vector<double> predictions = bad.get("pred").get<std::vector<double>>();
double target = bad.get("target").get<double>();
logits.push_back(bad.get("logits").get<std::vector<double>>());
if (target < 0)
throw OutputConnectorBadParamException("negative supervised discrete target (e.g. wrong use of label_offset ?");
else if (target >= nclasses)
Expand All @@ -1569,6 +1571,18 @@ namespace dd
raw_res.add("truths",targets);
raw_res.add("estimations",preds);
raw_res.add("confidences",confs);
if (logits.size() >0)
{
std::vector<APIData> adlogit;
for (std::vector<double> l: logits)
{
APIData lad;
lad.add("logits",l);
adlogit.push_back(lad);
}
raw_res.add("all_logits",adlogit);
}

return raw_res;
}

Expand Down Expand Up @@ -1632,6 +1646,8 @@ namespace dd
std::vector<std::string> preds;
std::vector<std::string> targets;
std::vector<double> confs;
std::vector<APIData> really_all_logits;
bool output_logits = false;
APIData bad = ad.getobj("0");
int pos_count = ad.get("pos_count").get<int>();
for (int i=0;i<pos_count;i++)
Expand All @@ -1655,10 +1671,6 @@ namespace dd
preds.push_back(clnames[label]);
confs.push_back(tp_d[k]);
}
}
//below false positives
for (unsigned int k = 0; k<fp_d.size(); ++k)
{
if (fp_i[k] == 1)
{
preds.push_back(clnames[label]);
Expand All @@ -1677,11 +1689,30 @@ namespace dd
confs.push_back(1.0);
preds.push_back("NO_DETECTION");
}

if (vbad.at(j).has("all_logits"))
{
output_logits = true;
const std::vector<APIData>& logits = vbad.at(j).getv("all_logits");
really_all_logits.insert(really_all_logits.end(),logits.begin(), logits.end());
for (int k= 0; k<(num_pos - ntp); ++k)
{
APIData background_logits_ad;
std::vector<double> background_logits;
background_logits.push_back(0.5+0.5/(float)clnames.size());
for (unsigned int c = 1; c<clnames.size(); ++c)
background_logits.push_back(1.0/(float)clnames.size()/2.0);
background_logits_ad.add("logits",background_logits);
really_all_logits.push_back(background_logits_ad);
}
}
}
}
raw_res.add("truths", targets);
raw_res.add("estimations", preds);
raw_res.add("confidences", confs);
if (output_logits)
raw_res.add("all_logits", really_all_logits);
return raw_res;
}

Expand Down

0 comments on commit 65fbb36

Please sign in to comment.