Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding triplet margin loss function #2762

Merged
merged 13 commits into from
Dec 23, 2020
Merged
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* Add `BUILD_DOCS` CMake option to control whether Doxygen documentation is
built (default ON) (#2730).

* Add Triplet Margin Loss function (#2762).

### mlpack 3.4.2
###### 2020-10-26
* Added Mean Absolute Percentage Error.
Expand Down
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/loss_functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ set(SOURCES
empty_loss_impl.hpp
mean_absolute_percentage_error.hpp
mean_absolute_percentage_error_impl.hpp
triplet_margin_loss.hpp
triplet_margin_loss_impl.hpp
)

# Add directory name to sources.
Expand Down
111 changes: 111 additions & 0 deletions src/mlpack/methods/ann/loss_functions/triplet_margin_loss.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/**
* @file triplet_margin_loss.hpp
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
* @author Prince Gupta
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
*
* Definition of the Triplet Margin Loss function.
*
* For more information, refer the following paper.
*
* @code
* @article{Schroff2015,
* author = {Florian Schroff, Dmitry Kalenichenko, James Philbin},
* title = {FaceNet: A Unified Embedding for Face Recognition and Clustering},
* year = {2015},
* url = {https://arxiv.org/abs/1503.03832},
* }
* @endcode
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_LOSS_HPP
#define MLPACK_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_LOSS_HPP

#include <mlpack/prereqs.hpp>

namespace mlpack {
namespace ann /** Artificial Neural Network. */ {

/**
* The TripletMarginLoss function's objective is that the distance from the
* anchor input to the positive input is minimized, and the distance from the
* anchor input to the negative input is maximized.
*
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
*/
template <
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class TripletMarginLoss
{
public:
/**
* Create the TripletMarginLoss object with Hyperparameter margin.
* Hyperparameter margin defines the minimum value by which the distance
* between Anchor and Negative sample exceeds the distance between
* Anchor and Positive sample.
* The distance between two samples A and B is defined as square of L2 norm
* of A-B.
*/
TripletMarginLoss(const double margin = 1.0);
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Computes the Triplet Margin Loss function.
*
* @param input The propagated input activation. It should be
* concatenated anchor and positive samples.
* @param target The target vector. It should be negative samples.
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
*/
template<typename InputType, typename TargetType>
typename InputType::elem_type Forward(const InputType& input,
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
const TargetType& target);
/**
* Ordinary feed backward pass of a neural network.
*
* @param input The propagated input activation. It should be
* concatenated anchor and positive samples.
* @param target The target vector. It should be negative samples.
* @param output The calculated error.
*/
template<typename InputType, typename TargetType, typename OutputType>
void Backward(const InputType& input,
const TargetType& target,
OutputType& output);

//! Get the output parameter.
OutputDataType& OutputParameter() const { return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() { return outputParameter; }

//! Get the output parameter.
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
double Margin() const { return margin; }
//! Modify the output parameter.
double& Margin() { return margin; }

/**
* Serialize the layer.
*/
template<typename Archive>
void serialize(Archive& ar, const unsigned int /* version */);

private:
//! Locally-stored output parameter object.
OutputDataType outputParameter;

//! The margin value used in calculating Triplet Margin Loss.
double margin;
}; // class TripletLossMargin

} // namespace ann
} // namespace mlpack

// include implementation.
#include "triplet_margin_loss_impl.hpp"

#endif
70 changes: 70 additions & 0 deletions src/mlpack/methods/ann/loss_functions/triplet_margin_loss_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/**
* @file triplet_margin_loss_impl.hpp
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
* @author Prince Gupta
*
* Implementation of the Triplet Margin Loss function.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP
#define MLPACK_METHODS_ANN_LOSS_FUNCTION_TRIPLET_MARGIN_IMPL_LOSS_HPP

// In case it hasn't been included.
#include "triplet_margin_loss.hpp"

namespace mlpack {
namespace ann /** Artifical Neural Network. */ {

template<typename InputDataType, typename OutputDataType>
TripletMarginLoss<InputDataType, OutputDataType>::TripletMarginLoss(
const double margin) : margin(margin)
{
// Nothing to do here.
}

template<typename InputDataType, typename OutputDataType>
template<typename InputType, typename TargetType>
typename InputType::elem_type
TripletMarginLoss<InputDataType, OutputDataType>::Forward(
const InputType& input,
const TargetType& target)
{
InputType anchor = input.submat(0, 0, input.n_rows / 2 - 1, input.n_cols - 1);
InputType positive = input.submat(input.n_rows / 2, 0, input.n_rows - 1,
input.n_cols - 1);
return std::max(0.0, arma::accu(arma::pow(anchor - positive, 2)) -
arma::accu(arma::pow(anchor - target, 2)) + margin) / anchor.n_cols;
}

template<typename InputDataType, typename OutputDataType>
template <
typename InputType,
typename TargetType,
typename OutputType
>
void TripletMarginLoss<InputDataType, OutputDataType>::Backward(
const InputType& input,
const TargetType& target,
OutputType& output)
{
InputType positive = input.submat(input.n_rows / 2, 0, input.n_rows - 1,
input.n_cols - 1);
output = 2 * (target - positive) / target.n_cols;
}

template<typename InputDataType, typename OutputDataType>
template<typename Archive>
void TripletMarginLoss<InputDataType, OutputDataType>::serialize(
Archive& ar,
const unsigned int /* version */)
{
ar(CEREAL_NVP(margin));
}

} // namespace ann
} // namespace mlpack

#endif
51 changes: 51 additions & 0 deletions src/mlpack/tests/loss_functions_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <mlpack/methods/ann/loss_functions/l1_loss.hpp>
#include <mlpack/methods/ann/loss_functions/soft_margin_loss.hpp>
#include <mlpack/methods/ann/loss_functions/mean_absolute_percentage_error.hpp>
#include <mlpack/methods/ann/loss_functions/triplet_margin_loss.hpp>
#include <mlpack/methods/ann/init_rules/nguyen_widrow_init.hpp>
#include <mlpack/methods/ann/ffn.hpp>

Expand Down Expand Up @@ -897,3 +898,53 @@ TEST_CASE("MeanAbsolutePercentageErrorTest", "[LossFunctionsTest]")
REQUIRE(output.n_cols == input.n_cols);
CheckMatrices(output, expectedOutput, 0.1);
}

/*
* Simple test for the Triplet Margin Loss function.
*/
BOOST_AUTO_TEST_CASE(TripletMarginLossTest)
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
{
arma::mat anchor, positive, negative;
arma::mat input, target, output;
TripletMarginLoss<> module;

// Test the Forward function on a user generated input and compare it against
// the manually calculated result.
anchor = arma::mat("2 3 5");
positive = arma::mat("10 12 13");
negative = arma::mat("4 5 7");

input = { {2, 3, 5}, {10, 12, 13} };

double error = module.Forward(input, negative);
BOOST_REQUIRE_EQUAL(error, 66);
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved

// Test the Backward function.
module.Backward(input, negative, output);
// According to the used backward formula:
// output = 2 * (negative - positive) / anchor.n_cols,
// output * nofColumns / 2 + positive should be equal to negative.
CheckMatrices(negative, output * output.n_cols / 2 + positive);
BOOST_REQUIRE_EQUAL(output.n_rows, anchor.n_rows);
ayushsingh11 marked this conversation as resolved.
Show resolved Hide resolved
BOOST_REQUIRE_EQUAL(output.n_cols, anchor.n_cols);

// Test the error function on a single input.
anchor = arma::mat("4");
positive = arma::mat("7");
negative = arma::mat("1");

input = arma::mat(2, 1);
input[0] = 4;
input[1] = 7;

error = module.Forward(input, negative);
BOOST_REQUIRE_EQUAL(error, 1.0);

// Test the Backward function on a single input.
module.Backward(input, negative, output);
// Test whether the output is negative.
BOOST_REQUIRE_EQUAL(arma::accu(output), -12);
BOOST_REQUIRE_EQUAL(output.n_elem, 1);
}

BOOST_AUTO_TEST_SUITE_END();