Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test for deserialization of GPU executables #12161

Conversation

jaro-sevcik
Copy link
Contributor

@jaro-sevcik jaro-sevcik commented May 5, 2024

This patch tests that serialization and deserialization of GPU executables
(e.g., for JAX compilation cache) succeeds.

The test covers the fix in PR 12184.

For completeness, PR 12184 fixes a problem where executable would be serialized
as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
The test in this patch exercises one such example.

@jaro-sevcik
Copy link
Contributor Author

FYI, the misbehaving deserialization fallback in StreamExecutorGpuClient::DeserializeExecutable was introduced by 203ffbd.

@hawkinsp hawkinsp requested a review from jyingl3 May 6, 2024 14:14
@jaro-sevcik jaro-sevcik force-pushed the serialize-as-stream-executor-executable branch from 69d3607 to 8bf3c76 Compare May 6, 2024 17:29
@jaro-sevcik jaro-sevcik changed the title Serialize GPU executables as StreamExecutorExecutableProto Only deserialize GPU executables as ExecutableAndOptionsProto May 6, 2024
@jyingl3
Copy link
Member

jyingl3 commented May 6, 2024

Thanks for fixing it!

To give more context, the main issue is that

StatusOr<std::string> StreamExecutorExecutable::SerializeExecutable() const {
and
StatusOr<std::string> PjRtStreamExecutorClient::SerializeExecutable(
serialize the executables into different protos.

copybara-service bot pushed a commit to google/jax that referenced this pull request May 6, 2024
https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/stream_executor_executable.cc#L27 and https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/pjrt_stream_executor_client.cc#L3241 serialize the executables into different protos. Previous attempts to try to deserialize the executable as StreamExecutorExecutableProto first causes issues openxla/xla#12161.

Therefore we will disable AOT GPU test until how executables are serialized is unified.

PiperOrigin-RevId: 631084054
copybara-service bot pushed a commit to google/jax that referenced this pull request May 6, 2024
https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/stream_executor_executable.cc#L27 and https://github.com/openxla/xla/blob/e25d99af0ca392afa9cb357eef73880de97edc35/xla/pjrt/pjrt_stream_executor_client.cc#L3241 serializes the executables into different protos. Previous attempts to try to deserialize the executable as StreamExecutorExecutableProto first causes issues openxla/xla#12161.

Therefore we will disable AOT GPU test until how executable is serialized is unified.

PiperOrigin-RevId: 631084054
@jaro-sevcik jaro-sevcik marked this pull request as ready for review May 6, 2024 18:01
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 6, 2024
Imported from GitHub PR openxla/xla#12161

This patch avoids deserialization of GPU executables (e.g., for JAX compilation cache)
as StreamExecutorExecutableProto and always deserializes as ExecutableAndOptionsProto.

This fixes a problem where executable would be serialized as ExecutableAndOptionsProto
by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
See the test for an example of an executable that would fail to deserialize without
this patch.

In JAX, the following program fails to use the cache without this patch:

```
import jax
jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache')
jax.numpy.array([1234])
```

The output in the console (on the second run) is:

```
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/wire_format_lite.cc:618] String field 'xla.gpu.CompilationResultProto.asm_text' contains invalid UTF-8 data when parsing a protocol buffer. Use the 'bytes' type if you intend to send raw bytes.
/opt/jax/jax/_src/compiler.py:522: UserWarning: Error reading persistent compilation cache entry for 'jit_convert_element_type': XlaRuntimeError: INTERNAL: Failed to parse serialized GpuThunkAotCompilationResult.
```

With this patch, JAX successfully caches the function.

Copybara import of the project:

--
8bf3c7689b8099eecfd70e02de6fb3432669f7b8 by Jaroslav Sevcik <jsevcik@nvidia.com>:

Avoid deserializing GPU executables as StreamExecutorExecutableProto

Merging this change closes #12161

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 8bf3c7689b8099eecfd70e02de6fb3432669f7b8
PiperOrigin-RevId: 631112289
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 6, 2024
Imported from GitHub PR openxla/xla#12161

This patch avoids deserialization of GPU executables (e.g., for JAX compilation cache)
as StreamExecutorExecutableProto and always deserializes as ExecutableAndOptionsProto.

This fixes a problem where executable would be serialized as ExecutableAndOptionsProto
by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
See the test for an example of an executable that would fail to deserialize without
this patch.

In JAX, the following program fails to use the cache without this patch:

```
import jax
jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache')
jax.numpy.array([1234])
```

The output in the console (on the second run) is:

```
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/wire_format_lite.cc:618] String field 'xla.gpu.CompilationResultProto.asm_text' contains invalid UTF-8 data when parsing a protocol buffer. Use the 'bytes' type if you intend to send raw bytes.
/opt/jax/jax/_src/compiler.py:522: UserWarning: Error reading persistent compilation cache entry for 'jit_convert_element_type': XlaRuntimeError: INTERNAL: Failed to parse serialized GpuThunkAotCompilationResult.
```

With this patch, JAX successfully caches the function.

Copybara import of the project:

--
8bf3c7689b8099eecfd70e02de6fb3432669f7b8 by Jaroslav Sevcik <jsevcik@nvidia.com>:

Avoid deserializing GPU executables as StreamExecutorExecutableProto

Merging this change closes #12161

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 8bf3c7689b8099eecfd70e02de6fb3432669f7b8
PiperOrigin-RevId: 631112289
@thomasjoerg
Copy link
Member

This breaks a JAX test
https://github.com/google/jax/blob/main/tests/aot_test.py#L98
fails with
XlaRuntimeError: INTERNAL: PjRtStreamExecutorClient::DeserializeExecutable proto deserialization failed

I'm not an expert in PJRT.
@jaro-sevcik can you tell why the breakage happens?

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 6, 2024
Imported from GitHub PR openxla/xla#12161

This patch avoids deserialization of GPU executables (e.g., for JAX compilation cache)
as StreamExecutorExecutableProto and always deserializes as ExecutableAndOptionsProto.

This fixes a problem where executable would be serialized as ExecutableAndOptionsProto
by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
See the test for an example of an executable that would fail to deserialize without
this patch.

In JAX, the following program fails to use the cache without this patch:

```
import jax
jax.experimental.compilation_cache.compilation_cache.set_cache_dir('/jaxcache')
jax.numpy.array([1234])
```

The output in the console (on the second run) is:

```
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/wire_format_lite.cc:618] String field 'xla.gpu.CompilationResultProto.asm_text' contains invalid UTF-8 data when parsing a protocol buffer. Use the 'bytes' type if you intend to send raw bytes.
/opt/jax/jax/_src/compiler.py:522: UserWarning: Error reading persistent compilation cache entry for 'jit_convert_element_type': XlaRuntimeError: INTERNAL: Failed to parse serialized GpuThunkAotCompilationResult.
```

With this patch, JAX successfully caches the function.

Copybara import of the project:

--
8bf3c7689b8099eecfd70e02de6fb3432669f7b8 by Jaroslav Sevcik <jsevcik@nvidia.com>:

Avoid deserializing GPU executables as StreamExecutorExecutableProto

Merging this change closes #12161

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 8bf3c7689b8099eecfd70e02de6fb3432669f7b8
PiperOrigin-RevId: 631112289
@jaro-sevcik
Copy link
Contributor Author

This breaks a JAX test https://github.com/google/jax/blob/main/tests/aot_test.py#L98 fails with XlaRuntimeError: INTERNAL: PjRtStreamExecutorClient::DeserializeExecutable proto deserialization failed

I'm not an expert in PJRT. @jaro-sevcik can you tell why the breakage happens?

Could it be that google/jax#21085 has not landed yet? (I think that JAX PR disabled the test.)

@jaro-sevcik jaro-sevcik force-pushed the serialize-as-stream-executor-executable branch from 8bf3c76 to 818765f Compare May 7, 2024 06:27
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label May 7, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 7, 2024
@jaro-sevcik jaro-sevcik changed the title Only deserialize GPU executables as ExecutableAndOptionsProto Test for deserialization of GPU executables May 7, 2024
@thomasjoerg
Copy link
Member

Could it be that google/jax#21085 has not landed yet? (I think that JAX PR disabled the test.)

That's right. This PR has been reverted.

@jyingl3 I believe you authored the PR. Do you still plan to submit a similar PR?

@jaro-sevcik
Copy link
Contributor Author

jaro-sevcik commented May 7, 2024

Instead of the JAX patch, @jyingl3 landed #12184 (that unifies LoadSerialize and DeserializeExecutable). I have updated this PR to be just a test for the bad case we saw (as suggested by @jyingl3).

@jyingl3
Copy link
Member

jyingl3 commented May 7, 2024

Sorry for the confusion! Comment

Instead of the JAX patch, @jyingl3 landed #12184 (that unifies LoadSerialize and DeserializeExecutable). I have updated this PR to be just a test for the bad case we saw (as suggested by @jyingl3).

Sorry for the confusion! Yes this comment is accurate.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 7, 2024
Imported from GitHub PR openxla/xla#12161

This patch tests that serialization and deserialization of GPU executables
(e.g., for JAX compilation cache) succeeds.

The test covers the fix in [PR 12184](openxla/xla#12184).

For completeness, PR 12184 fixes a problem where executable would be serialized
as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
The test in this patch exercises one such example.
Copybara import of the project:

--
818765f7b1f5f7282d3f2756b2a03ac556201a6c by Jaroslav Sevcik <jsevcik@nvidia.com>:

Avoid deserializing GPU executables as StreamExecutorExecutableProto

Merging this change closes #12161

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c
PiperOrigin-RevId: 631112289
@copybara-service copybara-service bot closed this in 691d487 May 8, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 8, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c
PiperOrigin-RevId: 629936020
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 8, 2024
Dynamic slice requires specific values for the indices that it takes in as operands. However, under computation extraction in the autotuner, all operands are replaced with random parameter values (which then represent incorrect/invalid offsets.) In order to avoid this, we replace dynamic-slice instructions with static slice instructions indexed at 0 along all dimensions.

This CL is needed for a later CL which will add support for fusing dynamic-slice into triton fusions.

The original author of this CL is jvstokes and I rewrote it a bit.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c
PiperOrigin-RevId: 631438546
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 8, 2024
This is consistent with all the callers and simplifies the callers.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#12161 from jaro-sevcik:serialize-as-stream-executor-executable 818765f7b1f5f7282d3f2756b2a03ac556201a6c
PiperOrigin-RevId: 630340797
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request May 8, 2024
Imported from GitHub PR openxla/xla#12161

This patch tests that serialization and deserialization of GPU executables
(e.g., for JAX compilation cache) succeeds.

The test covers the fix in [PR 12184](openxla/xla#12184).

For completeness, PR 12184 fixes a problem where executable would be serialized
as ExecutableAndOptionsProto by PjRtStreamExecutorClient::SerializeExecutable, but
StreamExecutorGpuClient::DeserializeExecutable would first try to deserialize
as StreamExecutorExecutableProto and sometimes succeed, reading mangled data.
The test in this patch exercises one such example.
Copybara import of the project:

--
818765f7b1f5f7282d3f2756b2a03ac556201a6c by Jaroslav Sevcik <jsevcik@nvidia.com>:

Avoid deserializing GPU executables as StreamExecutorExecutableProto

Merging this change closes #12161

PiperOrigin-RevId: 631740399
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants