Skip to content

Commit

Permalink
Migrate GenericFunction usages to PolymorphicFunction
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 564483413
  • Loading branch information
faizan-m authored and romanngg committed Sep 26, 2023
1 parent 3c3dc9f commit 8064529
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions neural_tangents/experimental/empirical_tf/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@


def empirical_ntk_fn_tf(
f: Union[tf.Module, tf.types.experimental.GenericFunction],
f: Union[tf.Module, tf.types.experimental.PolymorphicFunction],
trace_axes: Axes = (-1,),
diagonal_axes: Axes = (),
vmap_axes: VMapAxes = None,
Expand Down Expand Up @@ -243,7 +243,7 @@ def empirical_ntk_fn_tf(
if isinstance(f, tf.Module):
apply_fn, _ = get_apply_fn_and_params(f)

elif isinstance(f, tf.types.experimental.GenericFunction):
elif isinstance(f, tf.types.experimental.PolymorphicFunction):
apply_fn = tf2jax.convert_functional(f, *f.input_signature)

else:
Expand Down

0 comments on commit 8064529

Please sign in to comment.