Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
One-vs-all cross-entropy loss
Browse files Browse the repository at this point in the history
Summary: The new option for the `loss` parameter allows to compute the loss as a sum of cross-entropy of each independent unit of the output.

Reviewed By: EdouardGrave

Differential Revision: D10853638

fbshipit-source-id: dc4c56e25c89c9da1a33bda1b29db781080794fd
  • Loading branch information
Celebio authored and facebook-github-bot committed Nov 27, 2018
1 parent 71b4101 commit 8850c51
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 20 deletions.
1 change: 1 addition & 0 deletions python/fastText/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
.value("hs", fasttext::loss_name::hs)
.value("ns", fasttext::loss_name::ns)
.value("softmax", fasttext::loss_name::softmax)
.value("ova", fasttext::loss_name::ova)
.export_values();

m.def(
Expand Down
7 changes: 6 additions & 1 deletion src/args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ std::string Args::lossToString(loss_name ln) const {
return "ns";
case loss_name::softmax:
return "softmax";
case loss_name::ova:
return "one-vs-all";
}
return "Unknown loss!"; // should never happen
}
Expand Down Expand Up @@ -129,6 +131,9 @@ void Args::parseArgs(const std::vector<std::string>& args) {
loss = loss_name::ns;
} else if (args.at(ai + 1) == "softmax") {
loss = loss_name::softmax;
} else if (
args.at(ai + 1) == "one-vs-all" || args.at(ai + 1) == "ova") {
loss = loss_name::ova;
} else {
std::cerr << "Unknown loss: " << args.at(ai + 1) << std::endl;
printHelp();
Expand Down Expand Up @@ -229,7 +234,7 @@ void Args::printTrainingHelp() {
<< " -ws size of the context window [" << ws << "]\n"
<< " -epoch number of epochs [" << epoch << "]\n"
<< " -neg number of negatives sampled [" << neg << "]\n"
<< " -loss loss function {ns, hs, softmax} ["
<< " -loss loss function {ns, hs, softmax, one-vs-all} ["
<< lossToString(loss) << "]\n"
<< " -thread number of threads [" << thread << "]\n"
<< " -pretrainedVectors pretrained word vectors for supervised learning ["
Expand Down
2 changes: 1 addition & 1 deletion src/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace fasttext {

enum class model_name : int { cbow = 1, sg, sup };
enum class loss_name : int { hs = 1, ns, softmax };
enum class loss_name : int { hs = 1, ns, softmax, ova };

class Args {
protected:
Expand Down
14 changes: 9 additions & 5 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,13 @@ void FastText::supervised(
if (labels.size() == 0 || line.size() == 0) {
return;
}
std::uniform_int_distribution<> uniform(0, labels.size() - 1);
int32_t i = uniform(model.rng);
model.update(line, labels[i], lr);
if (args_->loss == loss_name::ova) {
model.update(line, labels, Model::kAllLabelsAsTarget, lr);
} else {
std::uniform_int_distribution<> uniform(0, labels.size() - 1);
int32_t i = uniform(model.rng);
model.update(line, labels, i, lr);
}
}

void FastText::cbow(Model& model, real lr, const std::vector<int32_t>& line) {
Expand All @@ -361,7 +365,7 @@ void FastText::cbow(Model& model, real lr, const std::vector<int32_t>& line) {
bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend());
}
}
model.update(bow, line[w], lr);
model.update(bow, line, w, lr);
}
}

Expand All @@ -375,7 +379,7 @@ void FastText::skipgram(
const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w]);
for (int32_t c = -boundary; c <= boundary; c++) {
if (c != 0 && w + c >= 0 && w + c < line.size()) {
model.update(ngrams, line[w + c], lr);
model.update(ngrams, line, w + c, lr);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/meter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

#include "meter.h"
#include "utils.h"

#include <algorithm>
#include <cmath>
Expand All @@ -26,8 +27,7 @@ void Meter::log(
for (const auto& prediction : predictions) {
labelMetrics_[prediction.second].predicted++;

if (std::find(labels.begin(), labels.end(), prediction.second) !=
labels.end()) {
if (utils::contains(labels, prediction.second)) {
labelMetrics_[prediction.second].predictedGold++;
metrics_.predictedGold++;
}
Expand Down
71 changes: 61 additions & 10 deletions src/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/

#include "model.h"
#include "utils.h"

#include <assert.h>
#include <algorithm>
Expand Down Expand Up @@ -90,12 +91,23 @@ real Model::hierarchicalSoftmax(int32_t target, real lr) {
return loss;
}

void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
void Model::computeOutput(Vector& hidden, Vector& output) const {
if (quant_ && args_->qout) {
output.mul(*qwo_, hidden);
} else {
output.mul(*wo_, hidden);
}
}

void Model::computeOutputSigmoid(Vector& hidden, Vector& output) const {
computeOutput(hidden, output);
for (int32_t i = 0; i < osz_; i++) {
output[i] = sigmoid(output[i]);
}
}

void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
computeOutput(hidden, output);
real max = output[0], z = 0.0;
for (int32_t i = 0; i < osz_; i++) {
max = std::max(output[i], max);
Expand Down Expand Up @@ -125,6 +137,16 @@ real Model::softmax(int32_t target, real lr) {
return -log(output_[target]);
}

real Model::oneVsAll(const std::vector<int32_t>& targets, real lr) {
real loss = 0.0;

This comment has been minimized.

Copy link
@emesday

emesday Dec 14, 2018

grad_.zero(); is not required?

for (int32_t i = 0; i < osz_; i++) {
bool isMatch = utils::contains(targets, i);
loss += binaryLogistic(i, isMatch, lr);
}

return loss;
}

void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden)
const {
assert(hidden.size() == hsz_);
Expand Down Expand Up @@ -184,7 +206,11 @@ void Model::findKBest(
std::vector<std::pair<real, int32_t>>& heap,
Vector& hidden,
Vector& output) const {
computeOutputSoftmax(hidden, output);
if (args_->loss == loss_name::ova) {
computeOutputSigmoid(hidden, output);
} else {
computeOutputSoftmax(hidden, output);
}
for (int32_t i = 0; i < osz_; i++) {
if (output[i] < threshold) {
continue;
Expand Down Expand Up @@ -237,20 +263,45 @@ void Model::dfs(
dfs(k, threshold, tree[node].right, score + std_log(f), heap, hidden);
}

void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
assert(target >= 0);
assert(target < osz_);
real Model::computeLoss(
const std::vector<int32_t>& targets,
int32_t targetIndex,
real lr) {
real loss = 0.0;

if (args_->loss == loss_name::ns) {
loss = negativeSampling(targets[targetIndex], lr);
} else if (args_->loss == loss_name::hs) {
loss = hierarchicalSoftmax(targets[targetIndex], lr);
} else if (args_->loss == loss_name::softmax) {
loss = softmax(targets[targetIndex], lr);
} else if (args_->loss == loss_name::ova) {
loss = oneVsAll(targets, lr);
} else {
throw std::invalid_argument("Unhandled loss function for this model.");
}

return loss;
}

void Model::update(
const std::vector<int32_t>& input,
const std::vector<int32_t>& targets,
int32_t targetIndex,
real lr) {
if (input.size() == 0) {
return;
}
computeHidden(input, hidden_);
if (args_->loss == loss_name::ns) {
loss_ += negativeSampling(target, lr);
} else if (args_->loss == loss_name::hs) {
loss_ += hierarchicalSoftmax(target, lr);

if (targetIndex == kAllLabelsAsTarget) {
loss_ += computeLoss(targets, -1, lr);
} else {
loss_ += softmax(target, lr);
assert(targetIndex >= 0);
assert(targetIndex < osz_);
loss_ += computeLoss(targets, targetIndex, lr);
}

nexamples_ += 1;

if (args_->model == model_name::sup) {
Expand Down
11 changes: 10 additions & 1 deletion src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Model {
int32_t getNegative(int32_t target);
void initSigmoid();
void initLog();
void computeOutput(Vector&, Vector&) const;

static const int32_t NEGATIVE_TABLE_SIZE = 10000000;

Expand All @@ -75,6 +76,7 @@ class Model {
real negativeSampling(int32_t, real);
real hierarchicalSoftmax(int32_t, real);
real softmax(int32_t, real);
real oneVsAll(const std::vector<int32_t>&, real);

void predict(
const std::vector<int32_t>&,
Expand All @@ -101,8 +103,14 @@ class Model {
std::vector<std::pair<real, int32_t>>&,
Vector&,
Vector&) const;
void update(const std::vector<int32_t>&, int32_t, real);
void update(
const std::vector<int32_t>&,
const std::vector<int32_t>&,
int32_t,
real);
real computeLoss(const std::vector<int32_t>&, int32_t, real);
void computeHidden(const std::vector<int32_t>&, Vector&) const;
void computeOutputSigmoid(Vector&, Vector&) const;
void computeOutputSoftmax(Vector&, Vector&) const;
void computeOutputSoftmax();

Expand All @@ -120,6 +128,7 @@ class Model {
setQuantizePointer(std::shared_ptr<QMatrix>, std::shared_ptr<QMatrix>, bool);

static const int32_t kUnlimitedPredictions = -1;
static const int32_t kAllLabelsAsTarget = -1;
};

} // namespace fasttext
10 changes: 10 additions & 0 deletions src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#pragma once

#include <algorithm>
#include <fstream>
#include <vector>

#if defined(__clang__) || defined(__GNUC__)
#define FASTTEXT_DEPRECATED(msg) __attribute__((__deprecated__(msg)))
Expand All @@ -24,7 +26,15 @@ namespace fasttext {
namespace utils {

int64_t size(std::ifstream&);

void seek(std::ifstream&, int64_t);

template <typename T>
bool contains(const std::vector<T>& container, const T& value) {
return std::find(container.begin(), container.end(), value) !=
container.end();
}

} // namespace utils

} // namespace fasttext

0 comments on commit 8850c51

Please sign in to comment.