-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dynamic-shapes] implement bint arrays (opaque dtypes), add padding rules #12707
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mattjj
force-pushed
the
djax-slice-sick4
branch
3 times, most recently
from
October 8, 2022 12:51
bcfbca8
to
56bf431
Compare
sharadmv
approved these changes
Oct 8, 2022
google-ml-butler
bot
added
kokoro:force-run
pull ready
Ready for copybara import and testing
labels
Oct 8, 2022
mattjj
force-pushed
the
djax-slice-sick4
branch
7 times, most recently
from
October 9, 2022 05:34
11c0d77
to
ccaf229
Compare
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
mattjj
force-pushed
the
djax-slice-sick4
branch
from
October 9, 2022 05:57
ccaf229
to
6d2aaac
Compare
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 10, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480184773
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 10, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480184773
For reference, the "opaque dtypes" mention in the PR title here is to say that this builds on #11768 and subsequent developments. |
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 10, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480184773
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480184773
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480184773
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480184773
copybara-service bot
pushed a commit
that referenced
this pull request
Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480229237
clrpackages
pushed a commit
to clearlinux-pkgs/pypi-jax
that referenced
this pull request
Oct 13, 2022
… 0.3.22 Adam Paszke (1): Support MANUAL collectives in top-level xmaps Artem Belevich (1): Handle FP8 types. Chao Chen (1): testSphHarmOrderZeroDegreeOne and test_custom_linear_solve_cholesky have been fixed in ROCm, no need to skip George Necula (3): Expand support for __jax_array__ in jnp.array. [jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations [jax2tf] Implement jax2tf(pjit) for experimental_native_lowering Jake VanderPlas (22): pure_callback: fix batching rule for multiple arguments Split parts of lax_numpy_test.py into separate test files. Re-land jax-ml/jax#12588 with minor fixes [typing] add full annotations for lax_numpy setops Add types to jax/_src/numpy/util.py jnp.prod & jnp.sum: promote to default integer type rather than int64/uint64 jax.jacobian: propagate function signature to transformed function [typing] add annotations to jax.numpy.linalg [typing] add annotations to numpy.fft [typing] add types for jax.numpy.polynomial [typing] overloads for jnp.linalg.svd & jnp.linalg.qr test: fix LaxNumpyTest:testConcatenate Move promote_like_jnp to jax.test_util jnp.average: support tuple axis [typing] add type annotations to lax.linalg functions Update scipy version in jax.scipy.fft [typing] annotate lax.slicing [typing] add annotations to jax.scipy.fft [typing] add annotations to jax.scipy.ndimage [typing] add type annotations to jax.scipy.linalg changelog: add missing github commit links Remove deprecated functionality from jax.test_util Jason Furmanek (2): [ROCM] Add TENSORFLOW_ROCM_COMMIT parameter to ROCM ci build Add default setting for TENSORFLOW_ROCM_COMMIT Krishna Haridasan (1): Add unsafe_buffer_pointer to _DeviceArray Kuangyuan Chen (1): Turn on cpp pjit py default Matthew Johnson (9): improve custom_jvp/vjp error messages [dynamic-shapes] small fix to einsum (and indexing) improve jit(f).lower(duck_args) and pjit(f).lower(duck_args) make device_put work with Sharding 2nd arg add test, small fixes fix -O / PYTHONOPTIMIZE bug make device_put(prngkeyarray, sharding) for Array implement bint arrays (opaque dtypes), add padding rules Rolling forward jax-ml/jax#12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests. Parker Schuh (1): Rename Executable to LoadedExecutable within jax. Peter Hawkins (10): Add an internal jtu.sample_product test decorator. Migrate api_test, lax_numpy_test, and lax_vmap_test to Switch lax_numpy_indexing_test to use jtu.sample_product. Use input-output aliasing for jaxlib GPU custom calls. Revert: Use input-output aliasing for jaxlib GPU custom calls. Reapply: Use input-output aliasing for jaxlib GPU custom calls. Copybara import of the project: Add changes accidentally omitted from Add input-output aliasing annotations for LAPACK calls on CPU. Fix compilation failure in lapack kernel under msan. Ran Ran (5): Set JAX_PLATFORMS=tpu,cpu on TPUs Add set up message for JAX_PLATFORMS Update set up message Update message and change log Address comments for change log Rishabh Kabra (1): Clarify docs for `fori_loop`, noting that negative or custom increments are not supported. Rohit Santhanam (1): [ROCm] Upgrade to ROCm 5.3 and associated enhancements Roy Frostig (1): add input/output sharding to executable protocol Serge Durand (1): Fix book links Sharad Vikram (2): Test that array layout is preserved in Python callbacks Fix collect_profile _src import Skye Wanderman-Milne (2): Update version and CHANGELOG for jax 0.3.21 release Update WORKSPACE and setup.py for jaxlib 0.3.22 release Sudhakar (1): Add multihost GPU CI run with last public jaxlib release Tianjian Lu (4): [sparse] Move broadcasting_vmap to sparse util. [sparse] Add conversions between BCSR and BCOO. [sparse] Bug fix in _validate_bcsr. [sparse] BCSR fromdense and todense. Yash Katariya (11): `get_device_buffers()` on ShardedBuffer if config.jax_array is enabled because jax.Arrray does not work with ShardedBuffer since jax.Array is like a ShardedBuffer. Create `Array`s from `__getitem__` and `__iter__`. This is done by `device_put`ting from the host to default device which is suboptimal. But there is a TODO to fix this! Add `host_local_array_to_global_array` and `global_array_to_host_local_array` for enabling transition to jax.Array. Add sharding to `DeviceArray` and `ShardedDeviceArray` as a compatibility change to rollout `jax.Array`. Add `addressable_shards` to SDA and DA as a compatibility API to match with `jax.Array`. This will aid in transition to `jax.Array`. Take `shardings` as a parameter to `deserialize` and `run_deserialization` instead of `mesh` and `pspecs`. Make `is_fully_replicated` and `is_fully_addressble` a property rather than a method. Lift `lambda x: x` to the top level so that we don't recompile on every invocation of `process_allgather`. Fix the type annotation of return type of `device_buffer` and `device_buffers` which return `ArrayImpl` instead of DeviceArray. Add weak_type attribute to `Array` since it exists on DA (but doesn't exist on SDA). Add support for calculating the device_assignment when there are no inputs to `jit` and `pjit`. dependabot[bot] (1): Bump styfle/cancel-workflow-action from 0.10.0 to 0.10.1 jax authors (1): Copybara import of the project:
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This was the last bit we needed to run a batch-shape-polymorphic transformer on XLA.
Co-authored-by: Sharad Vikram sharad.vikram@gmail.com