# Two Dimensional Convolutional Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: false

## Introduction

In this post, I'll start by implementing a basic convolutional layer using numpy and validate it against Keras.  After this, I'll write a more efficient one using JAX.

## How Convolutional Layers work

A two dimensional convolutional layer consists of several randomly initialized filters, biases, and a rule for moving the filter across the input array.  Typically, the input array is rank 4, meaning that the shape has 4 components.  One of these components represents the number of individual images in the input,  two of them tell you the size of each of the images, and the fourth tells you the number of *input channels*.  Interpret the number of channels as the number of components describing a pixel.  For instance, a pixel in a gray-scale image is a number between 0 and 255 and therefore has one channel.  However, an RGB color image has 3 channels because it has a red component, green component, and blue component.

The number of filters and biases in the layer tell you the number of channels each image will have after application.  When definining a layer, you generally specify how many output channels, or filters, you want.  In Keras, it's the number of filters; in Pytorch it's the number of output channels.  Either way, the layer eventually generates one rank 3 filter and 1 scalar bias per output channel.  The   

## Import Libraries

For now, I only need numpy and tensorflow.

In [686]:
import numpy as np
import tensorflow as tf

## Implementation from First Principles

This function filters a single image with every output filter, resulting in a rank 3 array.  The first two levels of the nested loop extract a rank 3 chunk from `image`, while the third level of the nested loop applies each output filter to the chunk.  After a chunk is processed and the results placed in the output array `y`, the filter shape and stride step to the next chunk position.  

In [687]:
def filter_image(image, filters, strides):
    
    xm, xn, _  = image.shape 
    
    km, kn, ni, no = filters.shape 
    
    
    sm, sn = strides
    #ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    ym, yn = 1 + ((xm - km)//sm), 1 + ((xn - kn)//sn) 
    y = np.zeros((ym, yn, no))

    for iy, ix in enumerate(range(0, xm-km+1, sm)):
        for jy, jx in enumerate(range(0, xn-kn+1, sn)):
            # Apply each output filter and bias term to this chunk
            chunk = image[ix:ix+km,jx:jx+kn,:]
            for channel in range(no):
                y[iy,jy,channel] = np.sum(filters[:,:,:,channel] * chunk)
            
    return y

Once we have an algorithm to filter a single image, it can easily be extended to a batch of images.

In [688]:
def filter_image_batch(batch, filters, biases, strides):
    outputs = [filter_image(image, filters, strides) for image in batch]
    outputs = np.array(outputs)
    outputs = outputs + biases
    return outputs

First, a list comprehension applies the `filter_image` function defined above, to each image in the batch.  Next, the list returned by the 
list comprehension, is converted to a rank 4 array with the `np.array` function.  The line preceding the `return` statement, 
```python
outputs = outputs + biases
```
seems like it shouldn't work, because the ranks don't match.  Fortunately,  numpy's broadcasting rules come to the rescue and does what we want.
    

## Compare to Keras

To compare the numpy version to Keras, I'm going to create a `Conv2D` layer:

In [689]:
layer_keras = tf.keras.layers.Conv2D(
    filters=4, 
    kernel_size=(4, 4), 
    strides=(1,1), 
    bias_initializer='he_uniform', 
    padding='valid'
)

initialize a random batch of fakey images:

In [690]:
input_batch = np.random.randn(2,28,28,3)

and filter the batch with the layer:

In [691]:
output_batch_keras = layer_keras(input_batch)

Next, the filters, biases and strides are extracted from the layer.  Note that `strides` doesn't really need to be accessed from the layer, it's in the `Conv2D` constructor after all.  The way I did it here is just less error-prone.

In [692]:
filters, biases = layer_keras.get_weights()
strides = layer_keras.strides

Now all the inputs can be passed to the `filter_image_batch` implemented earlier.

In [693]:
output_batch_numpy = filter_image_batch(input_batch, filters, biases, strides)

To check that the Keras output and numpy outputs are approximately equal, I make sure that the absolute error is below a threshold.

In [694]:
assert np.max(np.abs(output_batch_keras - output_batch_numpy)) < 1e-6 

Although the outputs are about the same, the Keras version runs much faster, as the following benchmarks show.    

In [695]:
%%timeit
filter_image_batch(input_batch, filters, biases, strides)

23 ms ± 888 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [696]:
%%timeit
layer_keras(input_batch)

264 µs ± 683 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


I can't say that these timing results are a surprise.  Remember how my numpy version of convolution has a three-level nested loop?  Well, this leads to very poor performance.  Unfortunately, to write fast python programs, a lot of its syntax and functionality (like loops) must be avoided in favor of calling wrappers for optimized C code.  This is precisely what Keras does.  Later on, we'll see that the JAX version is competitive with Keras.

### Convolutional Layer in JAX

In [623]:
import jax
import jax.numpy as jnp
from typing import Tuple, List
from fastcore.basics import patch

In addition to the standard set of imports, I'm also importing `patch` from [fastcore](https://fastcore.fast.ai/).  It's selling point is that it contains

> Python goodies to make your coding faster, easier, and more maintainable.

The nice thing about `patch`, is that it allows you to write methods outside of class definitions.  This is particularly useful if you're interested in incrementally developing, and explaining, class functionality in notebooks.  Without `patch`, you'd have to either have an entire class implementation in a single cell, or abandon classes altogether and use functions.


In [672]:
class Conv2D: 
    filters: jnp.ndarray 
    biases: jnp.ndarray
    input_channels: int 
    output_channels: int 
    filter_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str
    seed: int

#### Constructor

The constructor is pretty self explanatory.  Like the layers implemented in the previous post, the `build` parameter determines whether or not the filters and biases get initialized.  For now, the filters and biases follow a kaiming-uniform distribution.  This is the default initializer in Pytorch, so I figured it would be effective.  Can't say I know why at this point, but Keras has a different default initialization approach.

In [673]:
@patch
def __init__(
    self: Conv2D, 
    input_channels, 
    output_channels, 
    filter_shape=(2,2), 
    strides=(1,1), 
    padding='valid', 
    seed=1234, 
    build=True):
    
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.filter_shape = filter_shape
    self.strides = strides 
    self.padding = padding
    self.seed = seed
        
    if build:
        key = jax.random.PRNGKey(seed)
        fkey, bkey = jax.random.split(key)
            
        # kaiming/he uniform, using Pytorch documentation
        K = 1 / (input_channels * filter_shape[0] * filter_shape[1])
        sqrtK = jnp.sqrt(K)
        self.filters = jax.random.uniform(
            fkey, 
            (*filter_shape, input_channels, output_channels), 
            minval=-sqrtK, maxval=+sqrtK
        )
        
        self.biases = jax.random.uniform(
            bkey, 
            (output_channels,), 
            minval=-sqrtK, maxval=+sqrtK
        )   

#### `__call__` Method

To implement `__call__` we use the JAX builtin function `conv_general_dilated`.  Except for the `dimension_numbers` argument, it's pretty easy to figure out what it's doing (but I'm still not clear on how it works - maybe save that for another post).  Like the Keras `Conv2D` layer, it has additional input arguments that give you further control over the convolution.  I'm not including these additional arguments here because I'm trying to keep things as simple as possible.  

In [674]:
@patch
def __call__(self: Conv2D, batch: jnp.ndarray):
    outputs = jax.lax.conv_general_dilated(
        lhs=batch,
        rhs=self.filters,
        window_strides=self.strides,
        padding=self.padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )   
    
    # This uses the broadcasting rules.
    outputs +=  biases
        
    # Need to add biases...
    return outputs 

The `dimension_numbers` field is a three element tuple describing the shape of the input batch, the filters, and the output batch respectively.  We've adopted the default Keras layout, which means that for an input batch of images, the batch dimension is listed first, the image height second, the image width third, and the number of input channels fourth.  The dimension number for this is represented as `'NHWC'`.  

By default, the filters are arranged in a similar way although there is no batch dimension: the filter height comes first, the filter width second, the input channel count third, and the output channel count last.  The description number for this is `'HWIO'`.

Because `conv_general_dilatated` does not work with the biases, they must be added in separately.  Like numpy, JAX has broadcasting rules that make this mixed-rank addition work properly.

#### Adding to pytree Registry

Like the `Linear` and `Function` layers defined in my last post, `tree_flatten` and `tree_unflatten` methods must be defined.  

In [675]:
@patch
def tree_flatten(self: Conv2D) -> Tuple[List[jnp.ndarray], dict]:
    params = (self.filters, self.biases)
    metadata = {k: v for k, v in self.__dict__.items() if k not in {'biases', 'filters'}}
    return params, metadata

In [676]:
@patch(cls_method=True)
def tree_unflatten(cls: Conv2D, metadata: dict, params:List[jnp.ndarray]):
    # This assumes that each key-value pair in the metadata dict corresponds to a constructor argument.
    layer = cls(**metadata, build=False)
    layer.filters, layer.biases = params
    return layer

You'll notice that I'm trying to be a little more generic here.  Rather than use a tuple to store the layer's metadata, I'm using a dictionary that contains every data attribute, except the biases and filters.  Because these attributes correspond to arguments to the constructor, `metadata` is passed to `Conv2D`'s constructor in `tree_unflatten`.  

Now that JAX knows how to flatten and unflatten a `Conv2D` layer, it can formally be can be added to the pytree registry: 

In [677]:
_ = jax.tree_util.register_pytree_node_class(Conv2D)

Last time, `register_pytree_node_class` was used as a class decorator.  We can get away with using it as a function because that's all decorators are: special types of functions.  The reason the decorator approarch could not be used here is that the cell with the class definition didn't include the `tree_flatten` and `tree_unflatten` methods.  If you add the decorator and run the cell, JAX will complain and point out that these methods are not defined.

In [678]:
output_batch_jax = layer_jax(input_batch)

## Compare JAX and Keras

To compare the JAX and Keras convolutional layer implementations, the JAX `Conv2D` layer is initialized with data from the Keras layer, and applied to the same input batch that was used previously.

In [728]:
layer_jax = Conv2D(
    input_channels=layer_keras.input_spec.axes[-1],
    output_channels=layer_keras.filters,
    filter_shape=layer_keras.kernel_size,
    strides=layer_keras.strides,
    padding=layer_keras.padding
)

layer_jax.filters, layer_jax.biases = layer_keras.get_weights()

output_batch_jax = layer_jax(input_batch)

Because the following assertion passes, we can be reassured that both layers are calculating the same result.

In [729]:
assert np.max(np.abs(output_batch_keras - output_batch_jax)) < 1e-6

What about the calculation time?

In [730]:
%%timeit
layer_keras(input_batch)

272 µs ± 8.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [731]:
%%timeit
layer_jax(input_batch)

362 µs ± 5.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Not too bad, the JAX version only about 100 microseconds slower than Keras.  Fortunately, the JAX version can go even faster after applying the `jax.jit` transformation.

In [732]:
layer_jax_jitted = jax.jit(layer_jax)

In [733]:
%%timeit
layer_jax_jitted(input_batch)

223 µs ± 1.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


Now it's faster than Keras.  I realize you can't trust micro-benchmarks, but it's good to know that with respect to performance, JAX and Keras seem to be in the same ballpark.

## Conclusion

In the post, I implemented the mechanics of a 2D convolutional layer in numpy, from first principles.  After this was proven to be very inefficient compare to Keras, I built a simple layer in JAX and showed that it performed as well as the Keras version on a sample input batch.