# Two Dimensional Pooling Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, pooling]
- hide: false

## Introduction

In this post, I'll be implementing a 2D max-pooling layer from scratch and in JAX.  As before, we'll compare each version to Keras.

## Purpose of Pooling Layers

Pooling layers reduce the number of parameters in convolutional neural networks by downsampling the feature maps generated by the convolutional layers.  Under the hood, they are actually very similar to convolutional layers.  Both layers work by stepping through an input image based on a stride, extracting a chunk of a specified size,  and compressing 3D chunk with some sort of computation.  The main difference between the two layers is the computation.  Recall from the last post that convolutional layers compresses a chunk to a single number by multiplying it's elements with a set of filter coefficients, summing the products, and adding a bias term.  There aren't any filters nor bias terms in pooling layers.  They use simple functions (e.g. `max`, `mean`, etc) that don't depend on any learnable parameters to compress each chunk.  

Unlike convolutional layers, pooling layer retain the number of input channels in the input image.  So you can think of these *pooling functions* as being applied to each channel separately and merged back together.    

## Implementation from First Principles

In [None]:
import numpy as np
import tensorflow as tf

Here's an implementation of a pooling layer.  You can configure how each chunk gets downsampled by choosing different pooling function, `pool_fn`. Because `pool_fn` is set to `np.max` in the argument list, max-pooling is executed by default.

In [195]:
def downsample_images(input_batch, pool_size, strides, pool_fn=np.max):
    batch_size, xm, xn, num_channels  = input_batch.shape 
    
    pm, pn = pool_size 
    sm, sn = strides
    
    #ym, yn = 1 + ((xm - pm)//sm), 1 + ((xn - pn)//sn) 
    ym, yn = 1 + (xm - (pm-1) - 1)/sm,  1 + (xn - (pn-1) - 1)/sn
    ym, yn = np.int(np.ceil(ym)), np.int(np.ceil(yn))
    y = np.zeros((batch_size, ym, yn, num_channels))

    for iy, ix in enumerate(range(0, xm-pm+1, sm)):
        for jy, jx in enumerate(range(0, xn-pn+1, sn)):
            chunk = input_batch[:, ix:ix+pm, jx:jx+pn, :]
            y[:, iy, jy, :] = pool_fn(chunk, axis=(1,2))
    return y    

This looks a lot like the convolution functions implemented in the last post.  It extracts a chunk from the 4D array based on the pooling specification, and then executed `pool_fn` on each channel and for every batch. The lines
```python
y[:, iy, jy, :] = pool_fn(chunk, axis=(1,2))       
```

deserves a little explanation.  Although I didn't explictly mention it earlier, `pool_fn` is a vectorized function that's able to delegate it's work to some optimized low-level code.  If we didn't have vectorized functions, we'd have to use nested loops which we already know are very slow in Python.  The `axis=(1,2)` is saying that we want to perform the `pool_fn` only with respect to the height and width of an image for each batch and each channel.  For pooling, we don't care about the relative values between images in a batch, nor between different channels.  

Here's a snippet of code that does what the code above does: only very slowly:

In [199]:
 def slow_max_pool(chunk):
    num_items, height, width, num_channels = chunk.shape
    result = np.zeros((num_items, num_channels))
    for item in range(num_items):
        for chan in range(num_channels):
            max_val = -np.infty
            for h in range(height):
                for w in range(width):
                    max_val = max(chunk[item, h, w, chan], max_val)
            result[item, chan] = max_val
    return result
                    

Just to make sure, let's compare them.  We'll define a random chunk, check that the vectorized and non-vectorized chunk-pooling give the same answer, and see how long each one takes to execute.

In [204]:
chunk = np.random.randn(32, 28, 28, 3)
result_vectorized = np.max(chunk, axis=(1,2))
result_non_vectorized = slow_max_pool(chunk)

assert np.all(np.isclose(result_non_vectorized, result_vectorized, atol=1e-6))

%timeit -n 100 np.max(chunk, axis=(1,2))
%timeit -n 100 slow_max_pool(chunk)

476 µs ± 13.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
18.5 ms ± 486 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [112]:
jitted_filter_images_v2 = numba.jit(filter_images_v2)

## Compare to Keras

In [186]:
layer_keras = tf.keras.layers.MaxPool2D( 
    pool_size=(4, 4), 
    strides=(2,2),
    padding='valid'
)

In [187]:
input_batch = np.random.randn(32,28,28,3)

In [196]:
output_batch_keras = layer_keras(input_batch)

In [197]:
output_batch_numpy = downsample_images(
    input_batch, 
    pool_size=layer_keras.pool_size,
    strides=layer_keras.strides,
    pool_fn = np.max
)

To check that the Keras output and numpy outputs are approximately equal, I make sure that the absolute error is below a threshold.

In [198]:
assert np.all(np.isclose(output_batch_keras,output_batch_numpy, atol=1e-16))

Although the outputs are about the same, the Keras version runs much faster, as the following benchmarks show.    

In [194]:
%%timeit -n 100
downsample_images(
    input_batch, 
    pool_size=layer_keras.pool_size,
    strides=layer_keras.strides,
)

3.97 ms ± 216 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [111]:
%%timeit -n 100
layer_keras(input_batch)

666 µs ± 52.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Pooling Layer in JAX



In [112]:
import jax
import jax.numpy as jnp
from typing import Tuple, List, Callable
from fastcore.basics import patch

In [161]:
class MaxPool2D: 
    pool_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str

#### Constructor


In [162]:
@patch
def __init__(
    self: MaxPool2D, 
    pool_shape=(2,2), 
    strides=(1,1), 
    padding='valid'):
    
    self.pool_shape = pool_shape
    self.strides = strides 
    self.padding = padding
         

#### `__call__` Method

In [163]:
@patch
def __call__(self: MaxPool2D, batch: jnp.ndarray):
    
    pool_shape = (1, ) + self.pool_shape + (1, )
    strides = (1, ) + self.strides + (1, ) 
    outputs = jax.lax.reduce_window(
        batch,
        -jnp.inf,
        jax.lax.max,
        window_dimensions=pool_shape,
        window_strides=strides,
        padding=self.padding
    )

    return outputs 

#### Adding to pytree Registry
 

In [164]:
@patch
def tree_flatten(self: MaxPool2D) -> Tuple[List[jnp.ndarray], dict]:
    params = (None,)
    metadata = {k: v for k, v in self.__dict__.items()}
    return params, metadata

In [165]:
@patch(cls_method=True)
def tree_unflatten(cls: MaxPool2D, metadata: dict, params:List[jnp.ndarray]):
    # This assumes that each key-value pair in the metadata dict corresponds to a constructor argument.
    layer = cls(**metadata)
    return layer

In [166]:
_ = jax.tree_util.register_pytree_node_class(MaxPool2D)

## Compare JAX and Keras

In [167]:
layer_jax = MaxPool2D(
    pool_shape=layer_keras.pool_size,
    strides=layer_keras.strides,
    padding=layer_keras.padding
)

output_batch_jax = layer_jax(input_batch)

Because the following assertion passes, we can be reassured that both layers are calculating the same result.

In [168]:
assert np.max(np.abs(output_batch_keras - output_batch_jax)) < 1e-6

What about the calculation time?

In [169]:
%%timeit -n 100
layer_keras(input_batch)

669 µs ± 40 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [170]:
%%timeit -n 100
layer_jax(input_batch)

249 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Conclusion