Skip to content

Commit

Permalink
Remove deprecated unsafe_raw_array method from PRNG keys
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595190146
  • Loading branch information
Jake VanderPlas authored and jax authors committed Jan 2, 2024
1 parent e6c8901 commit fff5ea5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Remember to align the itemized text with the first line of an item within a list
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.
* Deprecations
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
This includes:
Expand All @@ -34,6 +34,8 @@ Remember to align the itemized text with the first line of an item within a list
* from {mod}`jax.numpy`: `NINF`, `NZERO`, `PZERO`, `row_stack`, `issubsctype`,
`trapz`, and `in1d`.
* from {mod}`jax.scipy.linalg`: `tril` and `triu`.
* The previously-deprecated method `PRNGKeyArray.unsafe_raw_array` has been
removed. Use {func}`jax.random.key_data` instead.

## jaxlib 0.4.24

Expand Down
8 changes: 0 additions & 8 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import math
import operator as op
from typing import Any, Callable, NamedTuple
import warnings

import numpy as np

Expand Down Expand Up @@ -308,13 +307,6 @@ def itemsize(self):
on_device_size_in_bytes = property(op.attrgetter('_base_array.on_device_size_in_bytes')) # type: ignore[assignment]
unsafe_buffer_pointer = property(op.attrgetter('_base_array.unsafe_buffer_pointer')) # type: ignore[assignment]

def unsafe_raw_array(self):
# deprecated on 13 Sept 2023
raise warnings.warn(
'The `unsafe_raw_array` method of PRNG key arrays is deprecated. '
'Use `jax.random.key_data` instead.', DeprecationWarning, stacklevel=2)
return self._base_array

def addressable_data(self, index: int) -> PRNGKeyArrayImpl:
return PRNGKeyArrayImpl(self._impl, self._base_array.addressable_data(index))

Expand Down

0 comments on commit fff5ea5

Please sign in to comment.