-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2345 from iamshnoo/multilabel_softmargin_loss
Adding MultiLabel SoftMargin Loss
- Loading branch information
Showing
5 changed files
with
324 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
src/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/** | ||
* @file methods/ann/loss_functions/multilabel_softmargin_loss.hpp | ||
* @author Anjishnu Mukherjee | ||
* | ||
* Definition of the Multi Label Soft Margin Loss function. | ||
* | ||
* It is a criterion that optimizes a multi-label one-versus-all loss based on | ||
* max-entropy, between input x and target y of size (N, C) where N is the | ||
* batch size and C is the number of classes. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_HPP | ||
#define MLPACK_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
/** | ||
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
*/ | ||
template < | ||
typename InputDataType = arma::mat, | ||
typename OutputDataType = arma::mat | ||
> | ||
class MultiLabelSoftMarginLoss | ||
{ | ||
public: | ||
/** | ||
* Create the MultiLabelSoftMarginLoss object. | ||
* | ||
* @param reduction Specifies the reduction to apply to the output. If false, | ||
* 'mean' reduction is used, where sum of the output will be | ||
* divided by the number of elements in the output. If | ||
* true, 'sum' reduction is used and the output will be | ||
* summed. It is set to true by default. | ||
* @param weights A manual rescaling weight given to each class. It is a | ||
* (1, numClasses) row vector. | ||
*/ | ||
MultiLabelSoftMarginLoss(const bool reduction = true, | ||
const arma::rowvec& weights = arma::rowvec()); | ||
|
||
/** | ||
* Computes the Multi Label Soft Margin Loss function. | ||
* | ||
* @param input Input data used for evaluating the specified function. | ||
* @param target The target vector with same shape as input. | ||
*/ | ||
template<typename InputType, typename TargetType> | ||
typename InputType::elem_type Forward(const InputType& input, | ||
const TargetType& target); | ||
|
||
/** | ||
* Ordinary feed backward pass of a neural network. | ||
* | ||
* @param input The propagated input activation. | ||
* @param target The target vector. | ||
* @param output The calculated error. | ||
*/ | ||
template<typename InputType, typename TargetType, typename OutputType> | ||
void Backward(const InputType& input, | ||
const TargetType& target, | ||
OutputType& output); | ||
|
||
//! Get the output parameter. | ||
OutputDataType& OutputParameter() const { return outputParameter; } | ||
//! Modify the output parameter. | ||
OutputDataType& OutputParameter() { return outputParameter; } | ||
|
||
//! Get the weights assigned to each class. | ||
const arma::rowvec& ClassWeights() const { return classWeights; } | ||
//! Modify the weights assigned to each class. | ||
arma::rowvec& ClassWeights() { return classWeights; } | ||
|
||
//! Get the type of reduction used. | ||
bool Reduction() const { return reduction; } | ||
//! Modify the type of reduction used. | ||
bool& Reduction() { return reduction; } | ||
|
||
/** | ||
* Serialize the layer. | ||
*/ | ||
template<typename Archive> | ||
void serialize(Archive& ar, const unsigned int /* version */); | ||
|
||
private: | ||
//! Locally-stored output parameter object. | ||
OutputDataType outputParameter; | ||
|
||
//! The boolean value that tells if reduction is sum or mean. | ||
bool reduction; | ||
|
||
//! A (1, numClasses) shaped vector with weights for each class. | ||
arma::rowvec classWeights; | ||
|
||
// An internal parameter used during initialisation of class weights. | ||
bool weighted; | ||
}; // class MultiLabelSoftMarginLoss | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// include implementation. | ||
#include "multilabel_softmargin_loss_impl.hpp" | ||
|
||
#endif |
88 changes: 88 additions & 0 deletions
88
src/mlpack/methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/** | ||
* @file methods/ann/loss_functions/multilabel_softmargin_loss_impl.hpp | ||
* @author Anjishnu Mukherjee | ||
* | ||
* Implementation of the Multi Label Soft Margin Loss function. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_IMPL_HPP | ||
#define MLPACK_METHODS_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_IMPL_HPP | ||
|
||
// In case it hasn't been included. | ||
#include "multilabel_softmargin_loss.hpp" | ||
|
||
namespace mlpack { | ||
namespace ann /** Artifical Neural Network. */ { | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
MultiLabelSoftMarginLoss<InputDataType, OutputDataType>:: | ||
MultiLabelSoftMarginLoss( | ||
const bool reduction, | ||
const arma::rowvec& weights) : | ||
reduction(reduction), | ||
weighted(false) | ||
{ | ||
if (weights.n_elem) | ||
{ | ||
classWeights = weights; | ||
weighted = true; | ||
} | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename InputType, typename TargetType> | ||
typename InputType::elem_type | ||
MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::Forward( | ||
const InputType& input, const TargetType& target) | ||
{ | ||
if (!weighted) | ||
{ | ||
classWeights.ones(1, input.n_cols); | ||
weighted = true; | ||
} | ||
|
||
InputType logSigmoid = arma::log((1 / (1 + arma::exp(-input)))); | ||
InputType logSigmoidNeg = arma::log(1 / (1 + arma::exp(input))); | ||
InputType loss = arma::mean(arma::sum(-(target % logSigmoid + | ||
(1 - target) % logSigmoidNeg)) % classWeights, 1); | ||
|
||
if (reduction) | ||
return arma::as_scalar(loss); | ||
|
||
return arma::as_scalar(loss / input.n_rows); | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename InputType, typename TargetType, typename OutputType> | ||
void MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::Backward( | ||
const InputType& input, | ||
const TargetType& target, | ||
OutputType& output) | ||
{ | ||
output.set_size(size(input)); | ||
InputType sigmoid = (1 / (1 + arma::exp(-input))); | ||
output = -(target % (1 - sigmoid) - (1 - target) % sigmoid) % | ||
arma::repmat(classWeights, target.n_rows, 1) / output.n_elem; | ||
|
||
if (reduction) | ||
output = output * input.n_rows; | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename Archive> | ||
void MultiLabelSoftMarginLoss<InputDataType, OutputDataType>::serialize( | ||
Archive& ar, | ||
const unsigned int /* version */) | ||
{ | ||
ar(CEREAL_NVP(classWeights)); | ||
ar(CEREAL_NVP(reduction)); | ||
} | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters