Skip to content

Commit

Permalink
Minor fixes to make NT build at OSS head.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569578070
  • Loading branch information
romanngg committed Sep 30, 2023
1 parent 271dcbe commit 136338d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions neural_tangents/_src/utils/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ class Kernel:
diagonal_batch: bool = dataclasses.field(pytree_node=False)
diagonal_spatial: bool = dataclasses.field(pytree_node=False)

shape1: tuple[int, ...] = dataclasses.field(pytree_node=False)
shape2: tuple[int, ...] = dataclasses.field(pytree_node=False)
shape1: Optional[tuple[int, ...]] = dataclasses.field(pytree_node=False)
shape2: Optional[tuple[int, ...]] = dataclasses.field(pytree_node=False)

batch_axis: int = dataclasses.field(pytree_node=False)
channel_axis: int = dataclasses.field(pytree_node=False)
Expand Down
6 changes: 4 additions & 2 deletions neural_tangents/experimental/empirical_tf/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@
import tf2jax


# TODO(romann): update to PolymorphicFunction with tf 2.15 release
def empirical_ntk_fn_tf(
f: Union[tf.Module, tf.types.experimental.PolymorphicFunction],
f: Union[tf.Module, tf.types.experimental.GenericFunction],
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None,
Expand Down Expand Up @@ -243,7 +244,8 @@ def empirical_ntk_fn_tf(
if isinstance(f, tf.Module):
apply_fn, _ = get_apply_fn_and_params(f)

elif isinstance(f, tf.types.experimental.PolymorphicFunction):
# TODO(romann): update to PolymorphicFunction with tf 2.15 release
elif isinstance(f, tf.types.experimental.GenericFunction):
apply_fn = tf2jax.convert_functional(f, *f.input_signature)

else:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
INSTALL_REQUIRES = [
'jax>=0.4.14',
'frozendict>=2.3.8',
'tensorflow>=2.14.0',
'tf2jax>=0.3.5',
]

Expand Down

0 comments on commit 136338d

Please sign in to comment.