Skip to content

Commit

Permalink
Merge pull request #2345 from iamshnoo/multilabel_softmargin_loss
Browse files Browse the repository at this point in the history
Adding MultiLabel SoftMargin Loss
  • Loading branch information
jeffin143 committed Aug 14, 2021
2 parents 91ab39d + bf61e97 commit 3f82130
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 0 deletions.
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
### mlpack ?.?.?
###### ????-??-??
* Added `Multi Label Soft Margin Loss` loss function for neural networks
(#2345).

* Added Decision Tree Regressor (#2905). It can be used using the class
`mlpack::tree::DecisionTreeRegressor`. It is accessible only though C++.

Expand Down
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/loss_functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ set(SOURCES
mean_squared_error_impl.hpp
mean_squared_logarithmic_error.hpp
mean_squared_logarithmic_error_impl.hpp
multilabel_softmargin_loss.hpp
multilabel_softmargin_loss_impl.hpp
negative_log_likelihood.hpp
negative_log_likelihood_impl.hpp
poisson_nll_loss.hpp
Expand Down
114 changes: 114 additions & 0 deletions src/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/**
* @file methods/ann/loss_functions/multilabel_softmargin_loss.hpp
* @author Anjishnu Mukherjee
*
* Definition of the Multi Label Soft Margin Loss function.
*
* It is a criterion that optimizes a multi-label one-versus-all loss based on
* max-entropy, between input x and target y of size (N, C) where N is the
* batch size and C is the number of classes.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_HPP
#define MLPACK_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
*/
template <
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class MultiLabelSoftMarginLoss
{
public:
/**
* Create the MultiLabelSoftMarginLoss object.
*
* @param reduction Specifies the reduction to apply to the output. If false,
* 'mean' reduction is used, where sum of the output will be
* divided by the number of elements in the output. If
* true, 'sum' reduction is used and the output will be
* summed. It is set to true by default.
* @param weights A manual rescaling weight given to each class. It is a
* (1, numClasses) row vector.
*/
MultiLabelSoftMarginLoss(const bool reduction = true,
const arma::rowvec& weights = arma::rowvec());

/**
* Computes the Multi Label Soft Margin Loss function.
*
* @param input Input data used for evaluating the specified function.
* @param target The target vector with same shape as input.
*/
template<typename InputType, typename TargetType>
typename InputType::elem_type Forward(const InputType& input,
const TargetType& target);

/**
* Ordinary feed backward pass of a neural network.
*
* @param input The propagated input activation.
* @param target The target vector.
* @param output The calculated error.
*/
template<typename InputType, typename TargetType, typename OutputType>
void Backward(const InputType& input,
const TargetType& target,
OutputType& output);

//! Get the output parameter.
OutputDataType& OutputParameter() const { return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() { return outputParameter; }

//! Get the weights assigned to each class.
const arma::rowvec& ClassWeights() const { return classWeights; }
//! Modify the weights assigned to each class.
arma::rowvec& ClassWeights() { return classWeights; }

//! Get the type of reduction used.
bool Reduction() const { return reduction; }
//! Modify the type of reduction used.
bool& Reduction() { return reduction; }

/**
* Serialize the layer.
*/
template<typename Archive>
void serialize(Archive& ar, const unsigned int /* version */);

private:
//! Locally-stored output parameter object.
OutputDataType outputParameter;

//! The boolean value that tells if reduction is sum or mean.
bool reduction;

//! A (1, numClasses) shaped vector with weights for each class.
arma::rowvec classWeights;

// An internal parameter used during initialisation of class weights.
bool weighted;
}; // class MultiLabelSoftMarginLoss

} // namespace ann
} // namespace mlpack

// include implementation.
#include "multilabel_softmargin_loss_impl.hpp"

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/**
* @file methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp
* @author Anjishnu Mukherjee
*
* Implementation of the Multi Label Soft Margin Loss function.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_IMPL_HPP
#define MLPACK_METHODS_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_IMPL_HPP

// In case it hasn't been included.
#include "multilabel_softmargin_loss.hpp"

namespace mlpack {
namespace ann /** Artifical Neural Network. */ {

template<typename InputDataType, typename OutputDataType>
MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::
MultiLabelSoftMarginLoss(
const bool reduction,
const arma::rowvec& weights) :
reduction(reduction),
weighted(false)
{
if (weights.n_elem)
{
classWeights = weights;
weighted = true;
}
}

template<typename InputDataType, typename OutputDataType>
template<typename InputType, typename TargetType>
typename InputType::elem_type
MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::Forward(
const InputType& input, const TargetType& target)
{
if (!weighted)
{
classWeights.ones(1, input.n_cols);
weighted = true;
}

InputType logSigmoid = arma::log((1 / (1 + arma::exp(-input))));
InputType logSigmoidNeg = arma::log(1 / (1 + arma::exp(input)));
InputType loss = arma::mean(arma::sum(-(target % logSigmoid +
(1 - target) % logSigmoidNeg)) % classWeights, 1);

if (reduction)
return arma::as_scalar(loss);

return arma::as_scalar(loss / input.n_rows);
}

template<typename InputDataType, typename OutputDataType>
template<typename InputType, typename TargetType, typename OutputType>
void MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::Backward(
const InputType& input,
const TargetType& target,
OutputType& output)
{
output.set_size(size(input));
InputType sigmoid = (1 / (1 + arma::exp(-input)));
output = -(target % (1 - sigmoid) - (1 - target) % sigmoid) %
arma::repmat(classWeights, target.n_rows, 1) / output.n_elem;

if (reduction)
output = output * input.n_rows;
}

template<typename InputDataType, typename OutputDataType>
template<typename Archive>
void MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::serialize(
Archive& ar,
const unsigned int /* version */)
{
ar(CEREAL_NVP(classWeights));
ar(CEREAL_NVP(reduction));
}

} // namespace ann
} // namespace mlpack

#endif
117 changes: 117 additions & 0 deletions src/mlpack/tests/loss_functions_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <mlpack/methods/ann/loss_functions/hinge_embedding_loss.hpp>
#include <mlpack/methods/ann/loss_functions/cosine_embedding_loss.hpp>
#include <mlpack/methods/ann/loss_functions/l1_loss.hpp>
#include <mlpack/methods/ann/loss_functions/multilabel_softmargin_loss.hpp>
#include <mlpack/methods/ann/loss_functions/soft_margin_loss.hpp>
#include <mlpack/methods/ann/loss_functions/mean_absolute_percentage_error.hpp>
#include <mlpack/methods/ann/loss_functions/triplet_margin_loss.hpp>
Expand Down Expand Up @@ -1033,3 +1034,119 @@ TEST_CASE("HingeLossTest", "[LossFunctionsTest]")
REQUIRE(output.n_rows == input.n_rows);
REQUIRE(output.n_cols == input.n_cols);
}

/**
* Simple test for the MultiLabel Softmargin Loss function.
*/
TEST_CASE("MultiLabelSoftMarginLossTest", "[LossFunctionsTest]")
{
arma::mat input, target, output, expectedOutput;
double loss;
MultiLabelSoftMarginLoss<> module1;
MultiLabelSoftMarginLoss<> module2(false);

input = arma::mat("0.1778 0.0957 0.1397 0.1203 0.2403 0.1925 -0.2264 -0.3400 "
"-0.3336");
target = arma::mat("0 1 0 1 0 0 0 0 1");
input.reshape(3, 3);
target.reshape(3, 3);

// Test for sum reduction.

// Calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='sum').
expectedOutput = arma::mat("0.1814 -0.1587 0.1783 -0.1567 0.1866 0.1827 "
"0.1479 0.1386 -0.1942");
expectedOutput.reshape(3, 3);

// Test the Forward function. Loss should be 2.14829.
// Value calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='sum').
loss = module1.Forward(input, target);
REQUIRE(loss == Approx(2.14829).epsilon(1e-5));

// Test the Backward function.
module1.Backward(input, target, output);
REQUIRE(arma::as_scalar(arma::accu(output)) ==
Approx(0.505909).epsilon(1e-5));
REQUIRE(output.n_rows == input.n_rows);
REQUIRE(output.n_cols == input.n_cols);
CheckMatrices(output, expectedOutput, 0.1);

// Test for mean reduction.

// Calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='mean').
expectedOutput = arma::mat("0.0605 -0.0529 0.0594 -0.0522 0.0622 0.0609 "
"0.0493 0.0462 -0.0647");
expectedOutput.reshape(3, 3);

// Test the Forward function. Loss should be 0.716095.
// Value calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='mean').
loss = module2.Forward(input, target);
REQUIRE(loss == Approx(0.716095).epsilon(1e-5));

// Test the Backward function.
module2.Backward(input, target, output);
REQUIRE(arma::as_scalar(arma::accu(output)) ==
Approx(0.168636).epsilon(1e-5));
REQUIRE(output.n_rows == input.n_rows);
REQUIRE(output.n_cols == input.n_cols);
CheckMatrices(output, expectedOutput, 0.1);
}

/**
* Simple test for the MultiLabel Softmargin Loss function.
*/
TEST_CASE("MultiLabelSoftMarginLossWeightedTest", "[LossFunctionsTest]")
{
arma::mat input, target, output, expectedOutput;
arma::rowvec weights;
double loss;
weights = arma::mat("1 2 3");
MultiLabelSoftMarginLoss<> module1(true, weights);
MultiLabelSoftMarginLoss<> module2(false, weights);

input = arma::mat("0.1778 0.0957 0.1397 0.2256 0.1203 0.2403 0.1925 0.3144 "
"-0.2264 -0.3400 -0.3336 -0.8695");
target = arma::mat("0 1 0 1 1 0 0 0 0 0 1 0");
input.reshape(4, 3);
target.reshape(4, 3);

// Test for sum reduction.

// Calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='sum').
expectedOutput = arma::mat("0.1814 -0.1587 0.1783 -0.1479 -0.3133 0.3732 "
"0.3653 0.3853 0.4436 0.4158 -0.5826 0.2954");
expectedOutput.reshape(4, 3);

// Test the Forward function. Loss should be 5.35057.
// Value calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='sum').
loss = module1.Forward(input, target);
REQUIRE(loss == Approx(5.35057).epsilon(1e-5));

// Test the Backward function.
module1.Backward(input, target, output);
REQUIRE(arma::as_scalar(arma::accu(output)) ==
Approx(1.43577).epsilon(1e-5));
REQUIRE(output.n_rows == input.n_rows);
REQUIRE(output.n_cols == input.n_cols);
CheckMatrices(output, expectedOutput, 0.1);

// Test for mean reduction.

// Calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='mean').
expectedOutput = arma::mat("0.0454 -0.0397 0.0446 -0.0370 -0.0783 0.0933 "
"0.0913 0.0963 0.1109 0.1040 -0.1457 0.0738");
expectedOutput.reshape(4, 3);

// Test the Forward function. Loss should be 1.33764.
// Value calculated using torch.nn.MultiLabelSoftMarginLoss(reduction='mean').
loss = module2.Forward(input, target);
REQUIRE(loss == Approx(1.33764).epsilon(1e-5));

// Test the Backward function.
module2.Backward(input, target, output);
REQUIRE(arma::as_scalar(arma::accu(output)) ==
Approx(0.358943).epsilon(1e-5));
REQUIRE(output.n_rows ==input.n_rows);
REQUIRE(output.n_cols == input.n_cols);
CheckMatrices(output, expectedOutput, 0.1);
}

0 comments on commit 3f82130

Please sign in to comment.