Skip to content

Commit

Permalink
Merge pull request #2762 from ayushsingh11/master
Browse files Browse the repository at this point in the history
Adding triplet margin loss function.
  • Loading branch information
zoq committed Dec 23, 2020
2 parents 14e1201 + 0e6e82f commit 7da22c5
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 0 deletions.
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* Add `BUILD_DOCS` CMake option to control whether Doxygen documentation is
built (default ON) (#2730).

* Add Triplet Margin Loss function (#2762).

* Add finalizers to Julia binding model types to fix memory handling (#2756).

### mlpack 3.4.2
Expand Down
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/loss_functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ set(SOURCES
empty_loss_impl.hpp
mean_absolute_percentage_error.hpp
mean_absolute_percentage_error_impl.hpp
triplet_margin_loss.hpp
triplet_margin_loss_impl.hpp
)

# Add directory name to sources.
Expand Down
111 changes: 111 additions & 0 deletions src/mlpack/methods/ann/loss_functions/triplet_margin_loss.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/**
* @file methods/ann/loss_functions/triplet_margin_loss.hpp
* @author Prince Gupta
* @author Ayush Singh
*
* Definition of the Triplet Margin Loss function.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_LOSS_HPP
#define MLPACK_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_LOSS_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* The Triplet Margin Loss performance function measures the network's
* performance according to the relative distance from the anchor input
* of the positive (truthy) and negative (falsy) inputs.
* The distance between two samples A and B is defined as square of L2 norm
* of A-B.
*
* For more information, refer the following paper.
*
* @code
* @article{Schroff2015,
* author = {Florian Schroff, Dmitry Kalenichenko, James Philbin},
* title = {FaceNet: A Unified Embedding for Face Recognition and Clustering},
* year = {2015},
* url = {https://arxiv.org/abs/1503.03832},
* }
* @endcode
*
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
*/
template <
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class TripletMarginLoss
{
public:
/**
* Create the TripletMarginLoss object.
*
* @param margin The minimum value by which the distance between
* Anchor and Negative sample exceeds the distance
* between Anchor and Positive sample.
*/
TripletMarginLoss(const double margin = 1.0);

/**
* Computes the Triplet Margin Loss function.
*
* @param prediction Concatenated anchor and positive sample.
* @param target The negative sample.
*/
template<typename PredictionType, typename TargetType>
typename PredictionType::elem_type Forward(const PredictionType& prediction,
const TargetType& target);
/**
* Ordinary feed backward pass of a neural network.
*
* @param prediction Concatenated anchor and positive sample.
* @param target The negative sample.
* @param loss The calculated error.
*/
template<typename PredictionType, typename TargetType, typename LossType>
void Backward(const PredictionType& prediction,
const TargetType& target,
LossType& loss);

//! Get the output parameter.
OutputDataType& OutputParameter() const { return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() { return outputParameter; }

//! Get the value of margin.
double Margin() const { return margin; }
//! Modify the value of margin.
double& Margin() { return margin; }

/**
* Serialize the layer.
*/
template<typename Archive>
void serialize(Archive& ar, const unsigned int /* version */);

private:
//! Locally-stored output parameter object.
OutputDataType outputParameter;

//! The margin value used in calculating Triplet Margin Loss.
double margin;
}; // class TripletLossMargin

} // namespace ann
} // namespace mlpack

// include implementation.
#include "triplet_margin_loss_impl.hpp"

#endif
71 changes: 71 additions & 0 deletions src/mlpack/methods/ann/loss_functions/triplet_margin_loss_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/**
* @file methods/ann/loss_functions/triplet_margin_loss_impl.hpp
* @author Prince Gupta
* @author Ayush Singh
*
* Implementation of the Triplet Margin Loss function.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP
#define MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP

// In case it hasn't been included.
#include "triplet_margin_loss.hpp"

namespace mlpack {
namespace ann /** Artifical Neural Network. */ {

template<typename InputDataType, typename OutputDataType>
TripletMarginLoss<InputDataType, OutputDataType>::TripletMarginLoss(
const double margin) : margin(margin)
{
// Nothing to do here.
}

template<typename InputDataType, typename OutputDataType>
template<typename PredictionType, typename TargetType>
typename PredictionType::elem_type
TripletMarginLoss<InputDataType, OutputDataType>::Forward(
const PredictionType& prediction,
const TargetType& target)
{
PredictionType anchor = prediction.submat(0, 0, prediction.n_rows / 2 - 1, prediction.n_cols - 1);
PredictionType positive = prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1,
prediction.n_cols - 1);
return std::max(0.0, arma::accu(arma::pow(anchor - positive, 2)) -
arma::accu(arma::pow(anchor - target, 2)) + margin) / anchor.n_cols;
}

template<typename InputDataType, typename OutputDataType>
template <
typename PredictionType,
typename TargetType,
typename LossType
>
void TripletMarginLoss<InputDataType, OutputDataType>::Backward(
const PredictionType& prediction,
const TargetType& target,
LossType& loss)
{
PredictionType positive = prediction.submat(prediction.n_rows / 2, 0, prediction.n_rows - 1,
prediction.n_cols - 1);
loss = 2 * (target - positive) / target.n_cols;
}

template<typename InputDataType, typename OutputDataType>
template<typename Archive>
void TripletMarginLoss<InputDataType, OutputDataType>::serialize(
Archive& ar,
const unsigned int /* version */)
{
ar(CEREAL_NVP(margin));
}

} // namespace ann
} // namespace mlpack

#endif
49 changes: 49 additions & 0 deletions src/mlpack/tests/loss_functions_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <mlpack/methods/ann/loss_functions/l1_loss.hpp>
#include <mlpack/methods/ann/loss_functions/soft_margin_loss.hpp>
#include <mlpack/methods/ann/loss_functions/mean_absolute_percentage_error.hpp>
#include <mlpack/methods/ann/loss_functions/triplet_margin_loss.hpp>
#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
#include <mlpack/methods/ann/ffn.hpp>

Expand Down Expand Up @@ -897,3 +898,51 @@ TEST_CASE("MeanAbsolutePercentageErrorTest", "[LossFunctionsTest]")
REQUIRE(output.n_cols == input.n_cols);
CheckMatrices(output, expectedOutput, 0.1);
}

/*
* Simple test for the Triplet Margin Loss function.
*/
TEST_CASE("TripletMarginLossTest")
{
arma::mat anchor, positive, negative;
arma::mat input, target, output;
TripletMarginLoss<> module;

// Test the Forward function on a user generated input and compare it against
// the manually calculated result.
anchor = arma::mat("2 3 5");
positive = arma::mat("10 12 13");
negative = arma::mat("4 5 7");

input = { {2, 3, 5}, {10, 12, 13} };

double loss = module.Forward(input, negative);
REQUIRE(loss == 66);

// Test the Backward function.
module.Backward(input, negative, output);
// According to the used backward formula:
// output = 2 * (negative - positive) / anchor.n_cols,
// output * nofColumns / 2 + positive should be equal to negative.
CheckMatrices(negative, output * output.n_cols / 2 + positive);
REQUIRE(output.n_rows == anchor.n_rows);
REQUIRE(output.n_cols == anchor.n_cols);

// Test the loss function on a single input.
anchor = arma::mat("4");
positive = arma::mat("7");
negative = arma::mat("1");

input = arma::mat(2, 1);
input[0] = 4;
input[1] = 7;

loss = module.Forward(input, negative);
REQUIRE(loss == 1.0);

// Test the Backward function on a single input.
module.Backward(input, negative, output);
// Test whether the output is negative.
REQUIRE(arma::accu(output) == -12);
REQUIRE(output.n_elem == 1);
}

0 comments on commit 7da22c5

Please sign in to comment.