From 780ad0ce22d482bcefd12f4d3390090de7206da5 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Tue, 8 Sep 2020 23:23:46 -0700 Subject: [PATCH] 1) Support transpose convolution 2) Allow circular padding with masking 3) Minor doc tweaks 4) deflake tests by raising tolerance 5) Relax some typing requirements from Tuples to Sequences. Co-authored-by: Kayhan Batmanghelich PiperOrigin-RevId: 330662849 --- neural_tangents/predict.py | 10 +- neural_tangents/stax.py | 841 +++++++++++++++++++++++--------- neural_tangents/utils/kernel.py | 4 +- neural_tangents/utils/utils.py | 51 +- tests/stax_test.py | 341 +++++++++---- 5 files changed, 895 insertions(+), 352 deletions(-) diff --git a/neural_tangents/predict.py b/neural_tangents/predict.py index ce48cef2..22dd76a8 100644 --- a/neural_tangents/predict.py +++ b/neural_tangents/predict.py @@ -37,7 +37,7 @@ from neural_tangents.utils import utils, dataclasses import scipy as osp from neural_tangents.utils.typing import KernelFn, Axes, Get -from typing import Union, Tuple, Callable, Iterable, Optional, Dict, NamedTuple, Generator +from typing import Union, Tuple, Callable, Iterable, Optional, Dict, NamedTuple, Sequence, Generator from functools import lru_cache @@ -741,7 +741,7 @@ def gradient_descent_mse_ensemble( k_dd_cache = {} - def get_k_train_train(get: Tuple[str, ...]) -> _Kernel: + def get_k_train_train(get: Sequence[str]) -> _Kernel: if len(get) == 1: get = get[0] if get not in k_dd_cache: @@ -1000,7 +1000,11 @@ def _get_fns_in_eigenbasis( Args: k_train_train: - an n x n matrix + an n x n matrix. + diag_reg: + diagonal regularizer strength. + diag_reg_absolute_scale: + `True` to use absolute (vs relative to mean trace) regulatization. fns: a sequence of functions that add on the eigenvalues (evals, dt) -> modified_evals. diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index fad3091b..2fccda46 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -68,7 +68,7 @@ import functools import operator as op import string -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Sequence, TypeVar import warnings import frozendict @@ -387,6 +387,8 @@ def apply_fn(params, pattern: A 3D tensor of shape (batch, nodes, nodes), whose dimension order is fixed. + **kwargs: + unused. Returns: A 3D tensor of shape `(batch, nodes, channels)` which is equal to @@ -434,7 +436,7 @@ def kernel_fn(k: Kernels, pattern1, pattern2 = pattern def full_conjugate(p1, mat, p2): - if mat is None or mat.ndim ==0 or (p1 is None and p2 is None): + if mat is None or mat.ndim == 0 or (p1 is None and p2 is None): return mat elif p2 is None: return np.einsum('bli,bcij->bclj', p1, mat, optimize=True) @@ -594,8 +596,8 @@ def mask_fn(mask, input_shape): def GeneralConv( dimension_numbers: Optional[Tuple[str, str, str]], out_chan: int, - filter_shape: Tuple[int, ...], - strides: Tuple[int, ...] = None, + filter_shape: Sequence[int], + strides: Sequence[int] = None, padding: str = Padding.VALID.name, W_std: float = 1.0, b_std: float = 0.0, @@ -604,6 +606,54 @@ def GeneralConv( Based on `jax.experimental.stax.GeneralConv`. + Args: + dimension_numbers: Specifies which axes should be convolved over. Should + match the specification in `jax.lax.conv_general_dilated`. + out_chan: The number of output channels / features of the + convolution. This is ignored in by the `kernel_fn` in NTK + parameterization. + filter_shape: The shape of the filter. The shape of the tuple should agree + with the number of spatial dimensions in `dimension_numbers`. + strides: The stride of the convolution. The shape of the tuple should agree + with the number of spatial dimensions in `dimension_nubmers`. + padding: Specifies padding for the convolution. Can be one of `"VALID"`, + `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. + W_std: The standard deviation of the weights. + b_std: The standard deviation of the biases. + parameterization: Either `"ntk"` or `"standard"`. These parameterizations + are the direct analogues for convolution of the corresponding + parameterizations for `Dense` layers. + + Returns: + `(init_fn, apply_fn, kernel_fn)`. + """ + return _GeneralConv(dimension_numbers, + out_chan, + filter_shape, + strides, + padding, + W_std, + b_std, + False, + parameterization) + + +@layer +@_supports_masking(remask_kernel=True) +def GeneralConvTranspose( + dimension_numbers: Optional[Tuple[str, str, str]], + out_chan: int, + filter_shape: Sequence[int], + strides: Sequence[int] = None, + padding: str = Padding.VALID.name, + W_std: float = 1.0, + b_std: float = 0.0, + parameterization: str = 'ntk' +) -> InternalLayer: + """Layer construction function for a general transpose convolution layer. + + Based on `jax.experimental.stax.GeneralConvTranspose`. + Args: dimension_numbers: Specifies which axes should be convolved over. Should match the @@ -639,6 +689,7 @@ def GeneralConv( padding, W_std, b_std, + True, parameterization) @@ -646,8 +697,8 @@ def GeneralConv( @_supports_masking(remask_kernel=True) def Conv( out_chan: int, - filter_shape: Tuple[int, ...], - strides: Optional[Tuple[int, ...]] = None, + filter_shape: Sequence[int], + strides: Sequence[int] = None, padding: str = Padding.VALID.name, W_std: float = 1.0, b_std: float = 0.0, @@ -689,39 +740,102 @@ def Conv( padding, W_std, b_std, + False, parameterization) -def _GeneralConv( - dimension_numbers: Optional[Tuple[str, str, str]], +@layer +@_supports_masking(remask_kernel=True) +def ConvTranspose( out_chan: int, - filter_shape: Tuple[int, ...], - strides: Optional[Tuple[int, ...]] = None, + filter_shape: Sequence[int], + strides: Sequence[int] = None, padding: str = Padding.VALID.name, W_std: float = 1.0, b_std: float = 0.0, - parameterization: str = 'ntk') -> InternalLayer: + parameterization: str = 'ntk' +) -> InternalLayer: + """Layer construction function for a general transpose convolution layer. + + Based on `jax.experimental.stax.ConvTranspose`. + + Args: + out_chan: + The number of output channels / features of the convolution. This is + ignored in by the `kernel_fn` in NTK parameterization. + filter_shape: + The shape of the filter. The shape of the tuple should agree with the + number of spatial dimensions in `dimension_numbers`. + strides: + The stride of the convolution. The shape of the tuple should agree with + the number of spatial dimensions in `dimension_nubmers`. + padding: + Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, + or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. + W_std: + The standard deviation of the weights. + b_std: + The standard deviation of the biases. + parameterization: + Either `"ntk"` or `"standard"`. These parameterizations are the direct + analogues for convolution of the corresponding parameterizations for + `Dense` layers. + + Returns: + `(init_fn, apply_fn, kernel_fn)`. + """ + return _GeneralConv(None, + out_chan, + filter_shape, + strides, + padding, + W_std, + b_std, + True, + parameterization) + + +def _GeneralConv( + dimension_numbers: Optional[Tuple[str, str, str]], + out_chan: int, + filter_shape: Sequence[int], + strides: Optional[Sequence[int]], + padding: str, + W_std: float, + b_std: float, + transpose: bool, + parameterization: str +) -> InternalLayer: """Layer construction function for a general convolution layer. Based on `jax.experimental.stax.GeneralConv`. Args: - dimension_numbers: Specifies which axes should be convolved over. Should - match the specification in `jax.lax.conv_general_dilated`. - out_chan: The number of output channels / features of the - convolution. This is ignored in by the `kernel_fn` in NTK - parameterization. - filter_shape: The shape of the filter. The shape of the tuple should agree - with the number of spatial dimensions in `dimension_numbers`. - strides: The stride of the convolution. The shape of the tuple should agree - with the number of spatial dimensions in `dimension_nubmers`. - padding: Specifies padding for the convolution. Can be one of `"VALID"`, - `"SAME"`, or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. - W_std: The standard deviation of the weights. - b_std: The standard deviation of the biases. - parameterization: Either `"ntk"` or `"standard"`. These parameterizations - are the direct analogues for convolution of the corresponding - parameterizations for `Dense` layers. + dimension_numbers: + Specifies which axes should be convolved over. Should match the + specification in `jax.lax.dot_general_dilated`. + out_chan: + The number of output channels / features of the convolution. This is + ignored in by the `kernel_fn` in NTK parameterization. + filter_shape: The shape of the filter. + The shape of the tuple should agree with the number of spatial dimensions + in `dimension_numbers`. + strides: + The stride of the convolution. The shape of the tuple should agree with + the number of spatial dimensions in `dimension_nubmers`. + padding: + Specifies padding for the convolution. Can be one of `"VALID"`, `"SAME"`, + or `"CIRCULAR"`. `"CIRCULAR"` uses periodic convolutions. + W_std: + The standard deviation of the weights. + b_std: + The standard deviation of the biases. + transpose: + `True` to use transpose convolution. + parameterization: + Either `"ntk"` or `"standard"`. These parameterizations are the direct + analogues for convolution of the corresponding parameterizations for + `Dense` layers. Returns: `(init_fn, apply_fn, kernel_fn)`. @@ -730,35 +844,41 @@ def _GeneralConv( parameterization = parameterization.lower() if dimension_numbers is None: - spatial_dims = ''.join(c for c in string.ascii_uppercase - if c not in ('N', 'C', 'I', 'O'))[:len(filter_shape)] - lhs_spec = 'N' + spatial_dims + 'C' - dimension_numbers = (lhs_spec, spatial_dims + 'IO', lhs_spec) + dimension_numbers = _get_dimension_numbers(len(filter_shape), False) - lhs_spec = dimension_numbers[0] + lhs_spec, rhs_spec, out_spec = dimension_numbers one = (1,) * len(filter_shape) strides = strides or one padding = Padding(padding) - init_padding = padding if padding == Padding.CIRCULAR: - init_padding = Padding.SAME + apply_padding = Padding.VALID + init_padding = padding.SAME + else: + init_padding = apply_padding = padding - def input_total_dim(input_shape): + if transpose: + stax_conv = ostax.GeneralConvTranspose + lax_conv = lax.conv_transpose + else: + stax_conv = ostax.GeneralConv + lax_conv = lax.conv_general_dilated + + ntk_init_fn, _ = stax_conv(dimension_numbers=dimension_numbers, + out_chan=out_chan, + filter_shape=filter_shape, + strides=strides, + padding=init_padding.name, + W_init=random.normal, + b_init=random.normal) + + def get_fan_in(input_shape): return input_shape[lhs_spec.index('C')] * onp.prod(filter_shape) - ntk_init_fn, _ = ostax.GeneralConv(dimension_numbers, - out_chan, - filter_shape, - strides, - init_padding.name, - random.normal, - random.normal) - def standard_init_fn(rng, input_shape): output_shape, (W, b) = ntk_init_fn(rng, input_shape) - norm = W_std / np.sqrt(input_total_dim(input_shape)) + norm = W_std / np.sqrt(get_fan_in(input_shape)) return output_shape, (W * norm, b * b_std) if parameterization == 'ntk': @@ -766,43 +886,54 @@ def standard_init_fn(rng, input_shape): elif parameterization == 'standard': init_fn = standard_init_fn else: - raise ValueError('Parameterization not supported: %s' % parameterization) + raise ValueError(f'Parameterization not supported: {parameterization}.') def apply_fn(params, inputs, **kwargs): W, b = params if parameterization == 'ntk': - norm = W_std / np.sqrt(input_total_dim(inputs.shape)) + norm = W_std / np.sqrt(get_fan_in(inputs.shape)) b_rescale = b_std elif parameterization == 'standard': norm = 1. b_rescale = 1. - apply_padding = padding - if padding == Padding.CIRCULAR: - apply_padding = Padding.VALID - spatial_axes = tuple(dimension_numbers[0].index(c) - for c in dimension_numbers[1] - if c not in ('I', 'O')) + if padding == Padding.CIRCULAR and not transpose: + spatial_axes = tuple(lhs_spec.index(c) + for c in rhs_spec if c not in ('I', 'O')) inputs = _same_pad_for_filter_shape(inputs, filter_shape, strides, - spatial_axes, 'wrap') + spatial_axes) - return norm * lax.conv_general_dilated( + res = norm * lax_conv( inputs, W, strides, apply_padding.name, - dimension_numbers=dimension_numbers) + b_rescale * b - - @_requires(batch_axis=dimension_numbers[0].index('N'), - channel_axis=dimension_numbers[0].index('C')) + dimension_numbers=dimension_numbers) + + if padding == Padding.CIRCULAR and transpose: + out_shape = eval_shape(lambda x: lax.conv_transpose( + lhs=x, + rhs=W, + strides=strides, + padding=Padding.SAME.name, + dimension_numbers=dimension_numbers + ), inputs).shape + spatial_axes = tuple(out_spec.index(c) + for c in rhs_spec if c not in ('I', 'O')) + res = _same_pad_for_filter_shape_transpose(res, spatial_axes, out_shape) + + return res + b_rescale * b + + @_requires(batch_axis=lhs_spec.index('N'), + channel_axis=lhs_spec.index('C')) def kernel_fn(k: Kernel, **kwargs): """Compute the transformed kernels after a conv layer.""" cov1, nngp, cov2, ntk, is_reversed = (k.cov1, k.nngp, k.cov2, k.ntk, k.is_reversed) - input_spec = tuple(c for c in dimension_numbers[0] if c not in ('N', 'C')) - conv_spec = tuple(c for c in dimension_numbers[1] if c not in ('I', 'O')) + input_spec = tuple(c for c in lhs_spec if c not in ('N', 'C')) + conv_spec = tuple(c for c in rhs_spec if c not in ('I', 'O')) input_to_filter_permutation = tuple(conv_spec.index(c) for c in input_spec) filter_shape_kernel = tuple(filter_shape[p] for p in @@ -811,10 +942,8 @@ def kernel_fn(k: Kernel, **kwargs): input_to_filter_permutation) if k.diagonal_spatial: - def conv_unscaled(x, batch_ndim): - x = _conv_kernel_diagonal_spatial( - x, filter_shape_kernel, strides_kernel, padding, batch_ndim) - return x + conv_kernel = (_conv_kernel_diagonal_spatial_transpose + if transpose else _conv_kernel_diagonal_spatial) else: if is_reversed: @@ -823,14 +952,20 @@ def conv_unscaled(x, batch_ndim): is_reversed = not is_reversed - def conv_unscaled(x, batch_ndim): - x = _conv_kernel_full_spatial( - x, filter_shape_kernel, strides_kernel, padding, batch_ndim) - return x + conv_kernel = (_conv_kernel_full_spatial_transpose + if transpose else _conv_kernel_full_spatial) - def conv(x, batch_ndim): - x = conv_unscaled(x, batch_ndim) - return _affine(x, W_std, b_std) + def conv_unscaled(lhs, batch_ndim): + lhs = conv_kernel(lhs, + filter_shape_kernel, + strides_kernel, + padding, + batch_ndim) + return lhs + + def conv(lhs, batch_ndim): + lhs = conv_unscaled(lhs, batch_ndim) + return _affine(lhs, W_std, b_std) cov1 = conv(cov1, 1 if k.diagonal_batch else 2) cov2 = conv(cov2, 1 if k.diagonal_batch else 2) @@ -843,7 +978,7 @@ def conv(x, batch_ndim): nngp_unscaled = conv_unscaled(nngp, 2) if ntk is not None: ntk = ( - input_total_dim(k.shape1) * nngp_unscaled + 1. + + get_fan_in(k.shape1) * nngp_unscaled + 1. + W_std ** 2 * conv_unscaled(ntk, 2)) nngp = _affine(nngp_unscaled, W_std, b_std) @@ -853,28 +988,44 @@ def conv(x, batch_ndim): ntk=ntk, is_gaussian=True, is_reversed=is_reversed, - batch_axis=dimension_numbers[2].index('N'), - channel_axis=dimension_numbers[2].index('C'), + batch_axis=out_spec.index('N'), + channel_axis=out_spec.index('C'), is_input=False) # Reorder output spatial dimensions if the finite layer does so. # TODO(romann): make more efficient / lazy. - out_spec = tuple(c for c in dimension_numbers[2] if c not in ('N', 'C')) - in_to_out_permutation = tuple(out_spec.index(c) for c in input_spec) + out_spec_kernel = tuple(c for c in out_spec if c not in ('N', 'C')) + in_to_out_permutation = tuple(out_spec_kernel.index(c) for c in input_spec) res = res.transpose(in_to_out_permutation) return res def mask_fn(mask, input_shape): - batch_axis = dimension_numbers[0].index('N') - channel_axis = dimension_numbers[0].index('C') + batch_axis, channel_axis = lhs_spec.index('N'), lhs_spec.index('C') # Collapse channel dimension of masks, since an FC layer is applied at each # spatial location. mask = np.all(mask, axis=channel_axis, keepdims=True) - _check_is_implemented(mask, padding, channel_axis) - return _pool_mask(mask, filter_shape, strides, padding, - batch_axis, channel_axis) + + if transpose: + rhs_shape = list(filter_shape) + for c in ('O', 'I'): + rhs_shape.insert(rhs_spec.index(c), 1) + rhs = np.ones(rhs_shape) + # TODO(romann): revisit after https://github.com/google/jax/issues/4012. + mask = lax.conv_transpose( + mask.astype(rhs.dtype), + rhs, + strides, + init_padding.name, + dimension_numbers=dimension_numbers).astype(mask.dtype) + + else: + mask = _pool_mask(mask, filter_shape, strides, init_padding, + batch_axis, channel_axis) + mask = np.transpose(mask, (out_spec.index(c) for c in lhs_spec)) + + return mask return init_fn, apply_fn, kernel_fn, mask_fn @@ -1090,7 +1241,8 @@ def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel: channel_axis: -1, }, **{ - spatial_axis: idx + 1 for idx, spatial_axis in enumerate(spatial_axes) + spatial_axis: idx + 1 + for idx, spatial_axis in enumerate(spatial_axes) } } @@ -1131,8 +1283,8 @@ def mask_fn(mask, input_shape): @layer @_supports_masking(remask_kernel=True) def AvgPool( - window_shape: Tuple[int, ...], - strides: Tuple[int, ...] = None, + window_shape: Sequence[int], + strides: Sequence[int] = None, padding: str = Padding.VALID.name, normalize_edges: bool = True, batch_axis: int = 0, @@ -1167,8 +1319,8 @@ def AvgPool( @layer @_supports_masking(remask_kernel=True) def SumPool( - window_shape: Tuple[int, ...], - strides: Tuple[int, ...] = None, + window_shape: Sequence[int], + strides: Sequence[int] = None, padding: str = Padding.VALID.name, batch_axis: int = 0, channel_axis: int = -1) -> InternalLayer: @@ -1197,8 +1349,8 @@ def SumPool( def _Pool( pool_type: Pooling, - window_shape: Tuple[int, ...], - strides: Union[None, Tuple[int, ...]], + window_shape: Sequence[int], + strides: Optional[Sequence[int]], padding: str, normalize_edges: bool, batch_axis: int, @@ -1256,7 +1408,7 @@ def apply_fn(params, inputs, **kwargs): spatial_axes = tuple(i for i in range(inputs.ndim) if i not in non_spatial_axes) inputs = _same_pad_for_filter_shape(inputs, window_shape, strides, - spatial_axes, 'wrap') + spatial_axes) res = apply_fn_0(params, inputs, **kwargs) return res @@ -1295,7 +1447,7 @@ def kernel_fn(k: Kernel, **kwargs): return k.replace(cov1=cov1, nngp=nngp, cov2=cov2, ntk=ntk) def mask_fn(mask, input_shape): - _check_is_implemented(mask, padding, channel_axis) + _check_is_implemented(mask, channel_axis) return _pool_mask(mask, window_shape, strides, padding, batch_axis, channel_axis) @@ -1417,7 +1569,7 @@ def _pool(ker_mat, batch_ndim, mask=None): is_reversed=False) def mask_fn(mask, input_shape): - _check_is_implemented(mask, None, channel_axis) + _check_is_implemented(mask, channel_axis) non_spatial_axes = (batch_axis % mask.ndim, channel_axis % mask.ndim) spatial_axes = tuple(i for i in range(mask.ndim) if i not in non_spatial_axes) @@ -2747,7 +2899,10 @@ def _get_input_req_attr(kernel_fns: List[LayerKernelFn]) -> Dict[str, bool]: return req -def _double_tuple(x: tuple) -> tuple: +_T = TypeVar('_T') + + +def _double_tuple(x: Iterable[_T]) -> Tuple[_T, ...]: return tuple(v for v in x for _ in range(2)) @@ -3271,7 +3426,7 @@ def _get_diagonal_outer_prods(cov1: np.ndarray, diagonal_batch: bool, diagonal_spatial: bool, operation: Callable[[float, float], float], - axis: Tuple[int, ...] = (), + axis: Sequence[int] = (), mask1: Optional[np.ndarray] = None, mask2: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: @@ -3462,11 +3617,11 @@ def _fan_in_kernel_fn_concat(ks: List[Kernel], axis: int) -> Kernel: def _concat_kernels( - mats: List[Optional[np.ndarray]], + mats: Sequence[Optional[np.ndarray]], axis: int, diagonal_batch: bool, diagonal_spatial: bool, - widths: List[int]) -> Optional[np.ndarray]: + widths: Sequence[int]) -> Optional[np.ndarray]: """Compute the covariance of concatenated activations with given covariances. Args: @@ -3528,10 +3683,11 @@ def _concat_kernels( def _same_pad_for_filter_shape( x: np.ndarray, - filter_shape: Tuple[int, ...], - strides: Tuple[int, ...], - axes: Tuple[int, ...], - mode: str) -> np.ndarray: + filter_shape: Sequence[int], + strides: Sequence[int], + axes: Sequence[int], + mode: str = 'wrap', +) -> np.ndarray: """Pad an array to imitate `SAME` padding with `VALID`. See `Returns` section for details. This function is usually needed to @@ -3549,81 +3705,176 @@ def _same_pad_for_filter_shape( https://docs.scipy.org/doc/numpy/reference/generated/numpy.pad.html. Returns: A `np.ndarray` of the same dimensionality as `x` padded to a potentially - larger shape such that a `VALID` convolution with `filter_shape` applied + larger shape such that a `"VALID"` convolution with `filter_shape` applied to `x` over `axes` outputs an array of the same shape as `x`. """ axes_shape = tuple(np.size(x, axis) for axis in axes) axes_pads = lax.padtype_to_pads(axes_shape, filter_shape, strides, Padding.SAME.name) - - pads = [(0, 0),] * x.ndim + pads = [(0, 0)] * x.ndim for i, axis in enumerate(axes): pads[axis] = axes_pads[i] - x = np.pad(x, pads, mode) return x +def _same_pad_for_filter_shape_transpose( + x: np.ndarray, + axes: Sequence[int], + out_shape: Sequence[int] +) -> np.ndarray: + """Transpose of the `_same_pad_for_filter_shape` function. + + Unpads (crops) the array and fills each coordinate with the sum of all + elements at positions where the current element would appear during + `CIRCULAR` padding. + + Args: + x: + `np.ndarray` to pad, e.g. a 4D `NHWC` image. + axes: + non-negative integers, the spatial axes to apply convolution + over (e.g. `(1, 2)` for an `NHWC` image). + out_shape: + target shape after cropping. + + Returns: + A `np.ndarray` of shape `output_shape`. + """ + window_dimensions = tuple( + int(onp.ceil(x.shape[i] / out_shape[i])) // 2 * 2 + 1 + if i in axes else 1 for i in range(x.ndim)) + + dilation = tuple(out_shape[i] if i in axes else 1 for i in range(x.ndim)) + + x = lax.reduce_window( + operand=x, + init_value=np.zeros((), x.dtype), + computation=lax.add, + window_dimensions=window_dimensions, + window_strides=(1,) * x.ndim, + padding=Padding.SAME.name, + window_dilation=dilation + ) + + if x.shape != out_shape: + pads = [((x.shape[i] - out_shape[i]) // 2, + (x.shape[i] - out_shape[i]) - (x.shape[i] - out_shape[i]) // 2) + for i in range(x.ndim)] + slices = [] + for axis in range(x.ndim): + if axis in axes: + slices += [slice(pads[axis][0], x.shape[axis] - pads[axis][1])] + else: + slices += [slice(None)] + x = x[slices] + return x + + +def _pool_transpose( + x: np.ndarray, + filter_shape: Sequence[int], + strides: Sequence[int], + axes: Sequence[int], + padding: Padding +) -> np.ndarray: + """Transpose convolution with an all-ones filter.""" + n_spatial = len(axes) + x = np.moveaxis(x, axes, range(-n_spatial, 0)) + split = -n_spatial or x.ndim + x_preshape = x.shape[:split] + x = x.reshape((-1, 1) + x.shape[split:]) + rhs = np.ones(tuple(filter_shape) + (1, 1), x.dtype) + x = lax.conv_transpose(x, + rhs, + strides, + padding.name, + dimension_numbers=_get_dimension_numbers(n_spatial)) + x = x.reshape(x_preshape + x.shape[2:]) + x = np.moveaxis(x, range(-n_spatial, 0), axes) + return x + + +def _get_dimension_numbers( + n: int, + channels_first: bool = True +) -> Tuple[str, str, str]: + spatial_dims = ''.join(c for c in string.ascii_uppercase + if c not in ('N', 'C', 'I', 'O'))[:n] + if channels_first: + lhs_spec = 'NC' + spatial_dims + else: + lhs_spec = 'N' + spatial_dims + 'C' + dimension_numbers = (lhs_spec, spatial_dims + 'IO', lhs_spec) + return dimension_numbers + + def _conv_kernel_full_spatial( - mat: Optional[np.ndarray], - filter_shape: Tuple[int, ...], - strides: Tuple[int, ...], + lhs: Optional[np.ndarray], + filter_shape: Sequence[int], + strides: Sequence[int], padding: Padding, batch_ndim: int ) -> Optional[np.ndarray]: - """Compute covariance of the CNN outputs given inputs with covariance `mat`. + """Compute covariance of the CNN outputs given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == False`. Args: - mat: a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing + lhs: + a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing sample-[sample-]position-position covariances of CNN inputs, where `S` is the number of spatial dimensions (e.g. 2 for images). Has shape - `(batch_size_1, [batch_size_2,] - height, height, width, width, depth, depth, ...)`. - filter_shape: tuple of positive integers, the convolutional filters spatial - shape (e.g. `(3, 3)` for a 2D convolution). - strides: tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a - 2D convolution). - padding: a `Padding` enum, e.g. `Padding.CIRCULAR`. - batch_ndim: integer, number of batch dimensions, 1 or 2. + `(batch_size_1, [batch_size_2,] height, height, width, width, depth, + depth, ...)`. + filter_shape: + positive integers, the convolutional filters spatial shape + (e.g. `(3, 3)` for a 2D convolution). + strides: + positive integers, the CNN strides (e.g. `(1, 1)` for a 2D + convolution). + padding: + a `Padding` enum, e.g. `Padding.CIRCULAR`. + batch_ndim: + number of batch dimensions, 1 or 2. Returns: a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing sample-[sample-]position-position covariances of CNN outputs, where `S` is the number of spatial dimensions (e.g. 2 for images). Has shape - `(batch_size_1, [batch_size_2,] new_width, new_width, - new_height, new_height, new_depth, new_depth, ...)`. + `(batch_size_1, [batch_size_2,] new_width, new_width, new_height, + new_height, new_depth, new_depth, ...)`. """ - if mat is None or mat.ndim == 0: - return mat + if lhs is None or lhs.ndim == 0: + return lhs if padding == Padding.CIRCULAR: - spatial_axes = tuple(range(batch_ndim, mat.ndim)) - mat = _same_pad_for_filter_shape( - mat, - _double_tuple(filter_shape), - _double_tuple(strides), - spatial_axes, - 'wrap' - ) - padding = Padding.VALID - - for i in range(mat.ndim - 1, batch_ndim, -2): - spatial_i = (i - batch_ndim) // 2 - filter_i = filter_shape[spatial_i] - stride_i = strides[spatial_i] - size_i = mat.shape[i] + spatial_axes = tuple(range(batch_ndim, lhs.ndim)) + total_filter_shape = _double_tuple(filter_shape) + total_strides = _double_tuple(strides) + lhs = _same_pad_for_filter_shape(lhs, + total_filter_shape, + total_strides, + spatial_axes) + + def lax_conv(lhs, rhs, strides, padding): + return lax.conv_general_dilated( + lhs, rhs, strides, padding, + dimension_numbers=_CONV_KERNEL_DIMENSION_NUMBERS, + feature_group_count=lhs.shape[ + _CONV_KERNEL_DIMENSION_NUMBERS[0].index('C')]) - mat = np.moveaxis(mat, (i - 1, i), (-2, -1)) - mat_preshape = mat.shape[:-2] + def get_n_channels(batch_and_channels: int) -> int: + """Get the hardware-friendly channel size for depthwise convolution. - rhs = np.diag(np.full((filter_i,), 1. / filter_i, mat.dtype)) - rhs_shape = () + Args: + batch_and_channels: total size of non-spatial dimensions. + Returns: + Suggested number of channels for depthwise-separable convolution. + """ platform = xla_bridge.get_backend().platform if platform in ['gpu', 'tpu']: - batch_and_channels = utils.size_at(mat_preshape) n_channels = batch_and_channels # Find smallest `n_channels > 1` that divides `batch_and_features`; use @@ -3632,7 +3883,7 @@ def _conv_kernel_full_spatial( # in any other case (`conv2d_c1_k1_nchw_hw_packed_kernel`), and the latter # seems many-fold faster. # For TPU, start with `n_channels >= 128`. Beware of precision errors: - # TODO(romann): revisit based on b/154160868, b/154165148. + # TODO(romann): revisit based on b/154160868. n_channels_min = 2 if platform == 'gpu' else 128 for n_c in range(n_channels_min, batch_and_channels): @@ -3641,96 +3892,247 @@ def _conv_kernel_full_spatial( break elif platform == 'cpu': - # For CPU minimal channels seems best. + # For CPU minimal channels seems best. Transpose convolution does not + # support depthwise operations. n_channels = 1 else: raise NotImplementedError(platform) + return n_channels + + lhs = _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding, + lax_conv, get_n_channels) + return lhs - mat = mat.reshape((-1, n_channels, size_i, size_i)) +def _conv_kernel_full_spatial_transpose( + lhs: Optional[np.ndarray], + filter_shape: Sequence[int], + strides: Sequence[int], + padding: Padding, + batch_ndim: int +) -> Optional[np.ndarray]: + """Compute covariance of the CNN transpose given inputs with covariance `lhs`. + + Used when `kernel.diagonal_spatial == False`. + + Args: + lhs: + a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing + sample-[sample-]position-position covariances of CNN inputs, where `S` is + the number of spatial dimensions (e.g. 2 for images). Has shape + `(batch_size_1, [batch_size_2,] height, height, width, width, depth, + depth, ...)`. + filter_shape: + positive integers, the convolutional filters spatial shape + (e.g. `(3, 3)` for a 2D convolution). + strides: + positive integers, the CNN strides (e.g. `(1, 1)` for a 2D + convolution). + padding: + a `Padding` enum, e.g. `Padding.CIRCULAR`. + batch_ndim: + number of batch dimensions, 1 or 2. + + Returns: + a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing + sample-[sample-]position-position covariances of CNN outputs, where `S` is + the number of spatial dimensions (e.g. 2 for images). Has shape + `(batch_size_1, [batch_size_2,] new_width, new_width, new_height, + new_height, new_depth, new_depth, ...)`. + """ + if lhs is None or lhs.ndim == 0: + return lhs + + def lax_conv(lhs, rhs, strides, padding): + return lax.conv_transpose( + lhs, rhs, strides, padding, + dimension_numbers=_CONV_KERNEL_DIMENSION_NUMBERS) + + def get_n_channels(batch_and_channels: int) -> int: + """Transpose convolution does not support depthwise separable filters.""" + return 1 + + out = _conv_kernel_full_spatial_loop(lhs, filter_shape, strides, padding, + lax_conv, get_n_channels) + + if padding == Padding.CIRCULAR: + spatial_axes = tuple(range(batch_ndim, out.ndim)) + total_filter_shape = _double_tuple(filter_shape) + total_strides = _double_tuple(strides) + out_shape = eval_shape(lambda x: _pool_transpose(x, + total_filter_shape, + total_strides, + spatial_axes, + Padding.SAME), lhs).shape + out = _same_pad_for_filter_shape_transpose( + x=out, + axes=spatial_axes, + out_shape=utils.reverse_zipped(out_shape, batch_ndim)) + return out + + +def _conv_kernel_full_spatial_loop( + lhs: np.ndarray, + filter_shape: Sequence[int], + strides: Sequence[int], + padding: Padding, + lax_conv: Callable, + get_n_channels: Callable[[int], int] +) -> np.ndarray: + padding = Padding.VALID if padding == Padding.CIRCULAR else padding + + def get_rhs(n_channels: int, filter_size: int) -> np.ndarray: + rhs = np.diag(np.full((filter_size,), 1. / filter_size, lhs.dtype)) + rhs_shape = () for c in _CONV_KERNEL_DIMENSION_NUMBERS[1]: if c == 'O': rhs_shape += (n_channels,) elif c == 'I': rhs_shape += (1,) else: - rhs_shape += (filter_i,) - + rhs_shape += (filter_size,) rhs = np.broadcast_to(rhs, rhs_shape) + return rhs - mat = lax.conv_general_dilated( - lhs=mat, - rhs=rhs, - window_strides=(stride_i, stride_i), - padding=padding.name, - dimension_numbers=_CONV_KERNEL_DIMENSION_NUMBERS, - feature_group_count=n_channels) - mat = mat.reshape(mat_preshape + mat.shape[-2:]) + batch_ndim = lhs.ndim - len(filter_shape) * 2 + for i in range(lhs.ndim - 1, batch_ndim, -2): + spatial_i = (i - batch_ndim) // 2 - return mat + lhs = np.moveaxis(lhs, (i - 1, i), (-2, -1)) + preshape = lhs.shape[:-2] + n_channels = get_n_channels(utils.size_at(preshape)) + lhs = lhs.reshape((-1, n_channels, lhs.shape[-2], lhs.shape[-1])) + + rhs = get_rhs(n_channels, filter_shape[spatial_i]) + lhs = lax_conv(lhs, rhs, (strides[spatial_i],) * 2, padding.name) + lhs = lhs.reshape(preshape + lhs.shape[-2:]) + + return lhs def _conv_kernel_diagonal_spatial( - mat: Optional[np.ndarray], - filter_shape: Tuple[int, ...], - strides: Tuple[int, ...], + lhs: Optional[np.ndarray], + filter_shape: Sequence[int], + strides: Sequence[int], padding: Padding, batch_ndim: int - ) -> Optional[np.ndarray]: - """Compute covariance of the CNN outputs given inputs with covariance `mat`. +) -> Optional[np.ndarray]: + """Compute covariance of the CNN outputs given inputs with covariance `lhs`. Used when `kernel.diagonal_spatial == True`. Args: - mat: an `(S+batch_ndim)`-dimensional `np.ndarray` containing + lhs: + an `(S+batch_ndim)`-dimensional `np.ndarray` containing sample-sample-(same position) covariances of CNN inputs. Has `batch_ndim` - batch and `S` spatial dimensions with the shape of - `(batch_size_1, [batch_size_2,] height, width, depth, ...)`. - filter_shape: tuple of positive integers, the convolutional filters spatial - shape (e.g. `(3, 3)` for a 2D convolution). - strides: tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a - 2D convolution). - padding: a `Padding` enum, e.g. `Padding.CIRCULAR`. - batch_ndim: integer, number of leading batch dimensions, 1 or 2. + batch and `S` spatial dimensions with the shape of `(batch_size_1, + [batch_size_2,] height, width, depth, ...)`. + filter_shape: + tuple of positive integers, the convolutional filters spatial shape + (e.g. `(3, 3)` for a 2D convolution). + strides: + tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a 2D + convolution). + padding: + a `Padding` enum, e.g. `Padding.CIRCULAR`. + batch_ndim: + number of leading batch dimensions, 1 or 2. Returns: an `(S+batch_ndim)`-dimensional `np.ndarray` containing sample-sample-(same position) covariances of CNN outputs. Has `batch_ndim` - batch and `S` spatial dimensions with the shape of - `(batch_size_1, [batch_size_2,] new_height, new_width, new_depth, ...)`. + batch and `S` spatial dimensions with the shape of `(batch_size_1, + [batch_size_2,] new_height, new_width, new_depth, ...)`. """ - if mat is None or mat.ndim == 0: - return mat + if lhs is None or lhs.ndim == 0: + return lhs + + spatial_axes = tuple(range(batch_ndim, lhs.ndim)) + apply_padding = Padding.VALID if padding == Padding.CIRCULAR else padding if padding == Padding.CIRCULAR: - spatial_axes = tuple(range(mat.ndim)[batch_ndim:]) - mat = _same_pad_for_filter_shape(mat, filter_shape, strides, - spatial_axes, 'wrap') - padding = Padding.VALID + lhs = _same_pad_for_filter_shape(lhs, filter_shape, strides, spatial_axes) + + lhs = lax.reduce_window( + operand=lhs, + init_value=np.zeros((), lhs.dtype), + computation=lax.add, + window_dimensions=(1,) * batch_ndim + tuple(filter_shape), + window_strides=(1,) * batch_ndim + tuple(strides), + padding=apply_padding.name) filter_size = functools.reduce(op.mul, filter_shape, 1) - filter_shape = (1,) * batch_ndim + filter_shape - strides = (1,) * batch_ndim + strides - padding_vals = lax.padtype_to_pads( - mat.shape, filter_shape, strides, padding.name) - mat = lax._reduce_window_sum(mat, filter_shape, strides, padding_vals) - mat /= filter_size - return mat + return lhs / filter_size + + +def _conv_kernel_diagonal_spatial_transpose( + lhs: Optional[np.ndarray], + filter_shape: Sequence[int], + strides: Sequence[int], + padding: Padding, + batch_ndim: int +) -> Optional[np.ndarray]: + """Compute covariance of the CNN transpose given inputs with covariance `lhs`. + + Used when `kernel.diagonal_spatial == True`. + + Args: + lhs: + an `(S+batch_ndim)`-dimensional `np.ndarray` containing + sample-sample-(same position) covariances of CNN inputs. Has `batch_ndim` + batch and `S` spatial dimensions with the shape of `(batch_size_1, + [batch_size_2,] height, width, depth, ...)`. + filter_shape: + tuple of positive integers, the convolutional filters spatial shape + (e.g. `(3, 3)` for a 2D convolution). + strides: + tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a 2D + convolution). + padding: + a `Padding` enum, e.g. `Padding.CIRCULAR`. + batch_ndim: + number of leading batch dimensions, 1 or 2. + + Returns: + an `(S+batch_ndim)`-dimensional `np.ndarray` containing + sample-sample-(same position) covariances of CNN outputs. Has `batch_ndim` + batch and `S` spatial dimensions with the shape of `(batch_size_1, + [batch_size_2,] new_height, new_width, new_depth, ...)`. + """ + if lhs is None or lhs.ndim == 0: + return lhs + + spatial_axes = tuple(range(batch_ndim, lhs.ndim)) + apply_padding = Padding.VALID if padding == Padding.CIRCULAR else padding + + out = _pool_transpose(lhs, filter_shape, strides, spatial_axes, apply_padding) + + if padding == Padding.CIRCULAR: + out_shape = eval_shape(lambda x: _pool_transpose( + x, + filter_shape, + strides, + spatial_axes, + padding.SAME), lhs).shape + out = _same_pad_for_filter_shape_transpose(out, spatial_axes, out_shape) + + filter_size = functools.reduce(op.mul, filter_shape, 1) + return out / filter_size def _pool_kernel( - mat: Optional[np.ndarray], + lhs: Optional[np.ndarray], pool_type: Pooling, - window_shape: Tuple[int, ...], - strides: Tuple[int, ...], + window_shape: Sequence[int], + strides: Sequence[int], padding: Padding, normalize_edges: bool, batch_ndim: int) -> Optional[np.ndarray]: - """Get covariances of pooling outputs given inputs covariances `mat`. + """Get covariances of pooling outputs given inputs covariances `lhs`. Args: - mat: a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing + lhs: a `(2*S+batch_ndim)`-dimensional `np.ndarray` containing sample-[sample-]position-position covariances of pooling inputs, where `S` is the number of spatial dimensions (e.g. 2 for images). Has shape `(batch_size_1, [batch_size_2,] @@ -3753,34 +4155,33 @@ def _pool_kernel( `(batch_size_1, [batch_size_2,] height, height, width, width, depth, depth, ...)`. """ - if mat is None or mat.ndim == 0: - return mat + if lhs is None or lhs.ndim == 0: + return lhs if padding == Padding.CIRCULAR: - spatial_axes = tuple(range(batch_ndim, mat.ndim)) - mat = _same_pad_for_filter_shape(mat, _double_tuple(window_shape), - _double_tuple(strides), spatial_axes, - 'wrap') + spatial_axes = tuple(range(batch_ndim, lhs.ndim)) + lhs = _same_pad_for_filter_shape(lhs, _double_tuple(window_shape), + _double_tuple(strides), spatial_axes) padding = Padding.VALID window_shape = (1,) * batch_ndim + _double_tuple(window_shape) strides = (1,) * batch_ndim + _double_tuple(strides) - nngp_out = lax.reduce_window(mat, 0., lax.add, window_shape, strides, - padding.name) + mat_out = lax.reduce_window(lhs, 0., lax.add, window_shape, strides, + padding.name) if pool_type == Pooling.AVG: if padding == Padding.SAME and normalize_edges: # `SAME` padding in `jax.experimental.stax.AvgPool` normalizes by actual # window size, which is smaller at the edges. - one = np.ones_like(mat, mat.dtype) + one = np.ones_like(lhs, lhs.dtype) window_sizes = lax.reduce_window(one, 0., lax.add, window_shape, strides, padding.name) - nngp_out /= window_sizes + mat_out /= window_sizes else: - nngp_out /= onp.prod(window_shape) + mat_out /= onp.prod(window_shape) - return nngp_out + return mat_out def _diag_mul_full_spatial( @@ -3843,14 +4244,7 @@ def _diag_mul( _NEG_INF = -1e20 # softmax raises an error if all entries are -np.inf -def _check_is_implemented( - mask: np.ndarray, - padding: Optional[Padding], - channel_axis: int) -> None: - if padding == Padding.CIRCULAR: - raise NotImplementedError(f'{padding} padding is not implemented for ' - f'masked inputs.') - +def _check_is_implemented(mask: np.ndarray, channel_axis: int) -> None: if mask.shape[channel_axis] != 1: raise NotImplementedError( 'Different channel-wise masks as inputs to ' @@ -3921,7 +4315,7 @@ def _map_tuples(fn: Callable, tuples: Iterable[Tuple]) -> Tuple: def _concat_masks( masks: List[Optional[np.ndarray]], - input_shapes: List[Tuple[int, ...]], + input_shapes: Sequence[Sequence[int]], axis: int) -> Optional[np.ndarray]: """Returns a mask which is a concatenation of `masks`. @@ -3950,7 +4344,9 @@ def _concat_masks( # Expand the concatenation dimension of each mask. masks = [m if m is None else np.broadcast_to( m, - m.shape[:axis] + input_shapes[i][axis: axis + 1] + m.shape[axis + 1:]) + (m.shape[:axis] + + tuple(input_shapes[i][axis: axis + 1]) + + m.shape[axis + 1:])) for i, m in enumerate(masks)] # Max shape to broadcast all masks to along non-concat dimension. @@ -3973,8 +4369,8 @@ def _concat_masks( def _pool_mask( mask: np.ndarray, - window_shape: Union[List[int], Tuple[int, ...]], - strides: Union[List[int], Tuple[int, ...]], + window_shape: Sequence[int], + strides: Sequence[int], padding: Padding, batch_axis: int, channel_axis: int) -> np.ndarray: @@ -3986,17 +4382,14 @@ def _pool_mask( window_shape.insert(i, 1) strides.insert(i, 1) - padding_vals = lax.padtype_to_pads( - mask.shape, window_shape, strides, padding.name) - # Get the output shape. out_shape = eval_shape(lambda x: lax.reduce_window( - x, - False, - op.or_, - window_shape, - strides, - padding_vals + operand=x, + init_value=np.zeros((), x.dtype), + computation=op.or_, + window_dimensions=window_shape, + window_strides=strides, + padding=padding.name ), mask).shape # If shapes don't match, stride through the mask. @@ -4019,14 +4412,14 @@ def _pool_mask( # POSITIONAL EMBEDDINGS -def _pos_emb_identity(shape: Tuple[int, ...]) -> np.ndarray: +def _pos_emb_identity(shape: Sequence[int]) -> np.ndarray: size = utils.size_at(shape) - R = np.eye(size).reshape(shape * 2) + R = np.eye(size).reshape(tuple(shape) * 2) R = utils.zip_axes(R) return R -def _pos_emb_pdist(shape: Tuple[int, ...], +def _pos_emb_pdist(shape: Sequence[int], pos_emb_p_norm: Optional[float], pos_emb_decay_fn: Optional[Callable[[float], float]] ) -> np.ndarray: diff --git a/neural_tangents/utils/kernel.py b/neural_tangents/utils/kernel.py index e8a15a7b..0207e092 100644 --- a/neural_tangents/utils/kernel.py +++ b/neural_tangents/utils/kernel.py @@ -16,7 +16,7 @@ import operator as op -from typing import Dict, Tuple, Optional, Callable, Any +from typing import Dict, Tuple, Optional, Callable, Any, Sequence import jax.numpy as np from neural_tangents.utils import dataclasses @@ -171,7 +171,7 @@ def reverse(self) -> 'Kernel': ntk=ntk, is_reversed=not self.is_reversed) - def transpose(self, axes: Tuple[int, ...] = None) -> 'Kernel': + def transpose(self, axes: Sequence[int] = None) -> 'Kernel': """Permute spatial dimensions of the `Kernel` according to `axes`. Follows diff --git a/neural_tangents/utils/utils.py b/neural_tangents/utils/utils.py index b298e8b8..49558fa5 100644 --- a/neural_tangents/utils/utils.py +++ b/neural_tangents/utils/utils.py @@ -163,6 +163,25 @@ def x1_is_x2(x1: np.ndarray, return np.all(np.abs(x1 - x2) < eps) +def _get_ndim(x: Union[int, Sized, np.ndarray]) -> int: + if hasattr(x, 'ndim'): + n = x.ndim + elif hasattr(x, '__len__'): + n = len(x) + elif isinstance(x, int): + n = x + else: + raise TypeError(x, type(x)) + return n + + +def mod(axis: Axes, x: Union[int, Sized, np.ndarray]) -> List[int]: + n = _get_ndim(x) + if isinstance(axis, int): + axis = [axis] + return [i % n for i in axis] + + def canonicalize_axis(axis: Axes, x: Union[int, Sized, np.ndarray]) -> List[int]: """Converts axis into a sorted non-negative list. @@ -175,15 +194,8 @@ def canonicalize_axis(axis: Axes, A sorted list of integer axes. """ axis = [axis] if isinstance(axis, int) else list(axis) - if hasattr(x, 'ndim'): - ndim = x.ndim - elif hasattr(x, '__len__'): - ndim = len(x) - elif isinstance(x, int): - ndim = x - else: - raise TypeError(x, type(x)) - return list(set(onp.arange(ndim)[axis])) + n = _get_ndim(x) + return list(set(onp.arange(n)[axis])) def zip_axes(x: np.ndarray, @@ -314,15 +326,20 @@ def outer_prod(x, y, start_axis, end_axis, prod_op): return prod_op(x, y) -def reverse_zipped(mat: np.ndarray, start_axis: int = 0) -> np.ndarray: +def reverse_zipped( + mat: Union[np.ndarray, Sequence[int]], + start_axis: int = 0) -> Union[np.ndarray, Sequence[int]]: if mat is not None: + ndim = _get_ndim(mat) source_axes = tuple(j - for i in range(mat.ndim - 2, start_axis - 1, -2) + for i in range(ndim - 2, start_axis - 1, -2) for j in (i, i + 1)) - target_axes = range(start_axis, mat.ndim) - mat = np.moveaxis(mat, source_axes, target_axes) - + if isinstance(mat, np.ndarray): + target_axes = range(start_axis, ndim) + mat = np.moveaxis(mat, source_axes, target_axes) + else: + mat = mat[:start_axis] + type(mat)(mat[i] for i in source_axes) return mat @@ -368,7 +385,7 @@ def get_masked_array(x: ArrayOrList, mask_mat = None else: mask_mat = lax.cond(np.isnan(mask_constant), - lambda x: np.isnan(x), + np.isnan, lambda x: x == mask_constant, x) else: @@ -407,8 +424,8 @@ def shape_and_axes( return shape, axes -def get_res_batch_dims(contracting_dims: List[int], - batch_dims: List[int]) -> List[int]: +def get_res_batch_dims(contracting_dims: Iterable[int], + batch_dims: Iterable[int]) -> List[int]: res_batch_dims = [2 * b - i for i, b in enumerate(batch_dims)] for i, b in enumerate(batch_dims): for c in contracting_dims: diff --git a/tests/stax_test.py b/tests/stax_test.py index 28dbf7ca..c0c0e843 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -17,15 +17,15 @@ import functools import itertools -import logging import random as prandom import string from typing import Tuple from absl.testing import absltest +from jax import lax from jax import ops from jax import test_util as jtu -from jax.api import jit +from jax.api import jit, vjp from jax.config import config as jax_config from jax.lib import xla_bridge import jax.numpy as np @@ -53,7 +53,7 @@ N_SAMPLES = 100 -RTOL = 0.04 +RTOL = 0.041 FILTER_SHAPES = [ (2, 1), @@ -165,8 +165,6 @@ def _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, if layer_norm: layer_norm = tuple(spec.index(c) for c in layer_norm) - logging.warning(f'DIMENSION NUMBERS: {dimension_numbers}') - def fc(out_dim): return stax.Dense( out_dim=out_dim, @@ -296,8 +294,34 @@ def _get_net_pool(width, is_ntk, pool_type, padding, fc(1 if is_ntk else width)), INPUT_SHAPE, -1, -1 +def _mask(x, mask_constant, mask_axis, key, p): + if mask_constant is not None: + mask_shape = [1 if i in mask_axis else s + for i, s in enumerate(x.shape)] + mask = random.bernoulli(key, p=p, shape=mask_shape) + x = np.where(mask, mask_constant, x) + x = np.sort(x, 1) + return x + + class StaxTest(test_utils.NeuralTangentsTestCase): + def _skip_test(self, filter_shape, is_conv, is_res, padding, proj_into_2d, + strides, use_pooling): + if is_conv: + if xla_bridge.get_backend().platform == 'cpu': + raise absltest.SkipTest('Not running CNN models on CPU to save time.') + + if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or + (padding == 'VALID' and filter_shape != + (1, 1)))): + raise absltest.SkipTest('Different paths in a residual models need to ' + 'return outputs of the same shape.') + elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or + strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or + use_pooling): + raise absltest.SkipTest('FC models do not have these parameters.') + @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': @@ -346,19 +370,8 @@ def test_exact(self, model, width, strides, padding, phi, same_inputs, is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. - if is_conv: - if xla_bridge.get_backend().platform == 'cpu': - raise absltest.SkipTest('Not running CNN models on CPU to save time.') - - if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or - (padding == 'VALID' and filter_shape != - (1, 1)))): - raise absltest.SkipTest('Different paths in a residual models need to ' - 'return outputs of the same shape.') - elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or - strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or - use_pooling): - raise absltest.SkipTest('FC models do not have these parameters.') + self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d, + strides, use_pooling) pool_type = 'AVG' W_std, b_std = 2.**0.5, 0.5**0.5 @@ -620,23 +633,12 @@ def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides, use_dropout = True is_conv = 'conv' in model is_res = False - # Check for duplicate / incorrectly-shaped NN configs / wrong backend. W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None parameterization = 'ntk' - if is_conv: - if xla_bridge.get_backend().platform == 'cpu': - raise absltest.SkipTest('Not running CNN models on CPU to save time.') - - if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or - (padding == 'VALID' and filter_shape != - (1, 1)))): - raise absltest.SkipTest('Different paths in a residual models need to ' - 'return outputs of the same shape.') - elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or - strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or - use_pooling): - raise absltest.SkipTest('FC models do not have these parameters.') + # Check for duplicate / incorrectly-shaped NN configs / wrong backend. + self._skip_test(filter_shape, is_conv, is_res, padding, proj_into_2d, + strides, use_pooling) net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, @@ -1246,7 +1248,7 @@ def _get_phi(cls, i): } for same_inputs in [False] for axis in [0, 1] - for n_branches in [2, 3] for get in ['nngp', 'ntk'] + for n_branches in [2, 3] for get in ['ntk'] for branch_in in ['dense_before_branch_in', 'dense_after_branch_in'] for fan_in_mode in ['FanInSum', 'FanInConcat', 'FanInProd'])) @@ -1262,9 +1264,10 @@ def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') - if (axis == 1 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in': + if ((axis == 1 or fan_in_mode == 'FanInProd') and + branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( - '`FanInConcat` or `FanInProd` on feature axis requires a dense layer' + '`FanInConcat` or `FanInProd` on feature axis requires a dense layer ' 'after concatenation or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() @@ -1356,7 +1359,7 @@ def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, } for same_inputs in [False] for axis in [0, 1, 2, 3] - for n_branches in [2, 3] for get in ['nngp', 'ntk'] + for n_branches in [2, 3] for get in ['ntk'] for branch_in in ['dense_before_branch_in', 'dense_after_branch_in'] for readout in ['pool', 'flatten'] for fan_in_mode in ['FanInSum', 'FanInConcat', 'FanInProd'])) @@ -1491,7 +1494,7 @@ class ConvNDTest(test_utils.NeuralTangentsTestCase): } for same_inputs in [False] for n in [0, 1, 2, 3] - for get in ['nngp', 'ntk'] + for get in ['ntk'] for proj in ['flatten', 'pool'] for use_attn in [True] for channels_first in [True, False] @@ -1595,10 +1598,10 @@ def test_conv_nd(self, same_inputs, n, get, proj, use_attn, channels_first, { 'testcase_name': ' [{}_out={}_in={}]'.format( - 'same_inputs' if same_inputs else 'different_inputs', - readout[0].__name__, - readin[0].__name__ - ), + 'same_inputs' if same_inputs else 'different_inputs', + readout[0].__name__, + readin[0].__name__ + ), 'same_inputs': same_inputs, 'readout': @@ -1792,12 +1795,11 @@ class MaskingTest(test_utils.NeuralTangentsTestCase): 'p': p, } - for same_inputs in [False] for get in ['ntk', 'nngp'] + for same_inputs in [False] for get in ['ntk'] for concat in [None, 0, 1] for p in [0.5] for mask_axis in [(), (0,), (1, 3), - (0, 2, 3), (0, 1, 2, 3)] for mask_constant in [10.])) def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): @@ -1806,22 +1808,14 @@ def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): tol = 0.03 key = random.PRNGKey(1) - def apply_mask(x): - if mask_constant is not None: - mask_shape = [1 if i in mask_axis else s - for i, s in enumerate(x.shape)] - mask = random.bernoulli(key, p=p, shape=mask_shape) - x = np.where(mask, mask_constant, x) - return x - x1 = random.normal(key, (4, 6, 5, 7)) - x1 = apply_mask(x1) + x1 = _mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = random.normal(key, (2, 6, 5, 7)) - x2 = apply_mask(x2) + x2 = _mask(x2, mask_constant, mask_axis, key, p) nn = stax.serial( stax.Flatten(), @@ -1867,7 +1861,8 @@ def apply_mask(x): @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': - ' [{}_get={}_axis={}_mask={}_concat={}_{}_p={}_attn={}_n={}]'.format( + ' [{}_get={}_axis={}_mask={}_concat={}_{}_p={}_n={}_{}]' + ''.format( 'same_inputs' if same_inputs else 'different_inputs', get, mask_axis, @@ -1875,8 +1870,8 @@ def apply_mask(x): concat, proj, p, - use_attn, - n + n, + 'transpose' if transpose else '' ), 'same_inputs': same_inputs, 'get': get, @@ -1885,61 +1880,50 @@ def apply_mask(x): 'concat': concat, 'proj': proj, 'p': p, - 'use_attn': use_attn, - 'n': n + 'n': n, + 'transpose': transpose } - for proj in ['avg'] - for use_attn in [True] + for proj in ['flatten', 'avg'] for same_inputs in [False] - for get in ['nngp', 'ntk'] - for n in [2] + for get in ['ntk'] + for n in [0, 1, 2] for concat in [None] + list(range(n + 1)) for mask_constant in [10.] for p in [0.5] + for transpose in [True, False] for mask_axis in [(), (0,), - (1,), - (2, 3), - (0, 1, 3), - (0, 1, 2, 3)] + (0, 1, 2, 3) + ] )) def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, - proj, p, use_attn, n): + proj, p, n, transpose): if xla_bridge.get_backend().platform == 'cpu': raise absltest.SkipTest('Skipping CNN tests on CPU for speed.') elif xla_bridge.get_backend().platform == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') - width = 1024 - n_samples = 128 - tol = 0.025 + width = 128 + n_samples = 256 + tol = 0.05 key = random.PRNGKey(1) - spatial_shape = (15, 8, 9)[:n] - filter_shape = (7, 2, 3)[:n] - strides = (2, 3, 1)[:n] - spatial_spec = 'HWD'[:n] + spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] + filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] + strides = (2, 1, 3, 2, 3)[:n] + spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') - def apply_mask(x): - if mask_constant is not None: - mask_shape = [1 if i in mask_axis else s - for i, s in enumerate(x.shape)] - mask = random.bernoulli(key, p=p, shape=mask_shape) - x = np.where(mask, mask_constant, x) - x = np.sort(x, 1) - return x - - x1 = random.normal(key, (4,) + spatial_shape + (3,)) - x1 = apply_mask(x1) + x1 = np.cos(random.normal(key, (2,) + spatial_shape + (2,))) + x1 = _mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: - x2 = random.normal(key, (2,) + spatial_shape + (3,)) - x2 = apply_mask(x2) + x2 = np.cos(random.normal(key, (4,) + spatial_shape + (2,))) + x2 = _mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( @@ -1947,23 +1931,25 @@ def get_attn(): n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), - ) if use_attn else stax.Identity() + ) if proj == 'avg' else stax.Identity() + + conv = stax.GeneralConvTranspose if transpose else stax.GeneralConv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( - stax.GeneralConv( + conv( dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, - padding='SAME', - W_std=1.1, - b_std=0.1), + padding='CIRCULAR', + W_std=1.5, + b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), - stax.GeneralConv( + conv( dimension_numbers=dimension_numbers, out_chan=width, strides=strides, @@ -1973,7 +1959,7 @@ def get_attn(): b_std=0.1), ), stax.serial( - stax.GeneralConv( + conv( dimension_numbers=dimension_numbers, out_chan=width, strides=strides, @@ -1983,7 +1969,7 @@ def get_attn(): b_std=0.3), stax.Relu(), stax.Dropout(0.7), - stax.GeneralConv( + conv( dimension_numbers=dimension_numbers, out_chan=width, strides=strides, @@ -1994,17 +1980,17 @@ def get_attn(): ), stax.serial( get_attn(), - stax.GeneralConv( + conv( dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, - padding='SAME', + padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), - stax.GeneralConv( + conv( dimension_numbers=dimension_numbers, out_chan=width, strides=strides, @@ -2075,7 +2061,6 @@ class AttentionTest(test_utils.NeuralTangentsTestCase): False ] for get in [ - 'nngp', 'ntk' ] for n in [ @@ -2137,12 +2122,7 @@ def test_attention( def get_x0(batch_size): x0 = random.normal(key, (batch_size,) + spatial_shape + (n_chan_in,)) - if mask_constant is not None: - mask_shape = [1 if i in mask_axis else s - for i, s in enumerate(x0.shape)] - mask = random.bernoulli(key, p=p, shape=mask_shape) - x0 = np.where(mask, mask_constant, x0) - x0 = np.sort(x0, 1) + x0 = _mask(x0, mask_constant, mask_axis, key, p) return x0 X0_1 = get_x0(2) @@ -2210,7 +2190,7 @@ class AggregateTest(test_utils.NeuralTangentsTestCase): 'same_input': same_input, 'activation': activation, 'test_mask': test_mask, - } for get in ['ntk', 'nngp'] + } for get in ['ntk'] for name, readout in [ ('Flattten', stax.Flatten()), ('Pooling', stax.GlobalAvgPool())] @@ -2265,5 +2245,154 @@ def test_aggregate(self, get, readout, same_input, activation, test_mask): test_utils.assert_close_matrices(self, exact, empirical, rtol) +class ConvTransposeTest(test_utils.NeuralTangentsTestCase): + + @jtu.parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + f'_same_inputs={same_inputs}_{padding}_size={size}_' + f'strides={strides}_filter={filter_shape}_' + f'diag_batch={diagonal_batch}_diag_spatial={diagonal_spatial}', + 'padding': padding, + 'size': size, + 'same_inputs': same_inputs, + 'filter_shape': filter_shape, + 'strides': strides, + 'diagonal_batch': diagonal_batch, + 'diagonal_spatial': diagonal_spatial + } + for padding in ['CIRCULAR', 'SAME', 'VALID'] + for same_inputs in [True, False] + for filter_shape in range(2, 5) + for strides in range(2, 5) + for size in range(2, 5) + for diagonal_batch in [True, False] + for diagonal_spatial in [True, False])) + def test_conv_transpose(self, same_inputs, padding, filter_shape, strides, + size, diagonal_batch, diagonal_spatial): + platform = xla_bridge.get_backend().platform + if platform == 'cpu' and size > 2: + raise absltest.SkipTest('Skipping large tests on CPU for speed.') + + width = 512 + tol = 0.01 + n_samples = 512 + filter_shape = (filter_shape,) + strides = (strides,) + + init_fn, apply_fn, kernel_fn = stax.ConvTranspose(width, + filter_shape, + strides, + padding, + b_std=0.1) + + key = random.PRNGKey(1) + shape = (size, 1) + x1 = random.normal(key, (2,) + shape) + x2 = random.normal(key, (3,) + shape) if not same_inputs else None + + k = kernel_fn(x1, x2, + diagonal_batch=diagonal_batch, + diagonal_spatial=diagonal_spatial, + get='cov1' if diagonal_batch else 'nngp') + + diagonal_axes = () + if diagonal_batch: + diagonal_axes += (0,) + if diagonal_spatial: + diagonal_axes += (1,) + + kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( + init_fn, apply_fn, key, n_samples, diagonal_axes=diagonal_axes, + device_count=0) + k_mc = kernel_fn_mc(x1, None if diagonal_batch else x2, 'nngp') + + test_utils.assert_close_matrices(self, k_mc, k, tol) + + @classmethod + def _conv_transpose_circular_via_grad(cls, + lhs, + params, + strides, + padding, + dimension_numbers): + """Helper method: calculates conv transpose via grad for testing. + + Adapted from `jax.tests.lax_test`. + """ + rhs = params[0] + rhs = np.swapaxes(rhs, dimension_numbers[1].index('O'), + dimension_numbers[1].index('I')) + rhs = np.flip(rhs, dimension_numbers[1].index('H')) + assert len(lhs.shape) == len(rhs.shape) + nspatial = len(lhs.shape) - 2 + dn = lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers) + in_shape = np.take(lhs.shape, dn.lhs_spec) + in_sdims = in_shape[2:] + k_shape = np.take(rhs.shape, dn.rhs_spec) + o_sdims = [in_sdims[i]*strides[i] for i in range(nspatial)] + o_shape = [in_shape[0], k_shape[1]] + o_sdims + out_spec_inv = [x[0] for x in + sorted(enumerate(dn.out_spec), key=lambda x: x[1])] + o_layout = np.take(np.array(o_shape), out_spec_inv) + placeholder = np.ones(o_layout, lhs.dtype) + + _, apply_fn, _ = stax.GeneralConv( + dimension_numbers=dimension_numbers, + out_chan=rhs.shape[dimension_numbers[1].index('I')], + filter_shape=(rhs.shape[dimension_numbers[1].index('H')],), + strides=strides, + padding=padding, + parameterization='standard') + conv = lambda x: apply_fn((rhs, 0.), x) + _, g = vjp(conv, placeholder) + return g(lhs)[0] + + @classmethod + def _conv_transpose_circular(cls, + lhs, + params, + strides, + padding, + dimension_numbers): + """Helper method: calculates conv transpose.""" + _, apply_fn, _ = stax.GeneralConvTranspose( + dimension_numbers=dimension_numbers, + out_chan=params[0].shape[dimension_numbers[1].index('O')], + filter_shape=(params[0].shape[dimension_numbers[1].index('H')],), + strides=strides, + padding=padding, + parameterization='standard') + return apply_fn((params[0], 0.), lhs) + + @jtu.parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': + f'size={size}_strides={strides}_filter={filter_shape}', + 'size': size, + 'filter_shape': filter_shape, + 'strides': strides, + } + for filter_shape in range(1, 5) + for strides in range(1, 5) + for size in range(1, 5))) + def test_conv_transpose_circular(self, size, filter_shape, strides): + if xla_bridge.get_backend().platform == 'cpu' and size > 2: + raise absltest.SkipTest('Skipping large tests on CPU for speed.') + + x = random.normal(random.PRNGKey(1), (2, size, 3)) + dn = ('NHC', 'HIO', 'NHC') + padding = 'CIRCULAR' + filter_shape = (filter_shape,) + strides = (strides,) + + init_fn, _, _ = stax.ConvTranspose(4, filter_shape, strides, padding) + _, params = init_fn(random.PRNGKey(2), x.shape) + f_conv = self._conv_transpose_circular(x, params, strides, padding, dn) + f_adj = self._conv_transpose_circular_via_grad(x, params, strides, padding, + dn) + self.assertAllClose(f_adj, f_conv) + + if __name__ == '__main__': absltest.main()