New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jax.default_device
context manager
#9118
Conversation
Sorry this is so huge; I can break up into smaller changes if that'd be helpful. I'll also create a CL if you prefer to review that. EDIT: once tensorflow/tensorflow#53656 goes in, which is a prerequisite. I also need to build a new jaxlib with that before this can go in. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This wasn't too long at all.
Left a couple optional comments.
35c1952
to
1c890d9
Compare
d9a322e
to
819824c
Compare
👀 |
819824c
to
32da2a3
Compare
This currently only supports setting a specific Device object, not a platform like "cpu". That should be added in the future. Bumps the minimum jaxlib version in order to include tensorflow/tensorflow#53656
32da2a3
to
f26b866
Compare
Fixed in google/jax#9118
* Update names and README * Add masking algorithm Finally solved after 6 months * Update mask algorithm * Update mask algorithm * Update mask.py * Update mask.py * Fix masking * Update random wrapper Fixed in google/jax#9118 * Add tests for random.wrapper * Update dependences * Masking using the wakong algorithm * Adapt to TPU Pods (#5) * Adapt to multihost environment * Fix multihost * Fix bug * Fix TPU initialisation Fix bug introduced by google/jax#12642, because the script is supposed to automatically fallback to CPU in subprocesses * Fix multihost environment Avoid all global variables * Fix dataloader * Fix bug * Fix bug * Use jax-smi * Scheduler * Use batch_size_per_device * Scale loss * Fix bug: no device_split * Chore * Undo scheduler and loss scaling * Tune lr * Re-apply loss scaling * Re-apply scheduler, use AdamW * Chore * Increase batch_size_per_device to 64 * Try out lamb optimiser * Increase batch size and learning rate * Decrease batch size * Formatting * Format code
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.
Bumps the minimum jaxlib version in order to include
tensorflow/tensorflow#53656