Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.ndarray.view: implement all dtypes #14526

Merged
merged 1 commit into from
Feb 16, 2023

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Feb 16, 2023

This simpler and more complete implementation is enabled by the enhancements to lax.bitcast_convert_type in #14501

@jakevdp jakevdp self-assigned this Feb 16, 2023
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Feb 16, 2023
@copybara-service copybara-service bot merged commit 66e7c0c into google:main Feb 16, 2023
@jakevdp jakevdp deleted the ndarray-view branch February 16, 2023 19:15
copybara-service bot pushed a commit that referenced this pull request Feb 17, 2023
copybara-service bot pushed a commit that referenced this pull request Feb 17, 2023
copybara-service bot pushed a commit that referenced this pull request Feb 17, 2023
jakevdp added a commit to jakevdp/jax that referenced this pull request Feb 17, 2023
Re-land google#14526 with fixes to scalar views
clrpackages pushed a commit to clearlinux-pkgs/pypi-jax that referenced this pull request Mar 9, 2023
…0.4.5

Adam Paszke (1):
      Slightly increase the tolerance in sparse tests to avoid flakiness

Anish Tondwalkar (2):
      [mhlo] Use XLA pretty-printed format for shardingattr
      [CHLO] Add erf_inv and lowering to mhlo

Brennan Saeta (1):
      Export the `Shard` type.

Chao (1):
      Update Dockerfile.ms

Chao Chen (1):
      jax-rocm runtime/ci dockerfile multistages

Eugene Burmako (1):
      Remove *_mhlo compatibility shims from jaxlib

Frederic Bastien (2):
      Small fix as the module name changed.
      Add a new link instead of a TODO.

Geoffrey Martin-Noble (1):
      Update links from iree-org/iree to openxla/iree

George Necula (8):
      [jax2tf] Customize limitations for native lowering mode
      [jax2tf] Enable strict platform checking for native serialized modules.
      Copybara import of the project:
      [jax2tf] Use CUDA and ROCM instead of GPU for XlaCallModuleOp platforms
      [jax2tf] Enable the strict platform checks for native serialization
      [jax2tf] Include more sharding annotations in the TF graph
      [jax2tf] Add support for cross-platform lowering in native serialization
      A different way to achieve cross-platform lowering, withouth any

Gijs Koning (1):
      Small update to jax profiling docs

Ikko Eltociear Ashimine (1):
      Fix typo in maps.py

Jake VanderPlas (30):
      lax.bitcast_convert_type: support casting between types of different width
      jnp.ndarray.view: implement all dtypes
      lax.bitcast_convert_type: better input validation
      [sparse] bring sparse.csr API in line with sparse.coo
      [sparse] test coo/csr extra nse
      Roll-back google/jax#14526 because it breaks `view()` on scalar inputs
      DOC: add alternative for pytree initialization
      BUG: avoid passing functions directly to abstractmethod
      TMP
      jnp.ndarray.view: implement all dtypes
      DOC: improve usage recommendation in jax.typing
      [sparse] bcoo_dot_general_sampled: faster special case
      DOC: add is_ready() to CHANGELOG
      [sparse] remove handling of padded indices from COO/CSR
      [sparse] fix bug in oob index correction
      Deprecate three jax.Array methods:
      [sparse] bcoo_dot_general_sampled: another special case
      BUG: raise error when shaped_abstractify is called on JAX scalar types
      DOC: improve lax.dot_general documentation
      DOC: fix jax.numpy.Array discussion
      [sparse] use precision=HIGHEST in bcoo_dot_general_sampled
      [sparse] temporarily disable bcoo_dot_general_sampled fast cases test on GPU
      DOC: mention scale/rate parameter in random.gamma
      Refactor bcoo_dot_general GPU lowering
      [sparse] fix dot_general precision in test
      [sparse] add low-level primitives wrapping cuda SpMV & SpMM
      DOC: fix readthedocs for sphinx-book-theme=1.0
      [sparse] fix expected warning in batched_matmat test case
      DOC: update sphinx & sphinx-autodoc-typehints
      [sparse] adjust tolerance on bcoo_dot_general_sampled

Johannes Reifferscheid (1):
      Don't create invalid bools in lax_numpy_test/testView.

John QiangZhang (1):
      Add tf.convert_to_tensor for call_tf gradient outputs.

Lena Martens (2):
      Checkify: Remove stray `raise_as_much_as_possible`.
      Update api_benchmark to not use any deprecated APIs.

Matthew Johnson (7):
      remove accidental re-export of __future__.annotations from jax/core.py
      [shard-map] add systematic tests for eager, jit, autodiff
      [shard-map] add annotations and notes to shard_map_test.py
      add remat tutorial docs
      fixes from reviewers
      custom_jvp symbolic zeros support
      [custom_vjp] bwd function should not be WrappedFun, may run multiple times

Parker Schuh (4):
      [Rollforward] Convert _arrays to return PyArray instead of PyBuffer.
      Rollback of array fix again for perf regression.
      Add PyArrayResultHandler which behaves like
      Rollforward of Add a fastpath to pmap_lib for sharding np.ndarray directly in c++.

Peter Hawkins (21):
      Remove more exported names from jax.interpreters.xla.
      Move global_device_array into its own BUILD target.
      Add keep_dep tag to :global_device_array build target to hint that it should be kept.
      Move contents of jax.experimental.global_device_array to jax._src.global_device_array.
      Reexport jax.interpreters.mlir.token_type.
      Limit visibility of Bazel target jax:global_device_array.
      Remove global_device_array from shared jax bazel library.
      Hide accidental exports from jax.core.
      Fix __module__ of jax.nn.initializers.* to be jax.nn.initializers.
      Export device_put_p as jax.lax.device_put_p.
      Update conda GPU install instructions.
      Remove jax._src deletion.
      Remove pytype suppression for jax/_src/config.py
      Remove circular dependency between source_info_util and util.
      [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
      Replace jax._src.util.prod with math.prod.
      Remove internal ndarray type name. Use Array throughout.
      [XLA:Python] Add buffer protocol support to jax.Array.
      Shorten alias chains for names exported in jax. namespace.
      Add "Open in Kaggle" buttons to Jupyter notebooks.
      Fix stale reference to util.prod.

Roy Frostig (4):
      remove several exported symbols from `jax.core`
      remove some exports from `jax.core`
      remove several symbols from `jax.core`
      remove several more symbols from `jax.core`

Sharad Vikram (6):
      Add some info in the docs about using `jax.debug.print` with f-strings
      Refactor effects system to use effect types, not objects
      Add `JaxprInputEffect` and refactor `StateEffect`s to use it
      Fix nondeterminism issue with ordered effects
      Refactor `Ref` abstract type to contain other `AbstractValue`s
      Remove TokenSet needing to have effects in a certain order

Skye Wanderman-Milne (3):
      Add `--pre` to nightly libtpu pip install command.
      Add back opt-barrier fallback, since the fallback sometimes prevents OOMs.
      Update Cloud TPU install command to be simpler.

Tianjian Lu (1):
      [sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.

Yash Katariya (23):
      Catch ImportError when importing `tf` instead of a broad exception catch. If not, this leads to weird errors in the other tests down the line.
      Finish jax and jaxlib 0.4.4 release
      Create a jax.Array from make_sharded_device_array since SDA is deprecated.
      Return jax.Array from GDA's callback APIs if jax.Array is True.
      Remove gda_benchmark file as GDA is deprecated.
      Make the _pjit_jaxpr cache more by not depending on the out_shardings. So if out_shardings argument of pjit changes, it should affect the jaxpr created because jaxpr creation is not dependent on out_shardings.
      Remove _ListWithW since it is not needed anymore
      Rename `jax.sharding.OpShardingSharding` to `jax.sharding.GSPMDSharding`. `jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
      Replace `op_sharding_sharding` with `gspmd_sharding`. This is purely an internal change.
      Rename `in_axis_resources` and `out_axis_resources` with `in_shardings` and `out_shardings`. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
      Return jax.Array with a single device sharding from make_device_array.
      Delete `PyShardedBuffer` since it was only used for GDAs and GDA is deprecated.
      Only check for `_device` if `device_buffer` is `xc.Buffer`.
      Use the standard jtu.create_global_mesh instead of creating a mesh from scratch.
      Add a helpful error message when device_putting with a Sharding that is incompatible with the shape of the input
      Create avals and pass them to _check_sharding rather than the actual value.
      Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
      Remove use_stablehlo as minimum mlir_api_version >= 43
      Mention that Pspecs are not allowed to be passed to jax.jit
      Replace usage of {in|out}_axis_resources with {in|out}_shardings
      Use math.prod instead of util.prod
      Pass the `jaxpr` from `pjit` since there is no need to trace it again in lower_sharding_computation. It also helps in preserving debug_info that already exists on the jaxpr to surface it in MHLO eventually.
      Use in_shardings and out_shardings since those are the new arguments that pjit has

jax authors (2):
      Add support for XLAGather with 2-D Batch Dimensions for `enable_xla=False`
      [XLA:Python] Add buffer protocol support to jax.Array.

jiayaobo (1):
      add random.chisquare and random.f

pizzud (3):
      lax_test: Create a separate module for lax-specific test utils in a new package.
      lax_scipy_test: Split into three so that each target is small enough to fit within a medium timeout.
      lax_scipy_test: Revert split into three targets.

tennessee_wallaceh (1):
      Update student-t sampling to use correct key for gamma

Øyvind Sigmundson Schøyen (1):
      DOC: fix typo in sph_harm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants