Skip to content

Commit

Permalink
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511475596
Change-Id: I6808a939cdd21acb9b2e9b912575118afd249349
  • Loading branch information
hawkinsp authored and lanctot committed Feb 22, 2023
1 parent 9c11edc commit 490695b
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def cfr_br_meta_data(
player_2_last_best_response_values[-1],
)

return (
return ( # pytype: disable=bad-return-type # jax-ndarray
counterfactual_values_player1,
counterfactual_values_player2,
player_2_last_best_response_values,
Expand Down Expand Up @@ -440,7 +440,7 @@ def training_optimizer(self):
cfvalues = cfvalues_per_player[player_ix][infoset.infostate_string]
train_dataset.append((cfvalues, infoset))

dataset = dataset_generator.Dataset(train_dataset, FLAGS.batch_size)
dataset = dataset_generator.Dataset(train_dataset, FLAGS.batch_size) # pytype: disable=wrong-arg-types # jax-ndarray
data_loader = dataset.get_batch()
for _ in range(FLAGS.num_batches):
batch = next(data_loader)
Expand Down

0 comments on commit 490695b

Please sign in to comment.