Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamic-shapes] implement bint arrays (opaque dtypes), add padding rules #12707

Merged
merged 1 commit into from
Oct 9, 2022

Conversation

mattjj
Copy link
Member

@mattjj mattjj commented Oct 8, 2022

This was the last bit we needed to run a batch-shape-polymorphic transformer on XLA.

Co-authored-by: Sharad Vikram sharad.vikram@gmail.com

@mattjj mattjj requested a review from sharadmv October 8, 2022 12:23
@mattjj mattjj force-pushed the djax-slice-sick4 branch 3 times, most recently from bcfbca8 to 56bf431 Compare October 8, 2022 12:51
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Oct 8, 2022
@mattjj mattjj force-pushed the djax-slice-sick4 branch 7 times, most recently from 11c0d77 to ccaf229 Compare October 9, 2022 05:34
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
@copybara-service copybara-service bot merged commit 25c6ef7 into google:main Oct 9, 2022
@mattjj mattjj deleted the djax-slice-sick4 branch October 10, 2022 21:57
copybara-service bot pushed a commit that referenced this pull request Oct 10, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480184773
copybara-service bot pushed a commit that referenced this pull request Oct 10, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480184773
@froystig
Copy link
Member

For reference, the "opaque dtypes" mention in the PR title here is to say that this builds on #11768 and subsequent developments.

copybara-service bot pushed a commit that referenced this pull request Oct 10, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480184773
copybara-service bot pushed a commit that referenced this pull request Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480184773
copybara-service bot pushed a commit that referenced this pull request Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480184773
copybara-service bot pushed a commit that referenced this pull request Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480184773
copybara-service bot pushed a commit that referenced this pull request Oct 11, 2022
…rivial jax.numpy shape validation code failed in some downstream user tests.

PiperOrigin-RevId: 480229237
clrpackages pushed a commit to clearlinux-pkgs/pypi-jax that referenced this pull request Oct 13, 2022
… 0.3.22

Adam Paszke (1):
      Support MANUAL collectives in top-level xmaps

Artem Belevich (1):
      Handle FP8 types.

Chao Chen (1):
      testSphHarmOrderZeroDegreeOne and test_custom_linear_solve_cholesky have been fixed in ROCm, no need to skip

George Necula (3):
      Expand support for __jax_array__ in jnp.array.
      [jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations
      [jax2tf] Implement jax2tf(pjit) for experimental_native_lowering

Jake VanderPlas (22):
      pure_callback: fix batching rule for multiple arguments
      Split parts of lax_numpy_test.py into separate test files.
      Re-land google/jax#12588 with minor fixes
      [typing] add full annotations for lax_numpy setops
      Add types to jax/_src/numpy/util.py
      jnp.prod & jnp.sum: promote to default integer type rather than int64/uint64
      jax.jacobian: propagate function signature to transformed function
      [typing] add annotations to jax.numpy.linalg
      [typing] add annotations to numpy.fft
      [typing] add types for jax.numpy.polynomial
      [typing] overloads for jnp.linalg.svd & jnp.linalg.qr
      test: fix LaxNumpyTest:testConcatenate
      Move promote_like_jnp to jax.test_util
      jnp.average: support tuple axis
      [typing] add type annotations to lax.linalg functions
      Update scipy version in jax.scipy.fft
      [typing] annotate lax.slicing
      [typing] add annotations to jax.scipy.fft
      [typing] add annotations to jax.scipy.ndimage
      [typing] add type annotations to jax.scipy.linalg
      changelog: add missing github commit links
      Remove deprecated functionality from jax.test_util

Jason Furmanek (2):
      [ROCM] Add TENSORFLOW_ROCM_COMMIT parameter to ROCM ci build
      Add default setting for TENSORFLOW_ROCM_COMMIT

Krishna Haridasan (1):
      Add unsafe_buffer_pointer to _DeviceArray

Kuangyuan Chen (1):
      Turn on cpp pjit py default

Matthew Johnson (9):
      improve custom_jvp/vjp error messages
      [dynamic-shapes] small fix to einsum (and indexing)
      improve jit(f).lower(duck_args) and pjit(f).lower(duck_args)
      make device_put work with Sharding 2nd arg
      add test, small fixes
      fix -O / PYTHONOPTIMIZE bug
      make device_put(prngkeyarray, sharding) for Array
      implement bint arrays (opaque dtypes), add padding rules
      Rolling forward google/jax#12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests.

Parker Schuh (1):
      Rename Executable to LoadedExecutable within jax.

Peter Hawkins (10):
      Add an internal jtu.sample_product test decorator.
      Migrate api_test, lax_numpy_test, and lax_vmap_test to
      Switch lax_numpy_indexing_test to use jtu.sample_product.
      Use input-output aliasing for jaxlib GPU custom calls.
      Revert: Use input-output aliasing for jaxlib GPU custom calls.
      Reapply: Use input-output aliasing for jaxlib GPU custom calls.
      Copybara import of the project:
      Add changes accidentally omitted from
      Add input-output aliasing annotations for LAPACK calls on CPU.
      Fix compilation failure in lapack kernel under msan.

Ran Ran (5):
      Set JAX_PLATFORMS=tpu,cpu on TPUs
      Add set up message for JAX_PLATFORMS
      Update set up message
      Update message and change log
      Address comments for change log

Rishabh Kabra (1):
      Clarify docs for `fori_loop`, noting that negative or custom increments are not supported.

Rohit Santhanam (1):
      [ROCm] Upgrade to ROCm 5.3 and associated enhancements

Roy Frostig (1):
      add input/output sharding to executable protocol

Serge Durand (1):
      Fix book links

Sharad Vikram (2):
      Test that array layout is preserved in Python callbacks
      Fix collect_profile _src import

Skye Wanderman-Milne (2):
      Update version and CHANGELOG for jax 0.3.21 release
      Update WORKSPACE and setup.py for jaxlib 0.3.22 release

Sudhakar (1):
      Add multihost GPU CI run with last public jaxlib release

Tianjian Lu (4):
      [sparse] Move broadcasting_vmap to sparse util.
      [sparse] Add conversions between BCSR and BCOO.
      [sparse] Bug fix in _validate_bcsr.
      [sparse] BCSR fromdense and todense.

Yash Katariya (11):
      `get_device_buffers()` on ShardedBuffer if config.jax_array is enabled because jax.Arrray does not work with ShardedBuffer since jax.Array is like a ShardedBuffer.
      Create `Array`s from `__getitem__` and `__iter__`. This is done by `device_put`ting from the host to default device which is suboptimal. But there is a TODO to fix this!
      Add `host_local_array_to_global_array` and `global_array_to_host_local_array` for enabling transition to jax.Array.
      Add sharding to `DeviceArray` and `ShardedDeviceArray` as a compatibility change to rollout `jax.Array`.
      Add `addressable_shards` to SDA and DA as a compatibility API to match with `jax.Array`. This will aid in transition to `jax.Array`.
      Take `shardings` as a parameter to `deserialize` and `run_deserialization` instead of `mesh` and `pspecs`.
      Make `is_fully_replicated` and `is_fully_addressble` a property rather than a method.
      Lift `lambda x: x` to the top level so that we don't recompile on every invocation of `process_allgather`.
      Fix the type annotation of return type of `device_buffer` and `device_buffers` which return `ArrayImpl` instead of DeviceArray.
      Add weak_type attribute to `Array` since it exists on DA (but doesn't exist on SDA).
      Add support for calculating the device_assignment when there are no inputs to `jit` and `pjit`.

dependabot[bot] (1):
      Bump styfle/cancel-workflow-action from 0.10.0 to 0.10.1

jax authors (1):
      Copybara import of the project:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants