Skip to content

Commit

Permalink
Ignore incorrect type annotations related to jax dtypes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571882947
  • Loading branch information
Jake VanderPlas authored and Rax Developers committed Oct 9, 2023
1 parent ca9ec51 commit 34825d3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/flax_integration/main.py
Expand Up @@ -151,7 +151,7 @@ def _loss_fn(params):
scores = model.apply(
flax.core.copy(model_state, {"params": params}), inputs
)
loss = loss_fn(scores, labels, where=mask, reduce_fn=jnp.mean)
loss = loss_fn(scores, labels, where=mask, reduce_fn=jnp.mean) # pytype: disable=wrong-arg-types # jnp-type
return loss

params = model_state["params"]
Expand Down

0 comments on commit 34825d3

Please sign in to comment.