-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Triplet margin loss function. #2208
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks fine, I think positiveType and negativeType would always be same to support matrix operations however I'm not sure at this point. I'll take a closer look at the code first thing tomorrow. Nice work.
OutputDataType& OutputParameter() { return outputParameter; } | ||
|
||
//! Get the output parameter. | ||
double& Margin() const { return margin; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There wont be & here. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small style issues. Rest of it looks great. Thanks for all the work.
* @param input The propagated input activation. | ||
* @param target The target vector. | ||
*/ | ||
template<typename AnchorType, typename PositiveType, typename NegativeType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should align with * from the comments. This also true in other places as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it like this only, no need to reduce templates, I checked couple of other programs they use different template even though they must have same datatype.
const PositiveType&& positive, | ||
const NegativeType&& negative, | ||
OutputType&& output | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you place this bracket in line 50.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also resolve this.
Archive& ar, | ||
const unsigned int /* version */) | ||
{ | ||
ar & BOOST_SERIALIZATION_NVP(margin); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, most people missed that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rest of it looks fine. However a member will provide better insight. Thanks.
/** | ||
* Create the TripletMarginLoss object with Hyperparameter margin. | ||
* Hyperparameter margin defines the minimum numeric value by which distance | ||
* between Anchor and Negative sample should be higher than the distance between |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you try doing the following :
- minimum value by which the distance between Anchor and Negative sample must exceed the distance between ...
This would increase readability.
Or if you come up with something that is better worded use that.
const PositiveType&& positive, | ||
const NegativeType&& negative, | ||
OutputType&& output | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also resolve this.
This code touches the ANN code, and there was recently a big refactoring of the ANN code in #2259, so be sure to merge the master branch into your branch here to make sure that nothing will fail if this PR is merged. 👍 (I'm pasting this message into all possibly relevant PRs.) |
anchor = arma::mat("2 3 5"); | ||
positive = arma::mat("10 12 13"); | ||
negative = arma::mat("4 5 7"); | ||
double error = module.Forward(std::move(anchor), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need std::move anywhere in this file anymore
arma::mat anchor, positive, negative, output; | ||
TripletMarginLoss<> module; | ||
|
||
// Test the Forward function on a user generator input and compare it against |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
user generated
Hey Prince, There is a change (maybe), Instead of taking three inputs what do you think about taking two concatenated inputs? |
Hey there, This code touches the loss function portion of mlpack's codebase. Recently #2339 was merged where the switch was made to templated return type instead of just using double. So before we merge this, kindly make the changes for the same. I'm pasting this in all relevant PRs. Thanks a lot 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add to HISTORY.md? Thanks for keeping up with all the repo changes by the way.
namespace ann /** Artificial Neural Network. */ { | ||
|
||
/** | ||
* The TripletMarginLoss function's objective is that the distance between |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I can understand this. Perhaps you could take cues from the description from Wikipedia or something?
* For more information, see the following paper. | ||
* | ||
* @code | ||
* @article{Janocha2017 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you have put down the same paper twice
positive = arma::mat("10 12 13"); | ||
negative = arma::mat("4 5 7"); | ||
|
||
input = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can put this in a single line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it so it is easy to see it's a matrix of dimensions 2x3. I can change that if you say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these will be my last comments before approval. Could you merge and correct HISTORY.md?
* Computes the Triplet Margin Loss function. | ||
* | ||
* @param input The propagated input activation. It should be | ||
* concatenated anchor and positive samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll need to tab here
* Ordinary feed backward pass of a neural network. | ||
* | ||
* @param input The propagated input activation. It should be | ||
* concatenated anchor and positive samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to tab here as well
const InputType& input, | ||
const TargetType& target) | ||
{ | ||
arma::mat anchor = input.submat(0, 0, input.n_rows / 2 - 1, input.n_cols - 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but will this not have to be InputType
in case the input is arma::fmat
?
This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions! 👍 |
Hi @prince776 and @favre49, I would like to contribute in addition of Triplet margin loss to the repository. |
Hi, I've added Triplet margin loss function in reference to issue #2200 . Please tell me if any changes are required.
Since triplet loss takes three matrices: anchor, positive and negative, its forward and backward functions are slightly different than other previous functions which took only input and targe.