Use optax's losses and schedules in mnist and imagenet flax examples.#1286
Use optax's losses and schedules in mnist and imagenet flax examples.#1286copybara-service[bot] merged 1 commit intomasterfrom
Conversation
917666c to
7e61722
Compare
|
@mtthss let's review here. LGTM, I also asked @andsteing to review as he's the person with the most holistic view on our examples (e.g. should we make sure to port all of our example to use Optax losses and schedules?) |
|
(Oh and thanks!) |
7e61722 to
4fdbe9b
Compare
|
It seems the latest release of Optax doesn't yet support |
|
Thanks @avital 👍 (Note to self: update the Flax Linen with MNIST tutorial to use a loss from Optax) |
Releasing a new version now |
andsteing
left a comment
There was a problem hiding this comment.
Thanks for the change!
(And sorry for my delayed reply...)
examples/imagenet/requirements.txt
Outdated
| @@ -1,4 +1,5 @@ | |||
| clu==0.0.1a2 | |||
| ml-collections>=0.1.0 | |||
| optax | |||
There was a problem hiding this comment.
Maybe specify version that includes optax.linear_schedule()?
| warmup_fn = optax.linear_schedule( | ||
| init_value=0., end_value=base_learning_rate, | ||
| transition_steps=config.warmup_epochs * steps_per_epoch) | ||
| cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1) |
There was a problem hiding this comment.
Why is this max(..., 1) and not max(..., 0) ?
examples/imagenet/train.py
Outdated
| return -jnp.sum( | ||
| common_utils.onehot(labels, num_classes=1000) * logits) / labels.size | ||
| xentropy = optax.softmax_cross_entropy( | ||
| logits, common_utils.onehot(labels, num_classes=NUM_CLASSES)) |
There was a problem hiding this comment.
Could you specify these by keyword?
(It's all too easy inverting them, like thinking of H(P, Q) with P~labels and Q~softmax(logits))
examples/imagenet/train.py
Outdated
| def cross_entropy_loss(logits, labels): | ||
| return -jnp.sum( | ||
| common_utils.onehot(labels, num_classes=1000) * logits) / labels.size | ||
| xentropy = optax.softmax_cross_entropy( |
There was a problem hiding this comment.
In imagenet/models.py we return nn.log_softmax(x):
flax/examples/imagenet/models.py
Line 117 in d804b90
In terms of numerics, do you know if it makes any difference computing that twice?
Should we remove it from models.py in any case?
examples/mnist/requirements.txt
Outdated
| @@ -1,6 +1,7 @@ | |||
| clu | |||
| flax | |||
| optax | |||
There was a problem hiding this comment.
should we specify version?
|
|
||
| def compute_metrics(logits, labels): | ||
| loss = cross_entropy_loss(logits, labels) | ||
| loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot(labels))) |
There was a problem hiding this comment.
(Same comments as above: should we maybe remove nn.log_softmax(x)? And would it make sense to specify arguments by keyword for extra safety?)
4fdbe9b to
2e04d64
Compare
2e04d64 to
162369d
Compare
162369d to
1078af5
Compare
|
Addressed comments. |
andsteing
left a comment
There was a problem hiding this comment.
There are still two minor comments open but otherwise LGTM.
|
|
||
| def compute_metrics(logits, labels): | ||
| loss = cross_entropy_loss(logits, labels) | ||
| loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot(labels))) |
There was a problem hiding this comment.
Could you also use keyword arguments here?
353fab2 to
51e59fc
Compare
041147f to
ba04ee7
Compare
5246e4e to
b0a274b
Compare
Codecov Report
@@ Coverage Diff @@
## master #1286 +/- ##
=======================================
Coverage 82.34% 82.34%
=======================================
Files 65 65
Lines 5318 5318
=======================================
Hits 4379 4379
Misses 939 939 Continue to review full report at Codecov.
|
e4d0282 to
0fc721a
Compare
PiperOrigin-RevId: 378640649
0fc721a to
44ee6f2
Compare
Use optax's losses and schedules in mnist and imagenet flax examples.