Skip to content

Commit

Permalink
Project import generated by Copybara.
Browse files Browse the repository at this point in the history
GitHub commits, most recent first:
  - f78ccf1c392c865434118403719ecbc3206db2ad Fix tuple() in reduce_window padding (#3748) by James Bradbury <jekbradbury@google.com>
  - a017c1088c6ab7c9230e7ee6ad3aef4ac7017d92 Implement nearest neighbor image resizes. (#3743) by Peter Hawkins <phawkins@google.com>
  - 9a867ca658b89c8147f93120c8afc442f9d05f92 Remove unused private function (#3744) by Jake Vanderplas <jakevdp@google.com>
  - a6ab742f3635ed7c6d8c499c4be8aeb2a4181922 Improve np.intersect1d (#3739) by Alex Dragan <35031007+aldragan0@users.noreply.github.com>
  - 6017205ceaa279b91f9c9589c05a1e93fe2b1150 Add defensive tuple() in lax.reduce_window (#3741) by James Bradbury <jekbradbury@google.com>
  - 3c6cd5fb946678e1ec42b1f1b1a23b178a0906b6 Implement complex convolutions on CPU and GPU. (#3735) by Peter Hawkins <phawkins@google.com>
  - 6391cfe7d0ef9b6a03a33ff82157e83d5a88e58a [jax2tf] First draft of converting sort_p using TF2XLA (#... by SIben <3920784+SIben@users.noreply.github.com>
  - 71253ac4c1d1ef4e7c20f665c37f622936378eac Generalize reduce-window padding to support (lo, hi) pair... by Peter Hawkins <phawkins@google.com>
  - a9da06ce75f07a1fd3f90a784710c12cfc161900 Fix shape error when taking JVP of reduce-prod over size ... by Peter Hawkins <phawkins@google.com>
  - 0d81e988d8884a95ced928b026fec003bb34af57 Implement np.intersect1d (#3726) by Alex Dragan <35031007+aldragan0@users.noreply.github.com>

PiperOrigin-RevId: 321077183
  • Loading branch information
jekbradbury authored and sschoenholz committed Jul 14, 2020
1 parent d5eb2af commit 93dca5e
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3046,7 +3046,9 @@ def _conv_kernel_diagonal_spatial(
filter_size = functools.reduce(op.mul, filter_shape, 1)
filter_shape = (1,) * batch_ndim + filter_shape
strides = (1,) * batch_ndim + strides
mat = lax._reduce_window_sum(mat, filter_shape, strides, padding.name)
padding_vals = lax.padtype_to_pads(
mat.shape, filter_shape, strides, padding.name)
mat = lax._reduce_window_sum(mat, filter_shape, strides, padding_vals)
mat /= filter_size
return mat

Expand Down Expand Up @@ -3318,12 +3320,15 @@ def _pool_mask(
window_shape.insert(i, 1)
strides.insert(i, 1)

padding_vals = lax.padtype_to_pads(
mask.shape, window_shape, strides, padding.name)

# Get the output shape.
out_shape = lax.reduce_window_shape_tuple(
mask.shape,
window_shape,
strides,
padding.name
padding_vals
)

# If shapes don't match, stride through the mask.
Expand Down

0 comments on commit 93dca5e

Please sign in to comment.