Skip to content

Commit

Permalink
Remove unused duplicate code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 319712862
  • Loading branch information
romanngg committed Jul 6, 2020
1 parent 24549f3 commit 7fdb1e9
Showing 1 changed file with 0 additions and 46 deletions.
46 changes: 0 additions & 46 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3051,52 +3051,6 @@ def _conv_kernel_diagonal_spatial(
return mat


def _conv_kernel_over_spatial(
mat: Optional[np.ndarray],
filter_shape: Tuple[int, ...],
strides: Tuple[int, ...],
padding: Padding,
batch_ndim: int
) -> Optional[np.ndarray]:
"""Compute covariance of the CNN outputs given inputs with covariance `mat`.
Used when `kernel.diagonal_spatial == True`.
Args:
mat: an `(S+batch_ndim)`-dimensional `np.ndarray` containing
sample-sample-(same position) covariances of CNN inputs. Has `batch_ndim`
batch and `S` spatial dimensions with the shape of
`(batch_size_1, [batch_size_2,] height, width, depth, ...)`.
filter_shape: tuple of positive integers, the convolutional filters spatial
shape (e.g. `(3, 3)` for a 2D convolution).
strides: tuple of positive integers, the CNN strides (e.g. `(1, 1)` for a
2D convolution).
padding: a `Padding` enum, e.g. `Padding.CIRCULAR`.
batch_ndim: integer, number of leading batch dimensions, 1 or 2.
Returns:
an `(S+batch_ndim)`-dimensional `np.ndarray` containing
sample-sample-(same position) covariances of CNN outputs. Has `batch_ndim`
batch and `S` spatial dimensions with the shape of
`(batch_size_1, [batch_size_2,] new_height, new_width, new_depth, ...)`.
"""
if mat is None or mat.ndim == 0:
return mat

if padding == Padding.CIRCULAR:
spatial_axes = tuple(range(mat.ndim)[batch_ndim:])
mat = _same_pad_for_filter_shape(mat, filter_shape, strides,
spatial_axes, 'wrap')
padding = Padding.VALID

filter_size = utils.size_at(filter_shape)
filter_shape = (1,) * batch_ndim + filter_shape
strides = (1,) * batch_ndim + strides
mat = lax._reduce_window_sum(mat, filter_shape, strides, padding.name)
mat /= filter_size
return mat


def _pool_kernel(
mat: Optional[np.ndarray],
pool_type: Pooling,
Expand Down

0 comments on commit 7fdb1e9

Please sign in to comment.