Skip to content

Commit

Permalink
Promote numerical arbitrary activation layer(NumericalActivation) as …
Browse files Browse the repository at this point in the history
…a public method.

example usage: `stax.serial(stax.Dense(1024), stax.NumericalActivation(jax.nn.swish, deg=25), stax.Dense(1))`

PiperOrigin-RevId: 322887050
  • Loading branch information
jaehlee authored and romanngg committed Jul 23, 2020
1 parent a10b604 commit 04c7c29
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,11 +1521,11 @@ def Abs(do_backprop: bool = False, do_stabilize: bool = False) -> InternalLayer:

@layer
@_supports_masking(remask_kernel=True)
def _NumericalActivation(fn: Callable[[float], float],
deg: int,
df: Callable[[float], float] = None,
do_backprop: bool = False) -> InternalLayer:
"""Activation function using numerical integegration.
def NumericalActivation(fn: Callable[[float], float],
deg: int,
df: Callable[[float], float] = None,
do_backprop: bool = False) -> InternalLayer:
"""Activation function using numerical integration.
Supports general activation functions using Gauss-Hermite quadrature.
Expand Down Expand Up @@ -1555,7 +1555,7 @@ def _NumericalActivation(fn: Callable[[float], float],
quad_points = osp.special.roots_hermite(deg)
if df is None:
df = np.vectorize(grad(fn))
return _elementwise(fn, f'_NumericalActivation({fn},deg={deg})', df=df,
return _elementwise(fn, f'NumericalActivation({fn},deg={deg})', df=df,
quad_points=quad_points, do_backprop=do_backprop)


Expand Down
2 changes: 1 addition & 1 deletion tests/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def _test_activation(self, activation, fn, same_inputs, model, get):
analytic_kernel = kernel_fn(X0_1, X0_2, get)

_, _, kernel_fn = stax.serial(
*[affine, stax._NumericalActivation(fn, deg=deg)]*depth, readout)
*[affine, stax.NumericalActivation(fn, deg=deg)]*depth, readout)
numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

test_utils.assert_close_matrices(self, analytic_kernel,
Expand Down

0 comments on commit 04c7c29

Please sign in to comment.