Skip to content

Commit

Permalink
[NumPy] Remove references to deprecated NumPy type aliases.
Browse files Browse the repository at this point in the history
This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str).

NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy.

PiperOrigin-RevId: 496484564
  • Loading branch information
hawkinsp authored and Copybara-Service committed Dec 19, 2022
1 parent 0b3ed2b commit a8fc2ce
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion trax/layers/research/efficient_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_lsh_self_attention_masked_non_causal(self):
hidden = 8

x = np.random.uniform(size=(batch, max_len, hidden))
mask = np.ones((batch, max_len)).astype(np.bool)
mask = np.ones((batch, max_len)).astype(bool)
rngs = jax.random.randint(
jax.random.PRNGKey(0), (batch,), minval=1, maxval=max_len - 1)

Expand Down
4 changes: 2 additions & 2 deletions trax/models/research/rse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_shuffle_layer_log_times_is_identity(self):
x = _input_with_indice_as_values(seq_len, d_model)
_, _ = shuffle_layer.init(shapes.signature(x))
y = x
for _ in range(np.int(np.log2(seq_len))):
for _ in range(int(np.log2(seq_len))):
y = shuffle_layer(y)
self._assert_equal_tensors(x, y)

Expand All @@ -71,7 +71,7 @@ def test_reverse_shuffle_layer_log_times_is_identity(self):
x = _input_with_indice_as_values(seq_len, d_model)
_, _ = reverse_shuffle_layer.init(shapes.signature(x))
y = x
for _ in range(np.int(np.log2(seq_len))):
for _ in range(int(np.log2(seq_len))):
y = reverse_shuffle_layer(y)
self._assert_equal_tensors(x, y)

Expand Down
2 changes: 1 addition & 1 deletion trax/models/research/terraformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_terraformer_quick(self, backend, encoder_attention_type, preembed):

if preembed:
model_inputs = [np.ones((1, max_len, 3)).astype(np.float32),
np.ones((1, max_len)).astype(np.bool)]
np.ones((1, max_len)).astype(bool)]
else:
model_inputs = [np.ones((1, max_len)).astype(np.int32)]
x = model_inputs + [np.ones((1, max_len)).astype(np.int32)]
Expand Down
6 changes: 3 additions & 3 deletions trax/rl/advantages_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def estimate_advantage_bias_and_variance(
else:
values = np.zeros_like(returns)

dones = np.zeros_like(returns, dtype=np.bool)
dones = np.zeros_like(returns, dtype=bool)
adv = advantage_fn(rewards, returns, values, dones, discount_mask)
if discount_true_return:
mean_return = true_returns[0, 0]
Expand Down Expand Up @@ -219,8 +219,8 @@ def test_future_return_is_zero_iff_discount_mask_is_on(self, advantage_fn):
# (... when gamma=0)
rewards = np.array([[1, 2, 3, 4]], dtype=np.float32)
values = np.array([[5, 6, 7, 8]], dtype=np.float32)
dones = np.zeros_like(rewards, dtype=np.bool)
discount_mask = np.array([[1, 0, 1, 0]], dtype=np.bool)
dones = np.zeros_like(rewards, dtype=bool)
discount_mask = np.array([[1, 0, 1, 0]], dtype=bool)
gammas = advantages.mask_discount(0.0, discount_mask)
returns = advantages.discounted_returns(rewards, gammas)
adv = advantage_fn(gamma=0.0, margin=1)(
Expand Down
2 changes: 1 addition & 1 deletion trax/tf_numpy/numpy_impl/tests/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def run_test(condition, arr, *args, **kwargs):
self.match(
array_ops.compress(arg1, arg2, *args, **kwargs),
np.compress(
np.asarray(arg1).astype(np.bool), arg2, *args, **kwargs))
np.asarray(arg1).astype(bool), arg2, *args, **kwargs))

run_test([True], 5)
run_test([False], 5)
Expand Down

0 comments on commit a8fc2ce

Please sign in to comment.