Skip to content

Commit

Permalink
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511322372
  • Loading branch information
hawkinsp authored and romanngg committed Mar 9, 2023
1 parent c5f8eb9 commit 5be5afd
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion neural_tangents/_src/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _reshape_kernel_for_pmap(k: Kernel,
@utils.nt_tree_fn()
def _set_cov2_to_none(k: _ArrayOrKernel) -> _ArrayOrKernel:
if isinstance(k, Kernel):
k = k.replace(cov2=None)
k = k.replace(cov2=None) # pytype: disable=attribute-error # jax-ndarray
return k


Expand Down
4 changes: 2 additions & 2 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2047,7 +2047,7 @@ def _trim_invals(
if isinstance(trimmed_invals[i], UndefinedPrimal):
trimmed_invals[i] = _trim_axis(trimmed_invals[i], in_d)

return trimmed_invals
return trimmed_invals # pytype: disable=bad-return-type # jax-ndarray


def _trim_eqn(
Expand Down Expand Up @@ -2259,7 +2259,7 @@ def _write_primal(
val: Union[np.ndarray, UndefinedPrimal]
):
if not ad.is_undefined_primal(val):
env[v] = val
env[v] = val # pytype: disable=container-type-mismatch # jax-ndarray


def _get_fwd(
Expand Down
8 changes: 4 additions & 4 deletions neural_tangents/_src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def get_k_train_train(get: Tuple[str, ...]) -> _Kernel:
if not any(g in k_dd_cache for g in get):
k_dd_cache.update(
kernel_fn(x_train, None, get,
**kernel_fn_train_train_kwargs)._asdict())
**kernel_fn_train_train_kwargs)._asdict()) # pytype: disable=attribute-error # jax-ndarray
else:
for g in get:
if g not in k_dd_cache:
Expand Down Expand Up @@ -1130,12 +1130,12 @@ def max_learning_rate(
The maximal feasible learning rate for infinite width NNs.
"""
ntk_train_train = utils.make_2d(ntk_train_train)
factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size
factor = ntk_train_train.shape[0] if y_train_size is None else y_train_size # pytype: disable=attribute-error # jax-ndarray

if _is_on_cpu(ntk_train_train):
max_eva = osp.linalg.eigvalsh(ntk_train_train,
eigvals=(ntk_train_train.shape[0] - 1,
ntk_train_train.shape[0] - 1))[-1]
eigvals=(ntk_train_train.shape[0] - 1, # pytype: disable=attribute-error # jax-ndarray
ntk_train_train.shape[0] - 1))[-1] # pytype: disable=attribute-error # jax-ndarray
else:
max_eva = np.linalg.eigvalsh(ntk_train_train)[-1]
lr = 2 * (1 + momentum) * factor / (max_eva + eps)
Expand Down
2 changes: 1 addition & 1 deletion neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3758,7 +3758,7 @@ def _pos_emb_pdist(shape: Sequence[int],
R += np.abs(pd) ** pos_emb_p_norm

R = pos_emb_decay_fn(R)
return R
return R # pytype: disable=bad-return-type # jax-ndarray


def _get_all_pos_emb(k: Kernel,
Expand Down
2 changes: 1 addition & 1 deletion neural_tangents/_src/stax/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def unmask_fn(fn: ApplyFn) -> ApplyFn:
def unmask(x: Union[MaskedArray, np.ndarray]) -> np.ndarray:
if isinstance(x, MaskedArray):
x = utils.mask(x.masked_value, x.mask)
return x
return x # pytype: disable=bad-return-type # jax-ndarray

def is_leaf(x) -> bool:
return isinstance(x, (np.ndarray, MaskedArray))
Expand Down
2 changes: 1 addition & 1 deletion neural_tangents/_src/utils/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def _mul_j(
other = np.ones((), other.dtype) / other

if inval.ndim == 0:
return other
return other # pytype: disable=bad-return-type # jax-ndarray

if other.ndim == 0:
other = np.broadcast_to(other, inval.shape)
Expand Down
2 changes: 1 addition & 1 deletion tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def testPredictND(self):
p_train_mse, p_test_mse = predict_fn_mse(
ts, fx_train_0, fx_test_0, ntk_test_train)
self.assertAllClose(y_test_shape, p_test_mse.shape)
self.assertAllClose(y_train_shape, p_train_mse.shape)
self.assertAllClose(y_train_shape, p_train_mse.shape) # pytype: disable=attribute-error # jax-ndarray

p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
ts, x, ('nngp', 'ntk'), compute_cov=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/stax/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_inputs(
batch_axis = shape.index(BATCH_SIZE)
shape = shape[:batch_axis] + (2 * BATCH_SIZE,) + shape[batch_axis + 1:]
x2 = None if same_inputs else fn(random.normal(split, shape)) * 2
return x1, x2
return x1, x2 # pytype: disable=bad-return-type # jax-ndarray


def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding,
Expand Down

0 comments on commit 5be5afd

Please sign in to comment.