# Two Dimensional Pooling Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, pooling]
- hide: false

## Introduction

In this post, I'll be implementing a 2D max-pooling layer from scratch and in JAX.  As before, we'll compare each version to Keras.

## Purpose of Pooling Layers

Pooling layers reduce the number of parameters in convolutional neural networks by downsampling the feature maps generated by the convolutional layers.  Under the hood, they are actually very similar to convolutional layers.  Both layers work by stepping through an input image based on a stride, extracting a chunk of a specified size,  and compressing 3D chunk with some sort of computation.  The main difference between the two layers is the computation.  Recall from the last post that convolutional layers compresses a chunk to a single number by multiplying it's elements with a set of filter coefficients, summing the products, and adding a bias term.  There aren't any filters nor bias terms in pooling layers.  They use simple functions (e.g. `max`, `mean`, etc) that don't depend on any learnable parameters to compress each chunk.  

Unlike convolutional layers, pooling layer retain the number of input channels in the input image.  So you can think of these *pooling functions* as being applied to each channel separately and merged back together.    

## Implementation from First Principles

In [211]:
import numpy as np
import tensorflow as tf

Here's an implementation of a pooling layer.  You can configure how each chunk gets downsampled by choosing different pooling function, `pool_fn`. Because `pool_fn` is set to `np.max` in the argument list, max-pooling is executed by default.

In [212]:
def downsample_images(input_batch, pool_size, strides, pool_fn=np.max):
    batch_size, xm, xn, num_channels  = input_batch.shape 
    
    pm, pn = pool_size 
    sm, sn = strides
    
    #ym, yn = 1 + ((xm - pm)//sm), 1 + ((xn - pn)//sn) 
    ym, yn = 1 + (xm - (pm-1) - 1)/sm,  1 + (xn - (pn-1) - 1)/sn
    ym, yn = np.int(np.ceil(ym)), np.int(np.ceil(yn))
    y = np.zeros((batch_size, ym, yn, num_channels))

    for h, ix in enumerate(range(0, xm-pm+1, sm)):
        for w, jx in enumerate(range(0, xn-pn+1, sn)):
            chunk = input_batch[:, ix:ix+pm, jx:jx+pn, :]
            y[:, h, w, :] = pool_fn(chunk, axis=(1,2))
    return y    

This looks a lot like the convolution functions implemented in the last post.  It extracts a chunk from the 4D array array and executes `pool_fn` for every batch and channel independently. The line
```python
y[:, h, w, :] = pool_fn(chunk, axis=(1,2))       
```
takes advantage of numpy's vectorization capabilities.  What this says is that the `pool_fn` is evaluated with respect to the height and width axes As far as the calculation goes, the batch and channel dimensions are ignored, and only serve as place holders for the result.  This means
that for each `chunk`, the value `pool_fn(chunk, axis=(1,2))` is a 2D array where the number of rows is equal to the number of items in the batch and the number of columns is equal to the number of channels.

Here's a snippet of code that does the same thing, just very slowly

In [213]:
 def slow_max_pool(chunk):
    num_items, height, width, num_channels = chunk.shape
    result = np.zeros((num_items, num_channels))
    for item in range(num_items):
        for chan in range(num_channels):
            max_val = -np.infty
            for h in range(height):
                for w in range(width):
                    max_val = max(chunk[item, h, w, chan], max_val)
            result[item, chan] = max_val
    return result
                    

Just to make sure, let's compare the two versions. We'll define a random chunk, check that the vectorized and non-vectorized chunk-pooling give the same answer, and see how long each one takes to execute.

In [214]:
chunk = np.random.randn(32, 28, 28, 3)
result_vectorized = np.max(chunk, axis=(1,2))
result_non_vectorized = slow_max_pool(chunk)

assert np.all(np.isclose(result_non_vectorized, result_vectorized, atol=1e-6))

%timeit -n 100 np.max(chunk, axis=(1,2))
%timeit -n 100 slow_max_pool(chunk)

474 µs ± 5.76 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 ms ± 393 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


No surprise that `slow_max_pool` is significantly slower than the vectorized version.

## Compare to Keras

In [215]:
layer_keras = tf.keras.layers.MaxPool2D( 
    pool_size=(4, 4), 
    strides=(2,2),
    padding='valid'
)

In [216]:
input_batch = np.random.randn(32,28,28,3)

In [217]:
output_batch_keras = layer_keras(input_batch)

In [218]:
output_batch_numpy = downsample_images(
    input_batch, 
    pool_size=layer_keras.pool_size,
    strides=layer_keras.strides,
    pool_fn = np.max
)

To check that the Keras output and numpy outputs are approximately equal, I make sure that the absolute error is below a threshold.

In [219]:
assert np.all(np.isclose(output_batch_keras,output_batch_numpy, atol=1e-16))

Although the outputs are about the same, the Keras version runs much faster, as the following benchmarks show.    

In [220]:
%%timeit -n 100
downsample_images(
    input_batch, 
    pool_size=layer_keras.pool_size,
    strides=layer_keras.strides,
)

4.03 ms ± 17.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [221]:
%%timeit -n 100
layer_keras(input_batch)

695 µs ± 41.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Pooling Layer in JAX

In [222]:
import jax
import jax.numpy as jnp
from typing import Tuple, List, Callable
from fastcore.basics import patch

In [223]:
class MaxPool2D: 
    pool_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str

#### Constructor


In [224]:
@patch
def __init__(
    self: MaxPool2D, 
    pool_shape=(2,2), 
    strides=(1,1), 
    padding='valid'):
    
    self.pool_shape = pool_shape
    self.strides = strides 
    self.padding = padding
         

#### `__call__` Method

In [227]:
@patch
def __call__(self: MaxPool2D, batch: jnp.ndarray):
    
    outputs = jax.lax.reduce_window(
        batch,
        -jnp.inf,
        jax.lax.max,
        window_dimensions=(1, ) + self.pool_shape + (1, ),
        window_strides= (1, ) + self.strides + (1, ) ,
        padding=self.padding
    )

    return outputs 

#### Registration

In [228]:
@patch
def tree_flatten(self: MaxPool2D) -> Tuple[List[jnp.ndarray], dict]:
    params = (None,)
    metadata = {k: v for k, v in self.__dict__.items()}
    return params, metadata

In [229]:
@patch(cls_method=True)
def tree_unflatten(cls: MaxPool2D, metadata: dict, params:List[jnp.ndarray]):
    # This assumes that each key-value pair in the metadata dict corresponds to a constructor argument.
    layer = cls(**metadata)
    return layer

In [230]:
_ = jax.tree_util.register_pytree_node_class(MaxPool2D)

## Compare JAX and Keras

In [231]:
layer_jax = MaxPool2D(
    pool_shape=layer_keras.pool_size,
    strides=layer_keras.strides,
    padding=layer_keras.padding
)

output_batch_jax = layer_jax(input_batch)

Because the following assertion passes, we can be reassured that both layers are calculating the same result.

In [232]:
assert np.all(np.isclose(output_batch_keras, output_batch_jax, atol=1e-6))

What about the calculation time?

In [233]:
%%timeit -n 100
layer_keras(input_batch)

687 µs ± 68.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [234]:
%%timeit -n 100
layer_jax(input_batch)

249 µs ± 24.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Conclusion