# Convolutional Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: true

## Introduction

In this post, I'll start by implementing a basic convolutional layer using numpy and validate it against Keras.  After this, I'll write a more efficient one using JAX.

## How Convolutional Layers work

## Import Libraries

For now, I only need numpy and tensorflow.

In [18]:
import numpy as np
import tensorflow as tf

## Implementation from First Principles

This function filters a single image with every output filter and adds the bias term, resulting in a rank 3 array.  The first two levels of the nested loop extract a rank 3 chunk from `image`, while the third level of the nested loop performs the filtering and biasing.  After a chunk is processed and the results placed in the output array `y`, the filter shape and stride is used to calculate the next chunk position.  

In [331]:
def filter_image(image, filters, strides):
    
    xm, xn, _  = image.shape 
    
    km, kn, ni, no = filters.shape 
    
    
    sm, sn = strides
    #ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    ym, yn = 1 + ((xm - km)//sm), 1 + ((xn - kn)//sn) 
    y = np.zeros((ym, yn, no))

    for iy, ix in enumerate(range(0, xm-km+1, sm)):
        for jy, jx in enumerate(range(0, xn-kn+1, sn)):
            # Apply each output filter and bias term to this chunk
            chunk = image[ix:ix+km,jx:jx+kn,:]
            for channel in range(no):
                y[iy,jy,channel] = np.sum(filters[:,:,:,channel] * chunk)# + biases[channel]
            
    return y

Once we have an algorithm to filter a single image, it can easily be extended to a batch of images.

In [332]:
def filter_image_batch(batch, filters, biases, strides):
    outputs = [filter_image(image, filters, strides) for image in batch]
    outputs = np.array(outputs)
    outputs = outputs + biases
    return outputs

First, a list comprehension applies the `filter_image` function defined above, to each image in the batch.  Next, the list returned by the 
list comprehension, is converted to a rank 4 array with the `np.array` function.  The line preceding the `return` statement, 
```python
outputs = outputs + biases
```
seems like it shouldn't work, because the ranks don't match.  Fortunately,  numpy's broadcasting rules come to the rescue and does what we want.
    

## Compare to Keras

To compare the numpy version to Keras, I'm going to create a `Conv2D` layer:

In [408]:
layer_keras = tf.keras.layers.Conv2D(filters=4, kernel_size=(4, 4), strides=(1,1), bias_initializer='he_uniform', padding='valid')

initialize a random batch of fakey images:

In [398]:
input_batch = np.random.randn(2,28,28,3)

and filter the batch with the layer:

In [399]:
output_batch_keras = layer_keras(input_batch)

Next, the filters, biases and strides are extracted from the layer.  Note that `strides` doesn't really need to be accessed from the layer, it's in the `Conv2D` constructor after all.  The way I did it here is just less error-prone.

In [400]:
filters, biases = layer_keras.get_weights()
strides = layer_keras.strides

In [402]:
output_batch_numpy = filter_image_batch(input_batch, filters, biases, strides)

In [406]:
assert np.max(np.abs(output_batch_keras - output_batch_numpy)) < 1e-6 

### Convolutional Layer in JAX
In addition to the standard set of imports, I decided to import a function from fast.ai's fastcore library, called `patch`.  

In [623]:
import jax
import jax.numpy as jnp
from typing import Tuple, List
from fastcore.basics import patch

In [601]:
class Conv2D: 
    filters: jnp.ndarray 
    biases: jnp.ndarray
    input_channels: int 
    output_channels: int 
    filter_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str
    seed: int

#### Constructor

In [602]:
@patch
def __init__(self: Conv2D, input_channels, output_channels, filter_shape=(2,2), strides=(1,1), padding='valid', seed=1234, build=True):
    
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.filter_shape = filter_shape
    self.strides = strides 
    self.padding = padding
    self.seed = seed
        
    if build:
        key = jax.random.PRNGKey(seed)
        fkey, bkey = jax.random.split(key)
            
        # kaiming/he uniform, using Pytorch documentation
        K = input_channels * filter_shape[0] * filter_shape[1]
        sqrtK = jnp.sqrt(K)
        self.filters = jax.random.uniform(fkey, (*filter_shape, input_channels, output_channels), minval=-sqrtK, maxval=+sqrtK)
        self.biases = jax.random.uniform(bkey, (output_channels,), minval=-sqrtK, maxval=+sqrtK)   

#### The `__call__` Method

To implement `__call__` we use the JAX builtin function `conv_general_dilated`.  Except for the `dimension_numbers` argument, it's pretty easy to figure out what it's doing (but I'm still not clear on how it works - maybe save that for another post).  Like the Keras `Conv2D` layer, it has additional input arguments that give you further control over the convolution.  I'm not including these additional arguments here because I'm trying to keep things as simple as possible.  

In [603]:
@patch
def __call__(self: Conv2D, batch: jnp.ndarray):
    outputs = jax.lax.conv_general_dilated(
        lhs=batch,
        rhs=self.filters,
        window_strides=self.strides,
        padding=self.padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )   
    
    # This uses the broadcasting rules.
    outputs +=  biases
        
    # Need to add biases...
    return outputs 

The `dimension_numbers` field is a three element tuple that defines the shape layout of the input batch, the filters, and the output batch respectively.  We've adopted the default Keras layout, which means that for an input batch of images, the batch dimension is listed first, the image height second, the image width third, and the number of input channels fourth.  The dimension number for this is represented as `'NHWC'`.  By default, the filters are arranged in a similar way although there is no batch dimension: the filter height comes first, the filter width second, the input channel count third, and the output channel count last.  As you can see, the description number for this is `'HWIO'`.

Because `conv_general_dilatated` does not work with the biases, they must be added to the convolution outputs.  Like numpy, JAX has broadcasting rules that make this mixed-rank addition work as expected.

#### Adding to pytree Registry

In [619]:
@patch
def tree_flatten(self: Conv2D) -> Tuple[List[jnp.ndarray], dict]:
    params = (self.filters, self.biases)
    metadata = {k: v for k, v in self.__dict__.items() if k not in {'biases', 'filters'}}
    return params, metadata

In [622]:
@patch(cls_method=True)
def tree_unflatten(cls: Conv2D, metadata: dict, params:List[jnp.ndarray]):
    # This assumes that each key-value pair in the metadata dict corresponds to a constructor argument.
    layer = cls(**metadata, build=False)
    layer.filters, layer.biases = params
    
    return layer

Finally, `Conv2D` can be added the pytree registry with the following line of code: 

In [606]:
_ = jax.tree_util.register_pytree_node_class(Conv2D)

You'll get an exception complaining about duplicate registration if you run this cell twice.  Until I figure out how to remove from the pytree registry (if it's even possible), my work around is to re-run the cells (in this order) containing the class declaration, the methods, and finally the class registration.  

In [610]:
a, b = jax.tree_flatten(Conv2D(1,3))

## Conclusion

In [565]:
import inspect

args_name = inspect.signature(Conv2D)
for a in args_name.parameters:
    print(a), print(type(a))

input_channels
<class 'str'>
output_channels
<class 'str'>
filter_shape
<class 'str'>
strides
<class 'str'>
padding
<class 'str'>
seed
<class 'str'>
build
<class 'str'>
