Skip to content

Commit

Permalink
Merge pull request #405 from hawkinsp/fixes
Browse files Browse the repository at this point in the history
Remove obsolete workarounds for bugs that seem fixed.
  • Loading branch information
hawkinsp committed Feb 18, 2019
2 parents 5180549 + 32be966 commit a44d248
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 15 deletions.
12 changes: 1 addition & 11 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,7 @@ def __init__(self, device_buffer, shape, dtype, ndim, size):
def _value(self):
if self._npy_value is None:
self._npy_value = self.device_buffer.to_py()
try:
self._npy_value.flags.writeable = False
except AttributeError:
# TODO(mattjj): bug with C64 on TPU backend, C64 values returned as pair
if onp.issubdtype(self.dtype, onp.complexfloating):
a, b = self._npy_value
npy_value = onp.stack([a, b], -1).view(self.dtype).reshape(self.shape)
npy_value.flags.writeable = False
self._npy_value = npy_value
else:
raise
self._npy_value.flags.writeable = False
return self._npy_value

def copy(self):
Expand Down
4 changes: 0 additions & 4 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,10 +1243,6 @@ def testRoll(self, shape, dtype, shifts, axis, rng):
for rng_indices in [jtu.rand_int(-5, 5)]))
def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode, rng,
rng_indices):
if (FLAGS.jax_test_dut.startswith("tpu")
and onp.issubdtype(dtype, onp.complexfloating)):
self.skipTest("skipping complex dtype on TPU") # TODO(mattjj): investigate failures

def args_maker():
x = rng(shape, dtype)
i = rng_indices(index_shape, index_dtype)
Expand Down

0 comments on commit a44d248

Please sign in to comment.