Skip to content

Commit

Permalink
Fix a pytype error that slipped past the cleanup resulting this LSC: …
Browse files Browse the repository at this point in the history
…go/lsc-jnp-ndarray-types

This is breaking upstream builds that depend on neural_tangents, such as         //learning/deepmind/public/tools/ml_python:core_deps

PiperOrigin-RevId: 513279197
  • Loading branch information
erikfrey authored and romanngg committed Mar 9, 2023
1 parent 6681c30 commit d225284
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3732,7 +3732,7 @@ def apply_fun(params, inputs, **kwargs):


def _pos_emb_identity(shape: Sequence[int]) -> np.ndarray:
size = utils.size_at(shape)
size = utils.size_at(shape) # pytype: disable=wrong-arg-types # jax-ndarray
R = np.eye(size).reshape(tuple(shape) * 2)
R = utils.zip_axes(R)
return R
Expand Down

0 comments on commit d225284

Please sign in to comment.