# Two Dimensional Pooling Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, pooling]
- hide: false

## Introduction

In the style of my last post on 2D convolutional layers, 

## Purpose of Pooling Layers


## Import Libraries

For now, I only need numpy and tensorflow.

In [4]:
import numpy as np
import tensorflow as tf

## Implementation from First Principles

In [76]:
import numba

In [101]:
def downsample_images(input_batch, pool_size, strides, pool_fn=np.max):
    batch_size, xm, xn, num_channels  = input_batch.shape 
    
    pm, pn = pool_size 
    sm, sn = strides
    
    ym, yn = 1 + ((xm - pm)//sm), 1 + ((xn - pn)//sn) 
    y = np.zeros((batch_size, ym, yn, num_channels))

    for iy, ix in enumerate(range(0, xm-pm+1, sm)):
        for jy, jx in enumerate(range(0, xn-pn+1, sn)):
            y[:, iy, jy, :] = pool_fn(
                input_batch[:, ix:ix+pm, jx:jx+pn, :],
                axis=(1,2)
            )         
    return y    

## Compare to Keras

In [102]:
layer_keras = tf.keras.layers.MaxPool2D( 
    pool_size=(4, 4), 
    strides=(2,2),
    padding='valid'
)

In [103]:
input_batch = np.random.randn(32,28,28,3)

In [104]:
output_batch_keras = layer_keras(input_batch)

In [107]:
output_batch_numpy = downsample_images(
    input_batch, 
    pool_size=layer_keras.pool_size,
    strides=layer_keras.strides,
    pool_fn = np.max
)

To check that the Keras output and numpy outputs are approximately equal, I make sure that the absolute error is below a threshold.

In [109]:
assert np.max(np.abs(output_batch_keras - output_batch_numpy)) < 1e-16

Although the outputs are about the same, the Keras version runs much faster, as the following benchmarks show.    

In [110]:
%%timeit -n 100
downsample_images(
    input_batch, 
    pool_size=layer_keras.pool_size,
    strides=layer_keras.strides,
)

3.61 ms ± 47.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [111]:
%%timeit -n 100
layer_keras(input_batch)

666 µs ± 52.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Pooling Layer in JAX

In [112]:
import jax
import jax.numpy as jnp
from typing import Tuple, List, Callable
from fastcore.basics import patch

In [116]:
class MaxPool2D: 
    pool_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str

#### Constructor


In [124]:
@patch
def __init__(
    self: MaxPool2D, 
    pool_shape=(2,2), 
    strides=(1,1), 
    padding='valid'):
    
    self.pool_shape = pool_shape
    self.strides = strides 
    self.padding = padding
         

#### `__call__` Method

To implement `__call__` we use the JAX builtin function `conv_general_dilated`.  Except for the `dimension_numbers` argument, it's pretty easy to figure out what it's doing (but I'm still not clear on how it works - maybe save that for another post).  Like the Keras `Conv2D` layer, it has additional input arguments that give you further control over the convolution.  I'm not including these additional arguments here because I'm trying to keep things as simple as possible.  

In [125]:
@patch
def __call__(self: MaxPool2D, batch: jnp.ndarray):
    
    pool_shape = (1, ) + self.pool_shape + (1, )
    strides = (1, ) + self.strides + (1, ) 
    outputs = jax.lax.reduce_window(
        batch,
        -jnp.inf,
        jax.lax.max,
        window_dimensions=pool_shape,
        window_strides=strides,
        padding=self.padding
    )

    return outputs 

#### Adding to pytree Registry
 

In [128]:
@patch
def tree_flatten(self: MaxPool2D) -> Tuple[List[jnp.ndarray], dict]:
    params = (None,)
    metadata = {k: v for k, v in self.__dict__.items()}
    return params, metadata

In [129]:
@patch(cls_method=True)
def tree_unflatten(cls: MaxPool2D, metadata: dict, params:List[jnp.ndarray]):
    # This assumes that each key-value pair in the metadata dict corresponds to a constructor argument.
    layer = cls(**metadata)
    return layer

In [131]:
_ = jax.tree_util.register_pytree_node_class(MaxPool2D)

ValueError: Duplicate custom PyTreeDef type registration for <class '__main__.MaxPool2D'>.

## Compare JAX and Keras

In [758]:
layer_jax = Conv2D(
    input_channels=layer_keras.input_spec.axes[-1],
    output_channels=layer_keras.filters,
    filter_shape=layer_keras.kernel_size,
    strides=layer_keras.strides,
    padding=layer_keras.padding
)

layer_jax.filters, layer_jax.biases = layer_keras.get_weights()

output_batch_jax = layer_jax(input_batch)

Because the following assertion passes, we can be reassured that both layers are calculating the same result.

In [759]:
assert np.max(np.abs(output_batch_keras - output_batch_jax)) < 1e-6

What about the calculation time?

In [760]:
%%timeit
layer_keras(input_batch)

270 µs ± 8.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [761]:
%%timeit
layer_jax(input_batch)

360 µs ± 8.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Not too bad, the JAX version only about 100 microseconds slower than Keras.  Fortunately, the JAX version can go even faster after applying the `jax.jit` transformation.

In [762]:
layer_jax_jitted = jax.jit(layer_jax)

In [763]:
%%timeit
layer_jax_jitted(input_batch)

224 µs ± 1.78 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Now it's faster than Keras.  I realize you can't trust micro-benchmarks, but it's good to know that with respect to performance, JAX and Keras seem to be in the same ballpark.

## Conclusion

In the post, I implemented the mechanics of a 2D convolutional layer in numpy, from first principles.  After this was proven to be very inefficient compare to Keras, I built a simple layer in JAX and showed that it performed as well as the Keras version on a sample input batch.