Skip to content

Commit

Permalink
Refactoring: rename NumericalActivation into ElementwiseNumerical
Browse files Browse the repository at this point in the history
… to it match the `_elementwise` function. Make `_elementwise` accept optional arguments. Refactor some related tests. Make `supports_masking` public.

PiperOrigin-RevId: 337900281
  • Loading branch information
romanngg committed Oct 20, 2020
1 parent 513df1a commit 8ac4161
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 68 deletions.
77 changes: 44 additions & 33 deletions neural_tangents/stax.py
Expand Up @@ -192,7 +192,7 @@ def new_kernel_fn(k: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
return req


def _supports_masking(remask_kernel: bool):
def supports_masking(remask_kernel: bool):
"""Returns a decorator that turns layers into layers supporting masking.
Specifically:
Expand Down Expand Up @@ -355,7 +355,7 @@ def kernel_fn(ks: List[Kernel], **kwargs) -> List[Kernel]:


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def DotGeneral(
*,
lhs: Union[np.ndarray, float] = None,
Expand Down Expand Up @@ -487,7 +487,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def Aggregate(
aggregate_axis: Axes = None,
batch_axis: int = 0,
Expand Down Expand Up @@ -683,7 +683,7 @@ def kernel_fn(k: NTTree[Kernel],


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def Dense(
out_dim: int,
W_std: float = 1.,
Expand Down Expand Up @@ -812,7 +812,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def Conv(
out_chan: int,
filter_shape: Sequence[int],
Expand All @@ -836,7 +836,7 @@ def Conv(
number of spatial dimensions in `dimension_numbers`.
strides:
The stride of the convolution. The shape of the tuple should agree with
the number of spatial dimensions in `dimension_nubmers`.
the number of spatial dimensions in `dimension_numbers`.
padding:
Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`,
or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions.
Expand All @@ -860,7 +860,7 @@ def Conv(


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def ConvTranspose(
out_chan: int,
filter_shape: Sequence[int],
Expand Down Expand Up @@ -908,7 +908,7 @@ def ConvTranspose(


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def ConvLocal(
out_chan: int,
filter_shape: Sequence[int],
Expand All @@ -934,7 +934,7 @@ def ConvLocal(
number of spatial dimensions in `dimension_numbers`.
strides:
The stride of the convolution. The shape of the tuple should agree with
the number of spatial dimensions in `dimension_nubmers`.
the number of spatial dimensions in `dimension_numbers`.
padding:
Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`,
or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions.
Expand Down Expand Up @@ -982,7 +982,7 @@ def _Conv(
in `dimension_numbers`.
strides:
The stride of the convolution. The shape of the tuple should agree with
the number of spatial dimensions in `dimension_nubmers`.
the number of spatial dimensions in `dimension_numbers`.
padding:
Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`,
or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions.
Expand Down Expand Up @@ -1276,7 +1276,7 @@ def FanOut(num: int) -> InternalLayer:


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def FanInSum() -> InternalLayer:
"""Layer construction function for a fan-in sum layer.
Expand Down Expand Up @@ -1335,7 +1335,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def FanInProd() -> InternalLayer:
"""Layer construction function for a fan-in product layer.
Expand Down Expand Up @@ -1401,7 +1401,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def FanInConcat(axis: int = -1) -> InternalLayer:
"""Layer construction function for a fan-in concatenation layer.
Expand Down Expand Up @@ -1511,7 +1511,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def AvgPool(
window_shape: Sequence[int],
strides: Sequence[int] = None,
Expand Down Expand Up @@ -1547,7 +1547,7 @@ def AvgPool(


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def SumPool(
window_shape: Sequence[int],
strides: Sequence[int] = None,
Expand Down Expand Up @@ -1697,7 +1697,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def GlobalSumPool(batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer:
"""Layer construction function for a global sum pooling layer.
Expand All @@ -1718,7 +1718,7 @@ def GlobalSumPool(batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer:


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def GlobalAvgPool(batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer:
"""Layer construction function for a global average pooling layer.
Expand Down Expand Up @@ -1824,7 +1824,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def Flatten(batch_axis: int = 0, batch_axis_out: int = 0) -> InternalLayer:
"""Layer construction function for flattening all non-batch dimensions.
Expand Down Expand Up @@ -1915,7 +1915,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def Identity() -> InternalLayer:
"""Layer construction function for an identity layer.
Expand Down Expand Up @@ -1953,7 +1953,7 @@ def fn(self):


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def GlobalSelfAttention(
n_chan_out: int,
n_chan_key: int,
Expand Down Expand Up @@ -2439,7 +2439,7 @@ def mask_fn(mask, input_shape):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def LayerNorm(
axis: Axes = -1,
eps: float = 1e-12,
Expand Down Expand Up @@ -2547,7 +2547,7 @@ def prepare_mask(m):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def Dropout(rate: float, mode: str = 'train') -> InternalLayer:
"""Dropout layer.
Expand Down Expand Up @@ -2604,7 +2604,7 @@ def kernel_fn_train(k: Kernel, **kwargs):


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def Erf(
a: float = 1.,
b: float = 1.,
Expand Down Expand Up @@ -2684,7 +2684,7 @@ def Sigmoid_like():


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def Gelu(
do_backprop: bool = False) -> InternalLayer:
"""Gelu function.
Expand Down Expand Up @@ -2757,7 +2757,7 @@ def nngp_fn_diag(nngp: np.ndarray) -> np.ndarray:


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def Sin(
a: float = 1.,
b: float = 1.,
Expand Down Expand Up @@ -2813,7 +2813,7 @@ def nngp_fn_diag(nngp):


@layer
@_supports_masking(remask_kernel=True)
@supports_masking(remask_kernel=True)
def Rbf(
gamma: float = 1.0) -> InternalLayer:
"""Dual activation function for normalized RBF or squared exponential kernel.
Expand Down Expand Up @@ -2871,7 +2871,7 @@ def nngp_fn_diag(nngp):


@layer
@_supports_masking(remask_kernel=False)
@supports_masking(remask_kernel=False)
def ABRelu(
a: float,
b: float,
Expand Down Expand Up @@ -2998,8 +2998,8 @@ def Abs(


@layer
@_supports_masking(remask_kernel=True)
def NumericalActivation(
@supports_masking(remask_kernel=True)
def ElementwiseNumerical(
fn: Callable[[float], float],
deg: int,
df: Callable[[float], float] = None,
Expand Down Expand Up @@ -3035,6 +3035,9 @@ def NumericalActivation(
quad_points = osp.special.roots_hermite(deg)

if df is None:
warnings.warn(
'Using JAX autodiff to compute the `fn` derivative for NTK. Beware of '
'https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where.')
df = np.vectorize(grad(fn))

_requires(diagonal_spatial=_Diagonal()) # pytype:disable=wrong-keyword-args
Expand Down Expand Up @@ -3745,17 +3748,25 @@ def kernel_fn_any(x1_or_kernel: Union[NTTree[np.ndarray], NTTree[Kernel]],
return kernel_fn_any


def _elementwise(fn: Callable[[float], float],
def _elementwise(fn: Optional[Callable[[float], float]],
name: str,
kernel_fn: LayerKernelFn) -> InternalLayer:
init_fn, apply_fn = ostax.elementwise(fn)
kernel_fn: Optional[LayerKernelFn]) -> InternalLayer:
init_fn = lambda rng, input_shape: (input_shape, ())

def apply_fn(params, inputs, **kwargs):
if fn is None:
raise NotImplementedError(fn)
return fn(inputs) # pytype:disable=not-callable

def new_kernel_fn(k: Kernel, **kwargs) -> Kernel:
if kernel_fn is None:
raise NotImplementedError(kernel_fn)

if not k.is_gaussian:
raise ValueError('The input to the activation function must be Gaussian, '
'i.e. a random affine transform is required before the '
'activation function.')
k = kernel_fn(k)
k = kernel_fn(k) # pytype:disable=not-callable
return k.replace(is_gaussian=False)

init_fn.__name__ = apply_fn.__name__ = new_kernel_fn.__name__ = name
Expand Down
61 changes: 26 additions & 35 deletions tests/stax_test.py
Expand Up @@ -952,9 +952,30 @@ def test_rbf(self, same_inputs, model, get, gamma):
rbf_gamma=gamma)


class NumericalActivationTest(test_utils.NeuralTangentsTestCase):
class ElementwiseNumericalTest(test_utils.NeuralTangentsTestCase):

def _test_activation(self, activation, same_inputs, model, get):
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}'.format(
model,
phi[0].__name__,
'Same_inputs' if same_inputs else 'Different_inputs',
get),
'model': model,
'phi': phi,
'same_inputs': same_inputs,
'get': get,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for phi in [
stax.Erf(),
stax.Gelu(),
stax.Sin(),
]
for same_inputs in [False, True]
for get in ['nngp', 'ntk']))
def test_elementwise_numerical(self, same_inputs, model, phi, get):
platform = xla_bridge.get_backend().platform
if platform == 'cpu' and 'conv' in model:
raise absltest.SkipTest('Not running CNNs on CPU to save time.')
Expand Down Expand Up @@ -986,47 +1007,17 @@ def _test_activation(self, activation, same_inputs, model, get):
stax.Dense(output_dim))
depth = 2

_, _, kernel_fn = stax.serial(*[affine, activation]*depth, readout)
_, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout)
analytic_kernel = kernel_fn(X0_1, X0_2, get)

fn = lambda x: activation[1]((), x)
fn = lambda x: phi[1]((), x)
_, _, kernel_fn = stax.serial(
*[affine, stax.NumericalActivation(fn, deg=deg)]*depth, readout)
*[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout)
numerical_activation_kernel = kernel_fn(X0_1, X0_2, get)

test_utils.assert_close_matrices(self, analytic_kernel,
numerical_activation_kernel, rtol)

@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_{}_{}_{}_{}'.format(
model,
phi_name,
'Same_inputs' if same_inputs else 'Different_inputs',
get),
'model': model,
'phi_name': phi_name,
'same_inputs': same_inputs,
'get': get,
}
for model in ['fc', 'conv-pool', 'conv-flatten']
for phi_name in ['Erf', 'Gelu', 'Sin', 'Cos']
for same_inputs in [False, True]
for get in ['nngp', 'ntk']))
def test_numerical_activation(self, same_inputs, model, phi_name, get):
if phi_name == 'Erf':
activation = stax.Erf()
elif phi_name == 'Gelu':
activation = stax.Gelu()
elif phi_name == 'Sin':
activation = stax.Sin()
elif phi_name == 'Cos':
activation = stax.Sin(c=np.pi/2)
else:
raise NotImplementedError(f'Activation {phi_name} is not implemented.')
self._test_activation(activation, same_inputs, model, get)


@jtu.parameterized.parameters([
{
Expand Down

0 comments on commit 8ac4161

Please sign in to comment.