From 8ac41614c60824ffd654b518c4c21d18f82b2945 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Mon, 19 Oct 2020 11:44:12 -0700 Subject: [PATCH] Refactoring: rename `NumericalActivation` into `ElementwiseNumerical` to it match the `_elementwise` function. Make `_elementwise` accept optional arguments. Refactor some related tests. Make `supports_masking` public. PiperOrigin-RevId: 337900281 --- neural_tangents/stax.py | 77 +++++++++++++++++++++++------------------ tests/stax_test.py | 61 ++++++++++++++------------------ 2 files changed, 70 insertions(+), 68 deletions(-) diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 8bcff8d8..35b53be0 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -192,7 +192,7 @@ def new_kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: return req -def _supports_masking(remask_kernel: bool): +def supports_masking(remask_kernel: bool): """Returns a decorator that turns layers into layers supporting masking. Specifically: @@ -355,7 +355,7 @@ def kernel_fn(ks: List[Kernel], **kwargs) -> List[Kernel]: @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def DotGeneral( *, lhs: Union[np.ndarray, float] = None, @@ -487,7 +487,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def Aggregate( aggregate_axis: Axes = None, batch_axis: int = 0, @@ -683,7 +683,7 @@ def kernel_fn(k: NTTree[Kernel], @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def Dense( out_dim: int, W_std: float = 1., @@ -812,7 +812,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def Conv( out_chan: int, filter_shape: Sequence[int], @@ -836,7 +836,7 @@ def Conv( number of spatial dimensions in `dimension_numbers`. strides: The stride of the convolution. The shape of the tuple should agree with - the number of spatial dimensions in `dimension_nubmers`. + the number of spatial dimensions in `dimension_numbers`. padding: Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. @@ -860,7 +860,7 @@ def Conv( @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def ConvTranspose( out_chan: int, filter_shape: Sequence[int], @@ -908,7 +908,7 @@ def ConvTranspose( @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def ConvLocal( out_chan: int, filter_shape: Sequence[int], @@ -934,7 +934,7 @@ def ConvLocal( number of spatial dimensions in `dimension_numbers`. strides: The stride of the convolution. The shape of the tuple should agree with - the number of spatial dimensions in `dimension_nubmers`. + the number of spatial dimensions in `dimension_numbers`. padding: Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. @@ -982,7 +982,7 @@ def _Conv( in `dimension_numbers`. strides: The stride of the convolution. The shape of the tuple should agree with - the number of spatial dimensions in `dimension_nubmers`. + the number of spatial dimensions in `dimension_numbers`. padding: Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. @@ -1276,7 +1276,7 @@ def FanOut(num: int) -> InternalLayer: @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def FanInSum() -> InternalLayer: """Layer construction function for a fan-in sum layer. @@ -1335,7 +1335,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def FanInProd() -> InternalLayer: """Layer construction function for a fan-in product layer. @@ -1401,7 +1401,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def FanInConcat(axis: int = -1) -> InternalLayer: """Layer construction function for a fan-in concatenation layer. @@ -1511,7 +1511,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def AvgPool( window_shape: Sequence[int], strides: Sequence[int] = None, @@ -1547,7 +1547,7 @@ def AvgPool( @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def SumPool( window_shape: Sequence[int], strides: Sequence[int] = None, @@ -1697,7 +1697,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def GlobalSumPool(batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer: """Layer construction function for a global sum pooling layer. @@ -1718,7 +1718,7 @@ def GlobalSumPool(batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer: @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def GlobalAvgPool(batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer: """Layer construction function for a global average pooling layer. @@ -1824,7 +1824,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def Flatten(batch_axis: int = 0, batch_axis_out: int = 0) -> InternalLayer: """Layer construction function for flattening all non-batch dimensions. @@ -1915,7 +1915,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def Identity() -> InternalLayer: """Layer construction function for an identity layer. @@ -1953,7 +1953,7 @@ def fn(self): @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def GlobalSelfAttention( n_chan_out: int, n_chan_key: int, @@ -2439,7 +2439,7 @@ def mask_fn(mask, input_shape): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def LayerNorm( axis: Axes = -1, eps: float = 1e-12, @@ -2547,7 +2547,7 @@ def prepare_mask(m): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def Dropout(rate: float, mode: str = 'train') -> InternalLayer: """Dropout layer. @@ -2604,7 +2604,7 @@ def kernel_fn_train(k: Kernel, **kwargs): @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def Erf( a: float = 1., b: float = 1., @@ -2684,7 +2684,7 @@ def Sigmoid_like(): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def Gelu( do_backprop: bool = False) -> InternalLayer: """Gelu function. @@ -2757,7 +2757,7 @@ def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray: @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def Sin( a: float = 1., b: float = 1., @@ -2813,7 +2813,7 @@ def nngp_fn_diag(nngp): @layer -@_supports_masking(remask_kernel=True) +@supports_masking(remask_kernel=True) def Rbf( gamma: float = 1.0) -> InternalLayer: """Dual activation function for normalized RBF or squared exponential kernel. @@ -2871,7 +2871,7 @@ def nngp_fn_diag(nngp): @layer -@_supports_masking(remask_kernel=False) +@supports_masking(remask_kernel=False) def ABRelu( a: float, b: float, @@ -2998,8 +2998,8 @@ def Abs( @layer -@_supports_masking(remask_kernel=True) -def NumericalActivation( +@supports_masking(remask_kernel=True) +def ElementwiseNumerical( fn: Callable[[float], float], deg: int, df: Callable[[float], float] = None, @@ -3035,6 +3035,9 @@ def NumericalActivation( quad_points = osp.special.roots_hermite(deg) if df is None: + warnings.warn( + 'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of ' + 'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.') df = np.vectorize(grad(fn)) _requires(diagonal_spatial=_Diagonal()) # pytype:disable=wrong-keyword-args @@ -3745,17 +3748,25 @@ def kernel_fn_any(x1_or_kernel: Union[NTTree[np.ndarray], NTTree[Kernel]], return kernel_fn_any -def _elementwise(fn: Callable[[float], float], +def _elementwise(fn: Optional[Callable[[float], float]], name: str, - kernel_fn: LayerKernelFn) -> InternalLayer: - init_fn, apply_fn = ostax.elementwise(fn) + kernel_fn: Optional[LayerKernelFn]) -> InternalLayer: + init_fn = lambda rng, input_shape: (input_shape, ()) + + def apply_fn(params, inputs, **kwargs): + if fn is None: + raise NotImplementedError(fn) + return fn(inputs) # pytype:disable=not-callable def new_kernel_fn(k: Kernel, **kwargs) -> Kernel: + if kernel_fn is None: + raise NotImplementedError(kernel_fn) + if not k.is_gaussian: raise ValueError('The input to the activation function must be Gaussian, ' 'i.e. a random affine transform is required before the ' 'activation function.') - k = kernel_fn(k) + k = kernel_fn(k) # pytype:disable=not-callable return k.replace(is_gaussian=False) init_fn.__name__ = apply_fn.__name__ = new_kernel_fn.__name__ = name diff --git a/tests/stax_test.py b/tests/stax_test.py index 91f0e027..c006cd58 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -952,9 +952,30 @@ def test_rbf(self, same_inputs, model, get, gamma): rbf_gamma=gamma) -class NumericalActivationTest(test_utils.NeuralTangentsTestCase): +class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase): - def _test_activation(self, activation, same_inputs, model, get): + @jtu.parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + '_{}_{}_{}_{}'.format( + model, + phi[0].__name__, + 'Same_inputs' if same_inputs else 'Different_inputs', + get), + 'model': model, + 'phi': phi, + 'same_inputs': same_inputs, + 'get': get, + } + for model in ['fc', 'conv-pool', 'conv-flatten'] + for phi in [ + stax.Erf(), + stax.Gelu(), + stax.Sin(), + ] + for same_inputs in [False, True] + for get in ['nngp', 'ntk'])) + def test_elementwise_numerical(self, same_inputs, model, phi, get): platform = xla_bridge.get_backend().platform if platform == 'cpu' and 'conv' in model: raise absltest.SkipTest('Not running CNNs on CPU to save time.') @@ -986,47 +1007,17 @@ def _test_activation(self, activation, same_inputs, model, get): stax.Dense(output_dim)) depth = 2 - _, _, kernel_fn = stax.serial(*[affine, activation]*depth, readout) + _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) - fn = lambda x: activation[1]((), x) + fn = lambda x: phi[1]((), x) _, _, kernel_fn = stax.serial( - *[affine, stax.NumericalActivation(fn, deg=deg)]*depth, readout) + *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout) numerical_activation_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol) - @jtu.parameterized.named_parameters( - jtu.cases_from_list({ - 'testcase_name': - '_{}_{}_{}_{}'.format( - model, - phi_name, - 'Same_inputs' if same_inputs else 'Different_inputs', - get), - 'model': model, - 'phi_name': phi_name, - 'same_inputs': same_inputs, - 'get': get, - } - for model in ['fc', 'conv-pool', 'conv-flatten'] - for phi_name in ['Erf', 'Gelu', 'Sin', 'Cos'] - for same_inputs in [False, True] - for get in ['nngp', 'ntk'])) - def test_numerical_activation(self, same_inputs, model, phi_name, get): - if phi_name == 'Erf': - activation = stax.Erf() - elif phi_name == 'Gelu': - activation = stax.Gelu() - elif phi_name == 'Sin': - activation = stax.Sin() - elif phi_name == 'Cos': - activation = stax.Sin(c=np.pi/2) - else: - raise NotImplementedError(f'Activation {phi_name} is not implemented.') - self._test_activation(activation, same_inputs, model, get) - @jtu.parameterized.parameters([ {