Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: make flax_resnet.py faster and more accurate across the board #119

Merged
merged 1 commit into from
Dec 9, 2021

Conversation

fabianp
Copy link
Collaborator

@fabianp fabianp commented Dec 7, 2021

Most changes are geared towards ensuring we give strong and
efficient baselines by default, without parameter tuning. The most
dramatic gains are in the CIFAR* datasets.

Currently, the train/test accuracy for the different datasets with
default flags are:

  • MNIST: 0.99/0.99 (this one didn't change much)
  • FASHION_MNIST: 0.94/0.88 with a Time elapsed of 0:03:09
    (vs 0.71/0.69 for previous defaults).
  • E_MNIST: 0.85/0.85 with a Time elapsed of 0:25:56
    (vs 0.31/0.37 for previous defaults).
  • CIFAR10: 0.87/0.75 with Time elapsed of 0:02:36
    (vs 0.23/0.22 for previous default architecture and
    0.29/0.32 for resnet18).
  • CIFAR100: 0.72/0.39, with a time elapsed: 0:02:31
    (vs 0.07/0.06 for previous default architecture and
    0.11/0.09 for resnet18).

Here, time elapsed is on a workstation with a GeForce GTX 1080 GPU. It
will certainly be slower on CPU, although I expect most people aiming to
train large resnets to have access to a GPU.

More precise, main changes are:

  • Refactored the printing statements to be able to jit compile the
    update rule. It now runs very efficiently on GPU.

  • Replace maxiter by epochs. The maxiter of 100 is nowhere near
    to give reasonable accuracy on CIFAR10, and I believe that setting
    the limit in terms of epochs is both more common and easier to set
    when considering several datasets in the same example.

  • Use resnet18 as default. resnet1 gives 0.4 test accuracy on cifar10
    after 30 epochs, while resnet18 goes up to 0.7.

  • Print running time. It's been useful for me to compare the efficiency
    of different approaches.

@fabianp fabianp requested a review from mblondel December 7, 2021 17:53
Copy link
Collaborator

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot Fabian!

@fabianp
Copy link
Collaborator Author

fabianp commented Dec 7, 2021

@mblondel : do you think I should switch the maxiter -> epochs in the related examples for consistency? I'm thinking flax_image_classif.py and perhaps the haiku* ones

@mblondel
Copy link
Collaborator

mblondel commented Dec 7, 2021

@fabianp That could be nice. That said, I would like to keep an example that uses run_iterator. Maybe we should try id_print within pre_update in these examples.

@fabianp
Copy link
Collaborator Author

fabianp commented Dec 7, 2021

@fabianp That could be nice. That said, I would like to keep an example that uses run_iterator. Maybe we should try id_print within pre_update in these examples.

Should be doable since I'm just computing the max_iter necessary to reach X epochs, not changing the structure of the for loop.

WRT to the id_print, I agree we should use it in the other examples

Most changes are geared towards ensuring we give strong and
efficient baselines by default, without parameter tuning. The most
dramatic gains are in the CIFAR* datasets.

Currently, the train/test accuracy for the different datasets with
default flags are:

  - MNIST: 0.99/0.99 (this one didn't change much)
  - FASHION_MNIST: 0.94/0.88 with a Time elapsed of 0:03:09
       (vs 0.71/0.69 for previous defaults).
  - E_MNIST: 0.85/0.85 with a Time elapsed of 0:25:56
       (vs 0.31/0.37 for previous defaults).
  - CIFAR10: 0.87/0.75 with Time elapsed of 0:02:36
       (vs 0.23/0.22 for previous default architecture and
        0.29/0.32 for resnet18).
  - CIFAR100: 0.72/0.39, with a time elapsed: 0:02:31
        (vs 0.07/0.06 for previous default architecture and
        0.11/0.09 for resnet18).

Here, time elapsed is on a workstation with a GeForce GTX 1080 GPU. It
will certainly be slower on CPU, although I expect most people aiming to
train large resnets to have access to a GPU.

More precise, main changes are:

  * Refactored the printing statements to be able to jit compile the
    update rule. It now runs very efficiently on GPU.

  * Replace maxiter by epochs. The maxiter of 100 is nowhere near
    to give reasonable accuracy on CIFAR10, and I believe that setting
    the limit in terms of epochs is both more common and easier to set
    when considering several datasets in the same example.

  * Use resnet18 as default. resnet1 gives 0.4 test accuracy on cifar10
    after 30 epochs, while resnet18 goes up to 0.7.

  * Print running time. It's been useful for me to compare the efficiency
    of different approaches.
@fabianp
Copy link
Collaborator Author

fabianp commented Dec 9, 2021

I added the jitted_update to avoid recompilation. i'll open an issue for the renaming of maxiter-> epochs in other examples

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants