Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax.default_device context manager #9118

Merged
merged 1 commit into from Jun 2, 2022

Conversation

skye
Copy link
Collaborator

@skye skye commented Jan 7, 2022

This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
tensorflow/tensorflow#53656

@skye skye requested review from hawkinsp and mattjj January 7, 2022 01:12
@skye
Copy link
Collaborator Author

skye commented Jan 7, 2022

Sorry this is so huge; I can break up into smaller changes if that'd be helpful. I'll also create a CL if you prefer to review that. EDIT: once tensorflow/tensorflow#53656 goes in, which is a prerequisite. I also need to build a new jaxlib with that before this can go in.

Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This wasn't too long at all.

Left a couple optional comments.

jax/_src/config.py Outdated Show resolved Hide resolved
jax/_src/config.py Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 7, 2022
@skye skye force-pushed the device_context_manager branch 2 times, most recently from 35c1952 to 1c890d9 Compare January 11, 2022 23:09
@skye skye force-pushed the device_context_manager branch 2 times, most recently from d9a322e to 819824c Compare February 23, 2022 00:39
@cgarciae
Copy link
Collaborator

👀

This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
tensorflow/tensorflow#53656
@copybara-service copybara-service bot merged commit ea54754 into google:main Jun 2, 2022
ayaka14732 added a commit to ayaka14732/bart-base-jax that referenced this pull request Oct 10, 2022
ayaka14732 added a commit to ayaka14732/bart-base-jax that referenced this pull request Oct 13, 2022
* Update names and README

* Add masking algorithm

Finally solved after 6 months

* Update mask algorithm

* Update mask algorithm

* Update mask.py

* Update mask.py

* Fix masking

* Update random wrapper

Fixed in google/jax#9118

* Add tests for random.wrapper

* Update dependences

* Masking using the wakong algorithm

* Adapt to TPU Pods (#5)

* Adapt to multihost environment

* Fix multihost

* Fix bug

* Fix TPU initialisation

Fix bug introduced by google/jax#12642, because
the script is supposed to automatically fallback to CPU in subprocesses

* Fix multihost environment

Avoid all global variables

* Fix dataloader

* Fix bug

* Fix bug

* Use jax-smi

* Scheduler

* Use batch_size_per_device

* Scale loss

* Fix bug: no device_split

* Chore

* Undo scheduler and loss scaling

* Tune lr

* Re-apply loss scaling

* Re-apply scheduler, use AdamW

* Chore

* Increase batch_size_per_device to 64

* Try out lamb optimiser

* Increase batch size and learning rate

* Decrease batch size

* Formatting

* Format code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants