-
Notifications
You must be signed in to change notification settings - Fork 340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test for deserialization of GPU executables #12161
Test for deserialization of GPU executables #12161
Conversation
FYI, the misbehaving deserialization fallback in StreamExecutorGpuClient::DeserializeExecutable was introduced by 203ffbd. |
69d3607
to
8bf3c76
Compare
Thanks for fixing it! To give more context, the main issue is that
xla/xla/pjrt/pjrt_stream_executor_client.cc Line 3241 in e25d99a
|
https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/stream_executor_executable.cc#L27 and https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/pjrt_stream_executor_client.cc#L3241 serialize the executables into different protos. Previous attempts to try to deserialize the executable as StreamExecutorExecutableProto first causes issues openxla/xla#12161. Therefore we will disable AOT GPU test until how executables are serialized is unified. PiperOrigin-RevId: 631084054
https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/stream_executor_executable.cc#L27 and https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/pjrt_stream_executor_client.cc#L3241 serializes the executables into different protos. Previous attempts to try to deserialize the executable as StreamExecutorExecutableProto first causes issues openxla/xla#12161. Therefore we will disable AOT GPU test until how executable is serialized is unified. PiperOrigin-RevId: 631084054
Imported from GitHub PR openxla/xla#12161 This patch avoids deserialization of GPU executables (e.g., for JAX compilation cache) as StreamExecutorExecutableProto and always deserializes as ExecutableAndOptionsProto. This fixes a problem where executable would be serialized as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize as StreamExecutorExecutableProto and sometimes succeed, reading mangled data. See the test for an example of an executable that would fail to deserialize without this patch. In JAX, the following program fails to use the cache without this patch: ``` import jax jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache') jax.numpy.array([1234]) ``` The output in the console (on the second run) is: ``` [libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/wire_format_lite.cc:618] String field 'xla.gpu.CompilationResultProto.asm_text' contains invalid UTF-8 data when parsing a protocol buffer. Use the 'bytes' type if you intend to send raw bytes. /opt/jax/jax/_src/compiler.py:522: UserWarning: Error reading persistent compilation cache entry for 'jit_convert_element_type': XlaRuntimeError: INTERNAL: Failed to parse serialized GpuThunkAotCompilationResult. ``` With this patch, JAX successfully caches the function. Copybara import of the project: -- 8bf3c7689b8099eecfd70e02de6fb3432669f7b8 by Jaroslav Sevcik <jsevcik@nvidia.com>: Avoid deserializing GPU executables as StreamExecutorExecutableProto Merging this change closes #12161 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 8bf3c7689b8099eecfd70e02de6fb3432669f7b8 PiperOrigin-RevId: 631112289
Imported from GitHub PR openxla/xla#12161 This patch avoids deserialization of GPU executables (e.g., for JAX compilation cache) as StreamExecutorExecutableProto and always deserializes as ExecutableAndOptionsProto. This fixes a problem where executable would be serialized as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize as StreamExecutorExecutableProto and sometimes succeed, reading mangled data. See the test for an example of an executable that would fail to deserialize without this patch. In JAX, the following program fails to use the cache without this patch: ``` import jax jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache') jax.numpy.array([1234]) ``` The output in the console (on the second run) is: ``` [libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/wire_format_lite.cc:618] String field 'xla.gpu.CompilationResultProto.asm_text' contains invalid UTF-8 data when parsing a protocol buffer. Use the 'bytes' type if you intend to send raw bytes. /opt/jax/jax/_src/compiler.py:522: UserWarning: Error reading persistent compilation cache entry for 'jit_convert_element_type': XlaRuntimeError: INTERNAL: Failed to parse serialized GpuThunkAotCompilationResult. ``` With this patch, JAX successfully caches the function. Copybara import of the project: -- 8bf3c7689b8099eecfd70e02de6fb3432669f7b8 by Jaroslav Sevcik <jsevcik@nvidia.com>: Avoid deserializing GPU executables as StreamExecutorExecutableProto Merging this change closes #12161 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 8bf3c7689b8099eecfd70e02de6fb3432669f7b8 PiperOrigin-RevId: 631112289
This breaks a JAX test I'm not an expert in PJRT. |
Imported from GitHub PR openxla/xla#12161 This patch avoids deserialization of GPU executables (e.g., for JAX compilation cache) as StreamExecutorExecutableProto and always deserializes as ExecutableAndOptionsProto. This fixes a problem where executable would be serialized as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize as StreamExecutorExecutableProto and sometimes succeed, reading mangled data. See the test for an example of an executable that would fail to deserialize without this patch. In JAX, the following program fails to use the cache without this patch: ``` import jax jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache') jax.numpy.array([1234]) ``` The output in the console (on the second run) is: ``` [libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/wire_format_lite.cc:618] String field 'xla.gpu.CompilationResultProto.asm_text' contains invalid UTF-8 data when parsing a protocol buffer. Use the 'bytes' type if you intend to send raw bytes. /opt/jax/jax/_src/compiler.py:522: UserWarning: Error reading persistent compilation cache entry for 'jit_convert_element_type': XlaRuntimeError: INTERNAL: Failed to parse serialized GpuThunkAotCompilationResult. ``` With this patch, JAX successfully caches the function. Copybara import of the project: -- 8bf3c7689b8099eecfd70e02de6fb3432669f7b8 by Jaroslav Sevcik <jsevcik@nvidia.com>: Avoid deserializing GPU executables as StreamExecutorExecutableProto Merging this change closes #12161 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 8bf3c7689b8099eecfd70e02de6fb3432669f7b8 PiperOrigin-RevId: 631112289
Could it be that google/jax#21085 has not landed yet? (I think that JAX PR disabled the test.) |
8bf3c76
to
818765f
Compare
That's right. This PR has been reverted. @jyingl3 I believe you authored the PR. Do you still plan to submit a similar PR? |
Imported from GitHub PR openxla/xla#12161 This patch tests that serialization and deserialization of GPU executables (e.g., for JAX compilation cache) succeeds. The test covers the fix in [PR 12184](openxla/xla#12184). For completeness, PR 12184 fixes a problem where executable would be serialized as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize as StreamExecutorExecutableProto and sometimes succeed, reading mangled data. The test in this patch exercises one such example. Copybara import of the project: -- 818765f7b1f5f7282d3f2756b2a03ac556201a6c by Jaroslav Sevcik <jsevcik@nvidia.com>: Avoid deserializing GPU executables as StreamExecutorExecutableProto Merging this change closes #12161 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c PiperOrigin-RevId: 631112289
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c PiperOrigin-RevId: 629936020
Dynamic slice requires specific values for the indices that it takes in as operands. However, under computation extraction in the autotuner, all operands are replaced with random parameter values (which then represent incorrect/invalid offsets.) In order to avoid this, we replace dynamic-slice instructions with static slice instructions indexed at 0 along all dimensions. This CL is needed for a later CL which will add support for fusing dynamic-slice into triton fusions. The original author of this CL is jvstokes and I rewrote it a bit. FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c PiperOrigin-RevId: 631438546
This is consistent with all the callers and simplifies the callers. FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c PiperOrigin-RevId: 630340797
Imported from GitHub PR openxla/xla#12161 This patch tests that serialization and deserialization of GPU executables (e.g., for JAX compilation cache) succeeds. The test covers the fix in [PR 12184](openxla/xla#12184). For completeness, PR 12184 fixes a problem where executable would be serialized as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize as StreamExecutorExecutableProto and sometimes succeed, reading mangled data. The test in this patch exercises one such example. Copybara import of the project: -- 818765f7b1f5f7282d3f2756b2a03ac556201a6c by Jaroslav Sevcik <jsevcik@nvidia.com>: Avoid deserializing GPU executables as StreamExecutorExecutableProto Merging this change closes #12161 PiperOrigin-RevId: 631740399
This patch tests that serialization and deserialization of GPU executables
(e.g., for JAX compilation cache) succeeds.
The test covers the fix in PR 12184.
For completeness, PR 12184 fixes a problem where executable would be serialized
as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
The test in this patch exercises one such example.