Skip to content

Commit

Permalink
Merge pull request #399 from hawkinsp/types
Browse files Browse the repository at this point in the history
Add {float16,uint16,uint8,int16,int8} types to abstract_arrays.
  • Loading branch information
hawkinsp committed Feb 18, 2019
2 parents 848b769 + 0129e94 commit 99abdf9
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions jax/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ def zeros_like_array(x):
dtype = xla_bridge.canonicalize_dtype(onp.result_type(x))
return onp.broadcast_to(onp.array(0, dtype), onp.shape(x))

array_types = [onp.ndarray, onp.float64, onp.float32, onp.complex64,
onp.complex128, onp.int64, onp.int32, onp.bool_, onp.uint64,
onp.uint32, complex, float, int, bool]
array_types = [onp.ndarray, onp.float64, onp.float32, onp.float16,
onp.complex64, onp.complex128,
onp.int64, onp.int32, onp.int16, onp.int8,
onp.bool_, onp.uint64, onp.uint32, onp.uint16, onp.uint8,
complex, float, int, bool]

for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
Expand Down

0 comments on commit 99abdf9

Please sign in to comment.