-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for PolyLoss (ICLR 2022) #467
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this contribution - this looks great! Thanks especially for adding the maths and recommended parameters.
I've added a few comments. I'll do another pass to compare the maths to the paper later.
optax/_src/loss.py
Outdated
logits: Unnormalized log probabilities, with shape `[..., num_classes]`. | ||
labels: Valid probability distributions (non-negative, sum to 1), e.g a | ||
one hot encoding specifying the correct class for each input; | ||
must have a shape broadcastable to `[..., num_classes]`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: ` instead of . at the end.
optax/_src/loss.py
Outdated
labels: Valid probability distributions (non-negative, sum to 1), e.g a | ||
one hot encoding specifying the correct class for each input; | ||
must have a shape broadcastable to `[..., num_classes]`` | ||
epsilon: The coefficient of the first polynomial term (default = 2.0). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's remove the default value from here as it is already visible to the user in the function signature.
optax/_src/loss.py
Outdated
- For the ImageNet 2d image classification, epsilon = 2.0 | ||
- For the 2d Instance Segmentation and object detection, epsilon = -1.0 | ||
- It is also recommended to adjust this value | ||
based on the task and dataset at hand. For example, one can use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's make this slightly more concise, e.g.:
"... adjust this value based on the task, e.g. using grid search."
optax/_src/loss.py
Outdated
must have a shape broadcastable to `[..., num_classes]`` | ||
epsilon: The coefficient of the first polynomial term (default = 2.0). | ||
According to the paper, the following values are recommended: | ||
- For the ImageNet 2d image classification, epsilon = 2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's add a full stop at the end of each bullet.
optax/_src/loss.py
Outdated
- It is also recommended to adjust this value | ||
based on the task and dataset at hand. For example, one can use | ||
simple grid search to achieve it. | ||
alpha: The smoothing factor, the greedy category with be assigned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this part "the greedy category with be assigned" ... is there a typo and it should be "will be"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, I copypasted the description from here without going through it
Since I removed the smoothing itself, this issue is now fixed in this PR. Should I submit a separate PR to fix this line for the smooth_labels
as well?
It looks like everything except for the "alpha: The smoothing factor"
can be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I removed the smoothing itself, this issue is now fixed in this PR. Should I submit a separate PR to fix this line for the smooth_labels as well?
Yes, please, that would be great - thanks a lot! I also saw some other style guide violations in the docstrings above (e.g. capitalization) - if you have time to fix these too in the same or a different PR that would be great but no worries if not of course!
optax/_src/loss.py
Outdated
probability `(1-alpha) + alpha / num_categories` (default = 0.0) | ||
|
||
Returns: | ||
poly loss between each prediction and the corresponding target |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please capitalize the first letter: Poly loss.
optax/_src/loss_test.py
Outdated
) | ||
def test_equals_to_cross_entropy_when_eps0(self, logits, labels): | ||
np.testing.assert_allclose( | ||
self.variant(loss.poly_loss_cross_entropy)(logits, labels, 0., 0.), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: please use a keyword argument for epsilon here so that it is immediately obvious to the reader that epsilon is being set to zero (you can then just leave alpha at the default or set it using a kwarg too - or we might remove it anyway).
optax/_src/loss_test.py
Outdated
|
||
@chex.all_variants | ||
@parameterized.parameters( | ||
dict(eps=2, alpha=0, expected=4.531657285679147), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are testing to atol 1e-4 we can remove some of the figures from the expected results.
Thanks for the review, I updated the code! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! Sorry for dropping the ball on this! I've added a few small nits, but we we can also address those during merging so I've also approved.
Thanks a lot!
optax/_src/loss_test.py
Outdated
dict(eps=0, expected=2.8990), | ||
dict(eps=-0.5, expected=2.4908), | ||
dict(eps=1.15, expected=3.8378), | ||
dict(eps=2, expected=4.5317), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is now a duplicate of line 514.
optax/_src/loss_test.py
Outdated
dict(eps=0, expected=np.array([0.1698, 0.8247])), | ||
dict(eps=-0.5, expected=np.array([0.0917, 0.7168])), | ||
dict(eps=1.15, expected=np.array([0.3495, 1.0731])), | ||
dict(eps=2, expected=np.array([0.4823, 1.2567])), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(as above, this is now a duplicate of the first parameter combination)
optax/_src/loss.py
Outdated
\frac{1}{N + 1} \cdot (1 - P_t)^{N + 1} + \ldots = \\ | ||
- \log(P_t) + \sum_{j = 1}^N \epsilon_j \cdot (1 - P_t)^j | ||
|
||
This function provides a simplified version of the :math:`L_{Poly-N}` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo/nit: I'd personally remove the the
before :math:
L_{Poly-N}``, i.e.
This function provides a simplified version of :math:`L_{Poly-N}`
optax/_src/loss.py
Outdated
|
||
Args: | ||
logits: Unnormalized log probabilities, with shape `[..., num_classes]`. | ||
labels: Valid probability distributions (non-negative, sum to 1), e.g a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: e.g
-> e.g.
. This is also incorrect in the softmax_cross_entropy - I can correct it there or you can include it in this PR too.
optax/_src/loss.py
Outdated
According to the paper, the following values are recommended: | ||
- For the ImageNet 2d image classification, epsilon = 2.0. | ||
- For the 2d Instance Segmentation and object detection, epsilon = -1.0. | ||
- It is also recommended to adjust this value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this line break looks early to me, let's break as close to 80 characters as possible.
optax/_src/loss.py
Outdated
- It is also recommended to adjust this value | ||
based on the task and dataset at hand. For example, one can use | ||
simple grid search to achieve it. | ||
alpha: The smoothing factor, the greedy category with be assigned |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I removed the smoothing itself, this issue is now fixed in this PR. Should I submit a separate PR to fix this line for the smooth_labels as well?
Yes, please, that would be great - thanks a lot! I also saw some other style guide violations in the docstrings above (e.g. capitalization) - if you have time to fix these too in the same or a different PR that would be great but no worries if not of course!
optax/_src/loss.py
Outdated
cross_entropy = softmax_cross_entropy(logits=logits, labels=labels) | ||
poly_loss = cross_entropy + epsilon * one_minus_pt | ||
|
||
return poly_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's return the result of the calculation directly in the line above - in this case there is no readability reason to assign to a named variable since the function name and docstring already document what is being returned.
@mkunesch addressed the comments. The tests seem to be failing locally though, but it looks like the problem is in I'll do another pass in a separate PR to fix the style in other docstrings if that's ok |
Hi! Thanks a lot for the changes and sorry for the delay!
Yes, that was an unrelated problem. We've fixed that error now - could you sync with master? That should make the checks pass so that we can merge.
Sounds good, thanks a lot! |
Hi, no worries! |
Hi @mkunesch, thanks for the review! Should I also update https://github.com/deepmind/optax/blob/master/docs/api.rst#L494 and https://github.com/deepmind/optax/blob/master/optax/__init__.py#L96? |
Ah, yes please - I forgot to check that in my review. Thanks a lot! |
Added here: #537 |
Closes #466
I focused on implementing the Poly-1 cross entropy loss with alpha-label smoothing since, as it is shown in the paper, the first term plays the most important role. Also, I crafted a small test showing that the behavior is indeed similar to the one of the
softmax_cross_entropy
when\epsilon = 0
I didn't add it to
optax/__init__.py
anddocs/api.rst
files yet (would it be ok to do that in a separate PR?)Please, let me know what you think!