In [1]:
import haiku as hk
import jax
import jax.numpy as jnp

def forward(x):
  mlp = hk.nets.MLP([300, 100, 10])
  return mlp(x)

forward = hk.transform(forward)

rng = jax.random.PRNGKey(42)
x = jnp.ones([8, 28 * 28])



In [2]:
x.shape

(8, 784)

In [3]:
temp = jax.random.normal(rng, [3, 3])
temp

DeviceArray([[ 0.36900425, -0.46067554, -0.86509347],
             [ 1.2080882 ,  0.59699154, -0.87080586],
             [-0.3984998 , -0.6670093 ,  0.33689347]], dtype=float32)

In [4]:
hk.max_pool(temp, window_shape=(2,2), strides=[1], padding="VALID")

DeviceArray([[1.2080882 , 0.59699154],
             [1.2080882 , 0.59699154]], dtype=float32)

In [37]:
temp.shape

(3, 3)

In [5]:
import types
from typing import Optional, Sequence, Tuple, Union
import warnings

from haiku._src import module
from jax import lax
import jax.numpy as jnp
import numpy as np

In [6]:
def _infer_shape(
    x: jnp.ndarray,
    size: Union[int, Sequence[int]],
    channel_axis: Optional[int] = -1,
) -> Tuple[int, ...]:
  """Infer shape for pooling window or strides."""
  if isinstance(size, int):
    if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
      raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
    if channel_axis and channel_axis < 0:
      print(channel_axis)
      channel_axis = x.ndim + channel_axis
      print(channel_axis)
    return (1,) + tuple(size if d != channel_axis else 1
                        for d in range(1, x.ndim))
  elif len(size) < x.ndim:
    # Assume additional dimensions are batch dimensions.
    return (1,) * (x.ndim - len(size)) + tuple(size)
  else:
    assert x.ndim == len(size)
    return tuple(size)

In [49]:
def adaptive_pool(
    value: jnp.ndarray,
    padding: str,
    channel_axis: Optional[int] = -1,
    output_size:Union[int, Sequence[int]],
) -> jnp.ndarray:
  """Max pool.
  Args:
    value: Value to pool.
    window_shape: Shape of the pooling window, an int or same rank as value.
    strides: Strides of the pooling window, an int or same rank as value.
    padding: Padding algorithm. Either ``VALID`` or ``SAME``.
    channel_axis: Axis of the spatial channels for which pooling is skipped,
      used to infer ``window_shape`` or ``strides`` if they are an integer.
  Returns:
    Pooled result. Same rank as value.
  """
  input_size = value.shape[-2:]
  strides = (input_size//output_size)  
  window_shape = input_size - (output_size-1)*stride  
  padding = "VALID"

  window_shape = _infer_shape(value, window_shape, channel_axis)
  strides = _infer_shape(value, strides, channel_axis)

  return lax.reduce_window(value, -jnp.inf, lax.max, window_shape, strides,
                           padding)

In [94]:
def adaptive_pool(
    value: jnp.ndarray,
    window_shape: Union[int, Sequence[int]],
    strides: Union[int, Sequence[int]],
    padding: str,
    channel_axis: Optional[int] = -1,
) -> jnp.ndarray:
  """Max pool.
  Args:
    value: Value to pool.
    window_shape: Shape of the pooling window, an int or same rank as value.
    strides: Strides of the pooling window, an int or same rank as value.
    padding: Padding algorithm. Either ``VALID`` or ``SAME``.
    channel_axis: Axis of the spatial channels for which pooling is skipped,
      used to infer ``window_shape`` or ``strides`` if they are an integer.
  Returns:
    Pooled result. Same rank as value.
  """

#   window_shape = _infer_shape(value, window_shape, channel_axis)
  input_size = jax.numpy.asarray(value.shape[-2:])
  window_shape = jax.numpy.asarray(window_shape)
  strides = jax.numpy.floor_divide(input_size, window_shape)  
  window_shape = input_size - (window_shape-1)*strides  
  print(f"window shape is {window_shape}     {strides}")

In [95]:
adaptive_pool(temp, window_shape=(2,2), strides=[1], padding="VALID")

window shape is [2 2]     [1 1]


In [33]:
import numpy as np

def adaptive_avg_pool(
    value: jnp.ndarray,
    out_shape: Union[int, Sequence[int]],
    padding: str,
    channel_axis: Optional[int] = -1,
) -> jnp.ndarray:

  out_shape = _infer_shape(value, out_shape, channel_axis)
  input_size = value.shape[-2:]
  strides = np.array(input_size) // np.array(out_shape)
  assert out_shape[-1] == out_shape[-2] 
  out_shape = input_size - (out_shape[-1] - 1) * strides

  return hk.avg_pool(value, out_shape, strides, padding=padding)

In [34]:
avg_pool(temp, out_shape=(4,4), strides=[1], padding="VALID")

TypeError: reduce_window window_strides must have every element be strictly positive, got (0, 0).

In [26]:
temp

DeviceArray([[ 0.36900425, -0.46067554, -0.86509347],
             [ 1.2080882 ,  0.59699154, -0.87080586],
             [-0.3984998 , -0.6670093 ,  0.33689347]], dtype=float32)

In [114]:
hk.avg_pool(temp, window_shape=(2,2), strides=[1], padding="VALID")

DeviceArray([[ 0.42835212, -0.39989582],
             [ 0.18489267, -0.15098253]], dtype=float32)

In [86]:
type(jnp.ndarray(temp.shape[-2:]))

TypeError: Can't instantiate abstract class ndarray with abstract methods __abs__, __add__, __and__, __bool__, __complex__, __divmod__, __eq__, __float__, __floordiv__, __ge__, __getitem__, __gt__, __index__, __int__, __invert__, __iter__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __radd__, __rand__, __rdivmod__, __reversed__, __rfloordiv__, __rlshift__, __rmatmul__, __rmod__, __rmul__, __ror__, __round__, __rpow__, __rrshift__, __rshift__, __rsub__, __rtruediv__, __rxor__, __setitem__, __sub__, __truediv__, __xor__, all, any, argmax, argmin, argpartition, argsort, astype, at, aval, choose, clip, compress, conj, conjugate, copy, cumprod, cumsum, diagonal, dot, flatten, imag, item, max, mean, min, nbytes, nonzero, prod, ptp, ravel, real, repeat, reshape, round, searchsorted, sort, squeeze, std, sum, swapaxes, take, tobytes, tolist, trace, transpose, var, view, weak_type

In [88]:
jax.numpy.floor_divide(jax.numpy.asarray(temp.shape[-2:]), jax.numpy.asarray(temp.shape[-2:]))

DeviceArray([1, 1], dtype=int32)

ModuleNotFoundError: No module named 'torch'