-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2762 from ayushsingh11/master
Adding triplet margin loss function.
- Loading branch information
Showing
5 changed files
with
235 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
src/mlpack/methods/ann/loss_functions/triplet_margin_loss.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/** | ||
* @file methods/ann/loss_functions/triplet_margin_loss.hpp | ||
* @author Prince Gupta | ||
* @author Ayush Singh | ||
* | ||
* Definition of the Triplet Margin Loss function. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_LOSS_HPP | ||
#define MLPACK_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_LOSS_HPP | ||
|
||
#include <mlpack/prereqs.hpp> | ||
|
||
namespace mlpack { | ||
namespace ann /** Artificial Neural Network. */ { | ||
|
||
/** | ||
* The Triplet Margin Loss performance function measures the network's | ||
* performance according to the relative distance from the anchor input | ||
* of the positive (truthy) and negative (falsy) inputs. | ||
* The distance between two samples A and B is defined as square of L2 norm | ||
* of A-B. | ||
* | ||
* For more information, refer the following paper. | ||
* | ||
* @code | ||
* @article{Schroff2015, | ||
* author = {Florian Schroff, Dmitry Kalenichenko, James Philbin}, | ||
* title = {FaceNet: A Unified Embedding for Face Recognition and Clustering}, | ||
* year = {2015}, | ||
* url = {https://arxiv.org/abs/1503.03832}, | ||
* } | ||
* @endcode | ||
* | ||
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat, | ||
* arma::sp_mat or arma::cube). | ||
*/ | ||
template < | ||
typename InputDataType = arma::mat, | ||
typename OutputDataType = arma::mat | ||
> | ||
class TripletMarginLoss | ||
{ | ||
public: | ||
/** | ||
* Create the TripletMarginLoss object. | ||
* | ||
* @param margin The minimum value by which the distance between | ||
* Anchor and Negative sample exceeds the distance | ||
* between Anchor and Positive sample. | ||
*/ | ||
TripletMarginLoss(const double margin = 1.0); | ||
|
||
/** | ||
* Computes the Triplet Margin Loss function. | ||
* | ||
* @param prediction Concatenated anchor and positive sample. | ||
* @param target The negative sample. | ||
*/ | ||
template<typename PredictionType, typename TargetType> | ||
typename PredictionType::elem_type Forward(const PredictionType& prediction, | ||
const TargetType& target); | ||
/** | ||
* Ordinary feed backward pass of a neural network. | ||
* | ||
* @param prediction Concatenated anchor and positive sample. | ||
* @param target The negative sample. | ||
* @param loss The calculated error. | ||
*/ | ||
template<typename PredictionType, typename TargetType, typename LossType> | ||
void Backward(const PredictionType& prediction, | ||
const TargetType& target, | ||
LossType& loss); | ||
|
||
//! Get the output parameter. | ||
OutputDataType& OutputParameter() const { return outputParameter; } | ||
//! Modify the output parameter. | ||
OutputDataType& OutputParameter() { return outputParameter; } | ||
|
||
//! Get the value of margin. | ||
double Margin() const { return margin; } | ||
//! Modify the value of margin. | ||
double& Margin() { return margin; } | ||
|
||
/** | ||
* Serialize the layer. | ||
*/ | ||
template<typename Archive> | ||
void serialize(Archive& ar, const unsigned int /* version */); | ||
|
||
private: | ||
//! Locally-stored output parameter object. | ||
OutputDataType outputParameter; | ||
|
||
//! The margin value used in calculating Triplet Margin Loss. | ||
double margin; | ||
}; // class TripletLossMargin | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// include implementation. | ||
#include "triplet_margin_loss_impl.hpp" | ||
|
||
#endif |
71 changes: 71 additions & 0 deletions
71
src/mlpack/methods/ann/loss_functions/triplet_margin_loss_impl.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/** | ||
* @file methods/ann/loss_functions/triplet_margin_loss_impl.hpp | ||
* @author Prince Gupta | ||
* @author Ayush Singh | ||
* | ||
* Implementation of the Triplet Margin Loss function. | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP | ||
#define MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP | ||
|
||
// In case it hasn't been included. | ||
#include "triplet_margin_loss.hpp" | ||
|
||
namespace mlpack { | ||
namespace ann /** Artifical Neural Network. */ { | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
TripletMarginLoss<InputDataType, OutputDataType>::TripletMarginLoss( | ||
const double margin) : margin(margin) | ||
{ | ||
// Nothing to do here. | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename PredictionType, typename TargetType> | ||
typename PredictionType::elem_type | ||
TripletMarginLoss<InputDataType, OutputDataType>::Forward( | ||
const PredictionType& prediction, | ||
const TargetType& target) | ||
{ | ||
PredictionType anchor = prediction.submat(0, 0, prediction.n_rows / 2 - 1, prediction.n_cols - 1); | ||
PredictionType positive = prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1, | ||
prediction.n_cols - 1); | ||
return std::max(0.0, arma::accu(arma::pow(anchor - positive, 2)) - | ||
arma::accu(arma::pow(anchor - target, 2)) + margin) / anchor.n_cols; | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template < | ||
typename PredictionType, | ||
typename TargetType, | ||
typename LossType | ||
> | ||
void TripletMarginLoss<InputDataType, OutputDataType>::Backward( | ||
const PredictionType& prediction, | ||
const TargetType& target, | ||
LossType& loss) | ||
{ | ||
PredictionType positive = prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1, | ||
prediction.n_cols - 1); | ||
loss = 2 * (target - positive) / target.n_cols; | ||
} | ||
|
||
template<typename InputDataType, typename OutputDataType> | ||
template<typename Archive> | ||
void TripletMarginLoss<InputDataType, OutputDataType>::serialize( | ||
Archive& ar, | ||
const unsigned int /* version */) | ||
{ | ||
ar(CEREAL_NVP(margin)); | ||
} | ||
|
||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters