Skip to content

Commit

Permalink
fix reverse precision and recall
Browse files Browse the repository at this point in the history
output all precisions recalls and f1 if measure="f1full"
  • Loading branch information
fantes committed Jul 7, 2020
1 parent 6a19b4f commit ecdfad8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
8 changes: 6 additions & 2 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1546,14 +1546,18 @@ namespace dd

for (auto m: meas_str)
{
if (m != "cmdiag" && m != "cmfull" && m != "clacc" && m != "labels" && m!= "cliou") // do not report confusion matrix in server logs

if (m != "cmdiag" && m != "cmfull" && m != "clacc" && m != "labels" && m!= "cliou"
&& m != "precisions" && m != "recalls" && m != "f1s")
// do not report confusion matrix in server logs
{
double mval = meas_obj.get(m).get<double>();
this->_logger->info("{}={}",m,mval);
this->add_meas(m,mval);
this->add_meas_per_iter(m,mval);
}
else if (m == "cmdiag" || m == "clacc" || m == "cliou")
else if (m == "cmdiag" || m == "clacc" || m == "cliou"
|| m == "precisions" || m == "recalls" || m == "f1s")
{
std::vector<double> mdiag = meas_obj.get(m).get<std::vector<double>>();
std::vector<std::string> cnames;
Expand Down
42 changes: 32 additions & 10 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ namespace dd
}
}
bool bf1 = (std::find(measures.begin(),measures.end(),"f1")!=measures.end());
bool bf1full = (std::find(measures.begin(),measures.end(),"f1full")!=measures.end());
bool bmcll = (std::find(measures.begin(),measures.end(),"mcll")!=measures.end());
bool bgini = (std::find(measures.begin(),measures.end(),"gini")!=measures.end());
bool beucll = false;
Expand Down Expand Up @@ -832,16 +833,36 @@ namespace dd
}
}

if (!multilabel && !segmentation && !bbox && bf1)
if (!multilabel && !segmentation && !bbox && (bf1 || bf1full))
{
double f1,precision,recall,acc;
dMat conf_diag,conf_matrix;
f1 = mf1(ad_res,precision,recall,acc,conf_diag,conf_matrix);
dVec precisionV, recallV, f1V;
f1 = mf1(ad_res,precision,recall,acc,precisionV, recallV, f1V, conf_diag,conf_matrix);
meas_out.add("f1",f1);
meas_out.add("precision",precision);
meas_out.add("recall",recall);
meas_out.add("accp",acc);
if (std::find(measures.begin(),measures.end(),"cmdiag")!=measures.end())
if (std::find(measures.begin(), measures.end(), "f1full") != measures.end())
{
std::vector<double> allPrecisions;
std::vector<double> allRecalls;
std::vector<double> allF1s;
for (int i=0; i<precisionV.size(); ++i)
{
allPrecisions.push_back(precisionV(i));
allRecalls.push_back(recallV(i));
allF1s.push_back(f1V(i));
}
meas_out.add("precisions", allPrecisions);
meas_out.add("recalls", allRecalls);
meas_out.add("f1s", allF1s);
if (std::find(measures.begin(),measures.end(),"cmdiag")==measures.end())
meas_out.add("labels",ad_res.get("clnames").get<std::vector<std::string>>());
}

if (std::find(measures.begin(),measures.end(),"cmdiag")!=measures.end())

{
std::vector<double> cmdiagv;
for (int i=0;i<conf_diag.rows();i++)
Expand Down Expand Up @@ -1588,7 +1609,8 @@ namespace dd
}

// measure: F1
static double mf1(const APIData &ad, double &precision, double &recall, double &acc, dMat &conf_diag, dMat &conf_matrix)
static double mf1(const APIData &ad, double &precision, double &recall, double &acc,
dVec &precisionV, dVec &recallV, dVec &f1V, dMat &conf_diag, dMat &conf_matrix)
{
int nclasses = ad.get("nclasses").get<int>();
double f1=0.0;
Expand All @@ -1611,12 +1633,12 @@ namespace dd
dMat conf_rsum = conf_matrix.rowwise().sum();
dMat eps = dMat::Constant(nclasses,1,1e-8);
acc = conf_diag.sum() / conf_matrix.sum();
precision = conf_diag.transpose().cwiseQuotient(conf_csum + eps.transpose()).sum() / static_cast<double>(nclasses);
recall = conf_diag.cwiseQuotient(conf_rsum + eps).sum() / static_cast<double>(nclasses);
if ((precision+recall) > 0)
f1 = (2.0*precision*recall) / (precision+recall);
else
f1 = 0.0;
recallV = (conf_diag.transpose().cwiseQuotient(conf_csum + eps.transpose())).transpose();
recall = recallV.sum() / static_cast<double>(nclasses);
precisionV = conf_diag.cwiseQuotient(conf_rsum + eps);
precision = precisionV.sum() / static_cast<double>(nclasses);
f1V = (2.0*precisionV.cwiseProduct(recallV)).cwiseQuotient(precisionV + recallV + eps);
f1 = f1V.sum() / static_cast<double>(nclasses);
conf_diag = conf_diag.transpose().cwiseQuotient(conf_csum+eps.transpose()).transpose();
for (int i=0;i<conf_matrix.cols();i++)
if (conf_csum(i) > 0)
Expand Down

0 comments on commit ecdfad8

Please sign in to comment.