Skip to content

Add tests and fix for PRNG key handling in Orbax + ColocatedPython.#3179

Merged
copybara-service[bot] merged 1 commit into
mainfrom
test_911983646
May 11, 2026
Merged

Add tests and fix for PRNG key handling in Orbax + ColocatedPython.#3179
copybara-service[bot] merged 1 commit into
mainfrom
test_911983646

Conversation

@copybara-service
Copy link
Copy Markdown

Add tests and fix for PRNG key handling in Orbax + ColocatedPython.

Tests are added to orbax/checkpoint/_src/multihost/dispatchers_test.py
to demonstrate a type mismatch issue: Orbax's restore process generates
result_specs with a physical dtype (e.g., uint32) for PRNG keys, but
the deserialization returns PRNGKeyArray objects with a PRNG key
dtype. This mismatch can cause errors in Pathways' IFRT transport
layer. This is now fixed in jax_array_handlers.py by creating the desired result_specs in _get_abstract_arrays.

@copybara-service copybara-service Bot force-pushed the test_911983646 branch 3 times, most recently from 8b37034 to 23de1a0 Compare May 11, 2026 18:31
Tests are added to `orbax/checkpoint/_src/multihost/dispatchers_test.py`
to demonstrate a type mismatch issue: Orbax's restore process generates
`result_specs` with a physical dtype (e.g., uint32) for PRNG keys, but
the deserialization returns `PRNGKeyArray` objects with a PRNG key
dtype.  This mismatch can cause errors in Pathways' IFRT transport
layer. This is now fixed in `jax_array_handlers.py` by creating the desired `result_specs` in `_get_abstract_arrays`.

PiperOrigin-RevId: 913931497
@copybara-service copybara-service Bot merged commit 075b32e into main May 11, 2026
@copybara-service copybara-service Bot deleted the test_911983646 branch May 11, 2026 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant