Skip to content

Use optax's losses and schedules in mnist and imagenet flax examples.#1286

Merged
copybara-service[bot] merged 1 commit intomasterfrom
test_370893380
Jun 10, 2021
Merged

Use optax's losses and schedules in mnist and imagenet flax examples.#1286
copybara-service[bot] merged 1 commit intomasterfrom
test_370893380

Conversation

@copybara-service
Copy link

Use optax's losses and schedules in mnist and imagenet flax examples.

@avital
Copy link
Contributor

avital commented Apr 30, 2021

@mtthss let's review here. LGTM, I also asked @andsteing to review as he's the person with the most holistic view on our examples (e.g. should we make sure to port all of our example to use Optax losses and schedules?)

@avital avital self-requested a review April 30, 2021 11:58
@avital
Copy link
Contributor

avital commented Apr 30, 2021

(Oh and thanks!)

@marcvanzee
Copy link
Contributor

It seems the latest release of Optax doesn't yet support linear_schedule (and perhaps other things you are using here), so could you please release a new Pypi version?

@8bitmp3
Copy link
Contributor

8bitmp3 commented May 3, 2021

Thanks @avital 👍 (Note to self: update the Flax Linen with MNIST tutorial to use a loss from Optax)

@mtthss
Copy link
Contributor

mtthss commented May 4, 2021

It seems the latest release of Optax doesn't yet support linear_schedule (and perhaps other things you are using here), so could you please release a new Pypi version?

Releasing a new version now

@marcvanzee marcvanzee added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label May 4, 2021
Copy link
Contributor

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

(And sorry for my delayed reply...)

@@ -1,4 +1,5 @@
clu==0.0.1a2
ml-collections>=0.1.0
optax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe specify version that includes optax.linear_schedule()?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

warmup_fn = optax.linear_schedule(
init_value=0., end_value=base_learning_rate,
transition_steps=config.warmup_epochs * steps_per_epoch)
cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this max(..., 1) and not max(..., 0) ?

return -jnp.sum(
common_utils.onehot(labels, num_classes=1000) * logits) / labels.size
xentropy = optax.softmax_cross_entropy(
logits, common_utils.onehot(labels, num_classes=NUM_CLASSES))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify these by keyword?
(It's all too easy inverting them, like thinking of H(P, Q) with P~labels and Q~softmax(logits))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def cross_entropy_loss(logits, labels):
return -jnp.sum(
common_utils.onehot(labels, num_classes=1000) * logits) / labels.size
xentropy = optax.softmax_cross_entropy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In imagenet/models.py we return nn.log_softmax(x):

x = nn.log_softmax(x)

In terms of numerics, do you know if it makes any difference computing that twice?

Should we remove it from models.py in any case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed nn.log_softmax(x)

@@ -1,6 +1,7 @@
clu
flax
optax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we specify version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot(labels)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Same comments as above: should we maybe remove nn.log_softmax(x)? And would it make sense to specify arguments by keyword for extra safety?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed nn.log_softmax(x)

@mtthss
Copy link
Contributor

mtthss commented May 14, 2021

Addressed comments.

Copy link
Contributor

@andsteing andsteing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still two minor comments open but otherwise LGTM.


def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
loss = jnp.mean(optax.softmax_cross_entropy(logits, onehot(labels)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also use keyword arguments here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@copybara-service copybara-service bot force-pushed the test_370893380 branch 2 times, most recently from 353fab2 to 51e59fc Compare May 18, 2021 16:31
@copybara-service copybara-service bot force-pushed the test_370893380 branch 2 times, most recently from 041147f to ba04ee7 Compare May 24, 2021 16:30
@copybara-service copybara-service bot force-pushed the test_370893380 branch 3 times, most recently from 5246e4e to b0a274b Compare June 10, 2021 12:52
@codecov-commenter
Copy link

Codecov Report

Merging #1286 (b0a274b) into master (8846461) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #1286   +/-   ##
=======================================
  Coverage   82.34%   82.34%           
=======================================
  Files          65       65           
  Lines        5318     5318           
=======================================
  Hits         4379     4379           
  Misses        939      939           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8846461...b0a274b. Read the comment docs.

@copybara-service copybara-service bot force-pushed the test_370893380 branch 2 times, most recently from e4d0282 to 0fc721a Compare June 10, 2021 13:35
@copybara-service copybara-service bot merged commit 44ee6f2 into master Jun 10, 2021
@copybara-service copybara-service bot deleted the test_370893380 branch June 10, 2021 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. pull ready

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants