Skip to content

Commit

Permalink
Merge pull request #401 from hawkinsp/master
Browse files Browse the repository at this point in the history
Remove type conversion table from xla_bridge.py.
  • Loading branch information
hawkinsp committed Feb 18, 2019
2 parents 27bedc2 + 2738acc commit b322833
Showing 1 changed file with 4 additions and 29 deletions.
33 changes: 4 additions & 29 deletions jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,41 +210,15 @@ def device_put(pyval, replica=0):

### utility functions

# Similar or identical dtype-to-etype conversion tables exist in the XLA
# clients, but because their organization hasn't been made consistent across
# clients yet, we repeat the information here.
_etype_to_dtype = {
xla_data_pb2.PRED: onp.dtype('bool'),
xla_data_pb2.S8: onp.dtype('int8'),
xla_data_pb2.S16: onp.dtype('int16'),
xla_data_pb2.S32: onp.dtype('int32'),
xla_data_pb2.S64: onp.dtype('int64'),
xla_data_pb2.U8: onp.dtype('uint8'),
xla_data_pb2.U16: onp.dtype('uint16'),
xla_data_pb2.U32: onp.dtype('uint32'),
xla_data_pb2.U64: onp.dtype('uint64'),
xla_data_pb2.F16: onp.dtype('float16'),
xla_data_pb2.F32: onp.dtype('float32'),
xla_data_pb2.F64: onp.dtype('float64'),
xla_data_pb2.C64: onp.dtype('complex64'),
xla_data_pb2.C128: onp.dtype('complex128'),
}

# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
# when keying by dtype in this dict, we use the string form of dtypes.
_dtype_to_etype = {str(dt): et for et, dt in _etype_to_dtype.items()}


@memoize
def dtype_to_etype(dtype):
"""Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
return _dtype_to_etype[canonicalize_dtype(dtype)]
return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[canonicalize_dtype(dtype)]

@memoize
def dtype_to_etype_exact(dtype):
"""Convert from dtype to exact etype (ignoring FLAGS.jax_enable_x64)."""
return _dtype_to_etype[str(onp.dtype(dtype))]
return xla_client.dtype_to_etype(dtype)


_dtype_to_32bit_dtype = {
Expand All @@ -268,7 +242,8 @@ def canonicalize_dtype(dtype):

@memoize_thunk
def supported_numpy_dtypes():
return {canonicalize_dtype(dtype) for dtype in _etype_to_dtype.values()}
return {canonicalize_dtype(dtype)
for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}


def canonicalize_shape(shape):
Expand Down

0 comments on commit b322833

Please sign in to comment.