From be1e597cb67c069ba9940ff241d9aad38ccd37da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Onur=20=C3=87elebi?= Date: Wed, 24 Oct 2018 06:24:06 -0700 Subject: [PATCH] Compute precision/recall for each label Summary: This diff adds a new command to fasttext to display precision/recall score for each individual label : `print-label-scores` It will get predicted labels above given threshold, and compute scores. For example, the question "vinegar softens the bite of raw onions ?" has two labels : "vinegar" and "onions". It will ask fastText to predict labels above given threshold. If there are two such labels : "pickling", "onions", we will obtain : "onions" will have a precision of 100%, "pickling" a precision of 0%, "onions" will have a recall of 100%, "vinegar" will have a recall of 0%. Reviewed By: EdouardGrave Differential Revision: D9991570 fbshipit-source-id: 63cff90f57659d51f5aa1f10243d40e253445aa6 --- src/fasttext.cc | 93 ++++++++++++++++++++++++++++++++++++++++--------- src/fasttext.h | 17 ++++++--- src/main.cc | 42 ++++++++++++++++++++++ src/model.cc | 4 ++- src/model.h | 2 ++ 5 files changed, 136 insertions(+), 22 deletions(-) diff --git a/src/fasttext.cc b/src/fasttext.cc index 78e265770..e6fba1629 100644 --- a/src/fasttext.cc +++ b/src/fasttext.cc @@ -397,26 +397,16 @@ FastText::test(std::istream& in, int32_t k, real threshold) { } void FastText::predict( - std::istream& in, int32_t k, - std::vector>& predictions, + const std::vector& words, + std::vector>& predictions, real threshold) const { - std::vector words, labels; - predictions.clear(); - dict_->getLine(in, words, labels); - predictions.clear(); if (words.empty()) { return; } Vector hidden(args_->dim); Vector output(dict_->nlabels()); - std::vector> modelPredictions; - model_->predict(words, k, threshold, modelPredictions, hidden, output); - for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); - it++) { - predictions.push_back( - std::make_pair(it->first, dict_->getLabel(it->second))); - } + model_->predict(words, k, threshold, predictions, hidden, output); } void FastText::predict( @@ -424,10 +414,12 @@ void FastText::predict( int32_t k, bool print_prob, real threshold) { - std::vector> predictions; + std::vector> predictions; while (in.peek() != EOF) { + std::vector words, labels; + dict_->getLine(in, words, labels); predictions.clear(); - predict(in, k, predictions, threshold); + predict(k, words, predictions, threshold); if (predictions.empty()) { std::cout << std::endl; continue; @@ -436,7 +428,7 @@ void FastText::predict( if (it != predictions.cbegin()) { std::cout << " "; } - std::cout << it->second; + std::cout << dict_->getLabel(it->second); if (print_prob) { std::cout << " " << std::exp(it->first); } @@ -445,6 +437,75 @@ void FastText::predict( } } +void FastText::printLabelStats( + const std::vector& labelStats) const { + const static double kUnknownValue = -1.0; + auto computeF1Score = [](double precision, double recall) -> double { + if (precision == kUnknownValue || recall == kUnknownValue) { + return kUnknownValue; + } + if (precision != 0 && recall != 0) { + return 2 * precision * recall / (precision + recall); + } + return 0.; + }; + auto displayScore = [](double value) { + std::cout << std::fixed; + std::cout.precision(6); + if (value == kUnknownValue) { + std::cout << "--------"; + } else { + std::cout << value; + } + }; + + for (size_t labelId = 0; labelId < labelStats.size(); labelId++) { + const auto& labelStat = labelStats[labelId]; + double precision = labelStat.predicted + ? ((double)labelStat.predictedGold / labelStat.predicted) + : kUnknownValue; + double recall = labelStat.gold + ? ((double)labelStat.predictedGold / labelStat.gold) + : kUnknownValue; + double f1score = computeF1Score(precision, recall); + std::cout << "F1-Score : "; + displayScore(f1score); + std::cout << " Precision : "; + displayScore(precision); + std::cout << " Recall : "; + displayScore(recall); + std::cout << " " << dict_->getLabel(labelId) << std::endl; + } +} + +void FastText::printLabelStats(std::istream& in, int32_t k, real threshold) + const { + std::vector> predictions; + size_t labelsSize = dict_->nlabels(); + std::vector labelStats(labelsSize); + while (in.peek() != EOF) { + std::vector words, gold; + dict_->getLine(in, words, gold); + predictions.clear(); + predict(k, words, predictions, threshold); + for (const auto& goldLabelId : gold) { + assert(goldLabelId < labelsSize); + labelStats[goldLabelId].gold++; + } + for (const auto& predictedLabel : predictions) { + int32_t predictedLabelId = predictedLabel.second; + assert(predictedLabelId < labelsSize); + labelStats[predictedLabelId].predicted++; + if (auto itFound = + std::find(gold.begin(), gold.end(), predictedLabelId) != + gold.end()) { + labelStats[predictedLabelId].predictedGold++; + } + } + } + printLabelStats(labelStats); +} + void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) { svec.zero(); if (args_->model == model_name::sup) { diff --git a/src/fasttext.h b/src/fasttext.h index 92646617f..18daa6c6a 100644 --- a/src/fasttext.h +++ b/src/fasttext.h @@ -53,7 +53,18 @@ class FastText { bool quant_; int32_t version; + struct LabelStats { + int32_t gold, predicted, predictedGold; + LabelStats() : gold(0), predicted(0), predictedGold(0) {} + }; + void startThreads(); + void predict( + int32_t, + const std::vector&, + std::vector>&, + real = 0.0) const; + void printLabelStats(const std::vector& labelStats) const; public: FastText(); @@ -95,11 +106,7 @@ class FastText { void quantize(const Args); std::tuple test(std::istream&, int32_t, real = 0.0); void predict(std::istream&, int32_t, bool, real = 0.0); - void predict( - std::istream&, - int32_t, - std::vector>&, - real = 0.0) const; + void printLabelStats(std::istream&, int32_t, real = 0.0) const; void ngramVectors(std::string); void precomputeWordVectors(Matrix&); void findNN( diff --git a/src/main.cc b/src/main.cc index 9bf621936..717c02e71 100644 --- a/src/main.cc +++ b/src/main.cc @@ -22,6 +22,7 @@ void printUsage() { << " supervised train a supervised classifier\n" << " quantize quantize a model to reduce the memory usage\n" << " test evaluate a supervised classifier\n" + << " test-label print labels with precision and recall scores\n" << " predict predict most likely labels\n" << " predict-prob predict most likely labels with probabilities\n" << " skipgram train a skipgram model\n" @@ -59,6 +60,16 @@ void printPredictUsage() { << std::endl; } +void printPrintLabelStatsUsage() { + std::cerr + << "usage: fasttext test-label [] []\n\n" + << " model filename\n" + << " test data filename\n" + << " (optional; 1 by default) predict top k labels\n" + << " (optional; 0.0 by default) probability threshold\n" + << std::endl; +} + void printPrintWordVectorsUsage() { std::cerr << "usage: fasttext print-word-vectors \n\n" << " model filename\n" @@ -186,6 +197,35 @@ void predict(const std::vector& args) { exit(0); } +void printLabelStats(const std::vector& args) { + if (args.size() < 4 || args.size() > 6) { + printPrintLabelStatsUsage(); + exit(EXIT_FAILURE); + } + int32_t k = 1; + real threshold = 0.0; + if (args.size() > 4) { + k = std::stoi(args[4]); + if (args.size() > 5) { + threshold = std::stof(args[5]); + } + } + + FastText fasttext; + fasttext.loadModel(std::string(args[2])); + + std::string infile(args[3]); + std::ifstream ifs(infile); + if (!ifs.is_open()) { + std::cerr << "Input file cannot be opened!" << std::endl; + exit(EXIT_FAILURE); + } + fasttext.printLabelStats(ifs, k, threshold); + ifs.close(); + + exit(0); +} + void printWordVectors(const std::vector args) { if (args.size() != 3) { printPrintWordVectorsUsage(); @@ -355,6 +395,8 @@ int main(int argc, char** argv) { analogies(args); } else if (command == "predict" || command == "predict-prob") { predict(args); + } else if (command == "test-label") { + printLabelStats(args); } else if (command == "dump") { dump(args); } else { diff --git a/src/model.cc b/src/model.cc index 7ffb4ba21..da98b6fa5 100644 --- a/src/model.cc +++ b/src/model.cc @@ -152,7 +152,9 @@ void Model::predict( std::vector>& heap, Vector& hidden, Vector& output) const { - if (k <= 0) { + if (k == Model::kUnlimitedPredictions) { + k = osz_; + } else if (k <= 0) { throw std::invalid_argument("k needs to be 1 or higher!"); } if (args_->model != model_name::sup) { diff --git a/src/model.h b/src/model.h index 3dd605c3f..1e3d62a6c 100644 --- a/src/model.h +++ b/src/model.h @@ -118,6 +118,8 @@ class Model { bool quant_; void setQuantizePointer(std::shared_ptr, std::shared_ptr, bool); + + static const int32_t kUnlimitedPredictions = -1; }; } // namespace fasttext