-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Hinge Loss #409
Add Hinge Loss #409
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi!
Thanks a lot for working on this! I know this is still in draft mode, but I thought it would be useful to provide some high-level comments in the meantime:
Please provide unit tests for the function that are wrapped in chex.all_variants
and uses self.variant
to run the loss with the different jax transforms (jax.jit
etc.). You can take a look at the other tests to see how that's done.
The list comprehensions with if statements don't work with jax.jit
. Please write them in terms of batch operations using jax functions and operators.
Thanks a lot again!
Hi @mkunesch, a quick question; Should I make 2 different functions for Hinge Loss (one for Binary and one for Multiclass) or coalesce them into one function? |
Having taken a look at sklearn the execution paths look quite different so I think I'd separate them into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! I added a more detailed review for the binary version - as I wrote earlier, let's merge this first and then consider the multi-class version separately.
Hi @mkunesch ! I was just in the process of committing an updated version. A question though, can I use black to format the code or any other formatter? |
I hope I've added all the requested changes. I also ran |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! This is looking great! Just a few very minor comments.
Re the formatting: I am working for better instructions on this for the Contributing.md
file, but your code looks great so if you are ok with this you can just leave the formatting as is and I'll run the internal linter over it as it's being merged (it will only make very minor changes).
Also, just to say that we will only be able to merge this after the ICLR paper submission deadline on the 28th of September but we can approve and everything else beforehand so that it can be merged immediately on the 29th.
Thanks a lot!
optax/_src/loss.py
Outdated
"""Computes the hinge loss for binary classification. | ||
|
||
Args: | ||
predictor_outputs: Outputs of the decision function with shape [...]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove the bit about the shape? I don't think we have to write anything but if we do, perhaps we could write explicitly that targets should be broadcastable to the same shape as predictor_outputs
(I'm not sure this is clear from shape [...]
)?
optax/_src/loss.py
Outdated
Target values should be strictly in the set {-1, 1}. | ||
|
||
Returns: | ||
Binary Hinge Loss with shape [...]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(as above, I'd probably just remove the bit about the shape here - the returned broadcasted shape could be a bit more complicated)
optax/_src/loss_test.py
Outdated
self.ys = np.array([1, 0, 1, 0, 0, 1, 0, 0, 1, 0]) | ||
self.ts = np.array([-1, -1, -1, -1, -1, 1, -1, 1, -1, 1]) | ||
# computed expected outputs. | ||
self.exp = np.maximum(0, 1 - self.ys * self.ts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: could you please rename self.exp
to self.correct_result
or something? I know the rest of the file uses exp
too but it is a confusing name given that exp
is usually the exponential (we should really change the name in the rest of the file too but we can do that in a separate PR).
Nit: if you could set the values rather than calculating them that would be nicer (in terms of doing as little computation in the test as possible to keep it isolated) - doing the calculation by hand once also helps ensure that all aspects (like the maximum
behavior are tested).
optax/_src/loss_test.py
Outdated
|
||
def setUp(self): | ||
super().setUp() | ||
self.ys = np.array([1, 0, 1, 0, 0, 1, 0, 0, 1, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that ys
, the predictor output can be in the range (-Inf, Inf)
? I think it would be important to test this, also as a way of documenting the function. In particular, I think your test don't test the maximum
behaviour in that they would also pass for the implementation return delta - predictor_outputs * targets
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! From what I understand by this is that I should add a different example here that would test the maximum behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mkunesch, can you please clarify on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please you could add a test for the maximum behaviour - or you could just change some numbers here so that the maximum behaviour is tested. Also, currently the ys are only 0 and 1 but if I understand correctly they can be any float in the range (-Inf, Inf) in practice so the test input should reflect this.
optax/_src/loss.py
Outdated
predictor_outputs: Outputs of the decision function. | ||
targets: Target values. | ||
Target values should be strictly in the set {-1, 1}. | ||
delta: Margin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the section for delta! If there is a one-word description that better describes the parameter, I think we should make this the name actually. So I'd call this varaible margin
and add a one-sentence description of what the argument is here.
We should also avoid confusion - other code e.g. sklearn refers to predictor_outputs * targets
as margin
- this will be clearer once you have added a description of what the margin in your case means.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mkunesch Thanks for the suggestions but I just realised that just like in sklearn, most frameworks with loss functions actually only use 1
as the value for delta (or margin) and so maybe it's better if we remove this parameter altogether and just set it manually to 1
in the function, might also remove some potential for confusion if we were to go ahead and name it margin
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds good - I think we can remove it. As you said that's in line with sklearn and pytorch. Thanks!
optax/_src/loss_test.py
Outdated
|
||
def setUp(self): | ||
super().setUp() | ||
self.ys = np.array([1, 0, 1, 0, 0, 1, 0, 0, 1, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please you could add a test for the maximum behaviour - or you could just change some numbers here so that the maximum behaviour is tested. Also, currently the ys are only 0 and 1 but if I understand correctly they can be any float in the range (-Inf, Inf) in practice so the test input should reflect this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Could you please add the hinge_loss to the api docs similar to the huber_loss (both to the summary and the detailed docs)?
Also, the license check doesn't pass - you might need to make sure the commits use the right e-mail address or sign the contributor agreement again with the new email.
Other than that looks good to me!
Thanks!
I messed up the whole thing by committing without verifying from my MacBook, will doing a new commit with my Gmail id (which is CLA signed) fix that, or do I need to close this PR and make a new one? edit: Nevermind, some StackOverflow and I fixed it. I hope it's ready to be merged :) |
5a76050
to
86cbd34
Compare
Hi! Thanks a lot for fixing it! I think it's ready to be merged but the checks don't pass at the moment - could it be because you forgot to add the intended changes to commit 2d92948? (The description says that you changed the tolerance but the commit only includes the removal of the binary file) |
@mkunesch Yes! No idea why the changes were not reflected here. Let me commit and push it now :) |
The line is <= 80 characters with the line break removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks a lot!
I might have to make some very minor formatting changes as I merge as the internal linter is slightly stricter but I'll take it from here!
Thanks again for this contribution!
* Need 2 blank lines before function def. * Fixed indentation of the function arguments.
Auto-formatting changes. Start comment with capital letter.
Hi, this PR intends to add Hinge Loss to Optax as referenced in #403.