Skip to content

Commit

Permalink
Make CNN w/pooling propagation use depthwise separable CNN with a sma…
Browse files Browse the repository at this point in the history
…ll but > 1 #channels. This triggers a different CuDNN routine that leads to a >4X speedup!

For TPU use 128 channels, speedup is ~3X. TPU remains slow thought, due to padding elsewhere. Namely, as Blake explained to me in the chat, for TPU we need to have a >=128-sized dimension in the covariance tensor _throughout_ the code (at least - there might be other padding issues). This should be doable by having batches of sizes e.g. (128, 4), but I think batching now only works with square batches.

For CPU no change.

New benchmark for float32, 21 layers 3x3-SAME-CNN-ReLU-GAP is 0.0027001 seconds per NTK entry per V100 (0.0029 w/ 8-GPU batching), which brings us down to theoretical 937 hours for lower triangle of 50K kernel (+ batching/beam overhead).

float64 is ~ 0.0065 seconds (2.3X slower than float32).

PiperOrigin-RevId: 307117825
  • Loading branch information
romanngg committed Apr 17, 2020
1 parent 3aa09df commit 100afac
Showing 1 changed file with 55 additions and 25 deletions.
80 changes: 55 additions & 25 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from jax.api_util import flatten_fun
import jax.experimental.stax as ostax
import jax.interpreters.partial_eval as pe
from jax.lib import xla_bridge
from jax.scipy.special import erf
from jax.tree_util import tree_map, tree_flatten, tree_unflatten
from neural_tangents.utils import utils
Expand Down Expand Up @@ -480,9 +481,9 @@ def GeneralConv(dimension_numbers,
be used instead.
Args:
:padding: in addition to `VALID` and `SAME` padding, supports `CIRCULAR`,
padding: in addition to `VALID` and `SAME` padding, supports `CIRCULAR`,
not available in `jax.experimental.stax.GeneralConv`.
:parameterization: Either "ntk" or "standard". These parameterizations are
parameterization: Either "ntk" or "standard". These parameterizations are
the direct analogues for convolution of the corresponding
parameterizations for `Dense` layers.
"""
Expand Down Expand Up @@ -521,10 +522,9 @@ def Conv(out_chan,
be used instead.
Args:
:padding: in addition to `VALID` and `SAME` padding, supports `CIRCULAR`,
padding: in addition to `VALID` and `SAME` padding, supports `CIRCULAR`,
not available in `jax.experimental.stax.GeneralConv`.
:parameterization: Either "ntk" or "standard". These parameterizations are
parameterization: Either "ntk" or "standard". These parameterizations are
the direct analogues for convolution of the corresponding
parameterizations for `Dense` layers.
"""
Expand Down Expand Up @@ -559,9 +559,9 @@ def _GeneralConv(dimension_numbers,
be used instead.
Args:
:padding: in addition to `VALID` and `SAME` padding, supports `CIRCULAR`,
padding: in addition to `VALID` and `SAME` padding, supports `CIRCULAR`,
not available in `jax.experimental.stax.GeneralConv`.
:parameterization: Either "ntk" or "standard". These parameterizations are
parameterization: Either "ntk" or "standard". These parameterizations are
the direct analogues for convolution of the corresponding
parameterizations for `Dense` layers.
"""
Expand Down Expand Up @@ -1580,7 +1580,7 @@ def kernel_fn_train(kernels):
# INTERNAL UTILITIES


_CONV_KERNEL_DIMENSION_NUMBERS = ('NCHW', 'HWIO', 'NCHW')
_CONV_KERNEL_DIMENSION_NUMBERS = ('NCHW', 'OIHW', 'NCHW')


_INPUT_REQ = 'input_req'
Expand Down Expand Up @@ -2522,11 +2522,7 @@ def _pad_one_side(x, pads, axes, mode):
return x


def _conv_kernel_full_spatial(mat,
filter_shape,
strides,
padding,
batch_ndim):
def _conv_kernel_full_spatial(mat, filter_shape, strides, padding, batch_ndim):
"""Compute covariance of the CNN outputs given inputs with covariance `mat`.
Used when `kernel.diagonal_spatial == False`.
Expand Down Expand Up @@ -2569,25 +2565,59 @@ def _conv_kernel_full_spatial(mat,
spatial_i = (i - batch_ndim) // 2
filter_i = filter_shape[spatial_i]
stride_i = strides[spatial_i]

ker = np.diag(np.full((filter_i,), 1. / filter_i, mat.dtype))
for c in _CONV_KERNEL_DIMENSION_NUMBERS[1]:
if c in ('I', 'O'):
ker = np.expand_dims(ker, _CONV_KERNEL_DIMENSION_NUMBERS[1].index(c))

size_i = mat.shape[i]

mat = np.moveaxis(mat, (i - 1, i), (-2, -1))
mat_preshape = mat.shape[:-2]
mat = np.expand_dims(mat.reshape((-1, size_i, size_i)),
_CONV_KERNEL_DIMENSION_NUMBERS[0].index('C'))

rhs = np.diag(np.full((filter_i,), 1. / filter_i, mat.dtype))
rhs_shape = ()

platform = xla_bridge.get_backend().platform
if platform in ['gpu', 'tpu']:
batch_and_channels = functools.reduce(op.mul, mat_preshape, 1)
n_channels = batch_and_channels

# Find smallest `n_channels > 1` that divides `batch_and_features`; use
# depthwise-separable CNN. For `n_channels == 1` CuDNN appears to invoke a
# different algorithm (`void cudnn::detail::implicit_convolve_sgemm`) than
# in any other case (`conv2d_c1_k1_nchw_hw_packed_kernel`), and the latter
# seems many-fold faster.
# For TPU, start with `n_channels >= 128`. Beware of precision errors:
# TODO(romann): revisit based on b/154160868, b/154165148.
n_channels_min = 2 if platform == 'gpu' else 128

for n_c in range(n_channels_min, batch_and_channels):
if batch_and_channels % n_c == 0:
n_channels = n_c
break

elif platform == 'cpu':
# For CPU minimal channels seems best.
n_channels = 1

else:
raise NotImplementedError(platform)

mat = mat.reshape((-1, n_channels, size_i, size_i))

for c in _CONV_KERNEL_DIMENSION_NUMBERS[1]:
if c == 'O':
rhs_shape += (n_channels,)
elif c == 'I':
rhs_shape += (1,)
else:
rhs_shape += (filter_i,)

rhs = np.broadcast_to(rhs, rhs_shape)

mat = lax.conv_general_dilated(
lhs=mat,
rhs=ker,
rhs=rhs,
window_strides=(stride_i, stride_i),
padding=padding.name,
dimension_numbers=_CONV_KERNEL_DIMENSION_NUMBERS)
mat = np.squeeze(mat,
_CONV_KERNEL_DIMENSION_NUMBERS[2].index('C'))
dimension_numbers=_CONV_KERNEL_DIMENSION_NUMBERS,
feature_group_count=n_channels)
mat = mat.reshape(mat_preshape + mat.shape[-2:])

return mat
Expand Down

0 comments on commit 100afac

Please sign in to comment.