Skip to content

Commit

Permalink
Add complex number DLPack support to JAX and TensorFlow.
Browse files Browse the repository at this point in the history
Fixes #9497

PiperOrigin-RevId: 427579098
  • Loading branch information
hawkinsp authored and jax authors committed Feb 9, 2022
1 parent 74506c7 commit 8af0d8d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions jax/_src/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
from jax import numpy as jnp
from jax._src import device_array
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_bridge

SUPPORTED_DTYPES = set([jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64,
jnp.float16, jnp.bfloat16, jnp.float32, jnp.float64])
SUPPORTED_DTYPES = frozenset({
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
jnp.float64})

if xla_extension_version >= 58:
SUPPORTED_DTYPES = SUPPORTED_DTYPES | {jnp.complex64, jnp.complex128}


def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False):
Expand Down

0 comments on commit 8af0d8d

Please sign in to comment.