# Convolutional Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: true

## Introduction

In this post, I'll start by implementing a basic convolutional layer using numpy and validate it against Keras.  After this, I'll write a more efficient one using JAX.

## How Convolutional Layers work

## Import Libraries

For now, I only need numpy and tensorflow.

In [18]:
import numpy as np
import tensorflow as tf

## Implementation from First Principles

This function filters a single image with every output filter and adds the bias term, resulting in a rank 3 array.  The first two levels of the nested loop extract a rank 3 chunk from `image`, while the third level of the nested loop performs the filtering and biasing.  After a chunk is processed and the results placed in the output array `y`, the filter shape and stride is used to calculate the next chunk position.  

In [331]:
def filter_image(image, filters, strides):
    
    xm, xn, _  = image.shape 
    
    km, kn, ni, no = filters.shape 
    
    
    sm, sn = strides
    #ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    ym, yn = 1 + ((xm - km)//sm), 1 + ((xn - kn)//sn) 
    y = np.zeros((ym, yn, no))

    for iy, ix in enumerate(range(0, xm-km+1, sm)):
        for jy, jx in enumerate(range(0, xn-kn+1, sn)):
            # Apply each output filter and bias term to this chunk
            chunk = image[ix:ix+km,jx:jx+kn,:]
            for channel in range(no):
                y[iy,jy,channel] = np.sum(filters[:,:,:,channel] * chunk)# + biases[channel]
            
    return y

Once we have an algorithm to filter a single image, it's very simple to filter a batch of images.  Here's the code:

In [332]:
def filter_image_batch(batch, filters, biases, strides):
    outputs = [filter_image(image, filters, strides) for image in batch]
    outputs = np.array(outputs)
    outputs = outputs + biases
    return outputs

First, a list comprehension applies the `filter_image` function defined above, to each image in the batch.  Next, the list returned by the 
list comprehension, is converted to a rank 4 array with the `np.array` function.  The line preceding the `return` statement, 
```python
outputs = outputs + biases
```
seems like it shouldn't work, because the ranks don't match.  Fortunately,  numpy's broadcasting rules come to the rescue and does what we want.
    

## Compare to Keras

In [408]:
layer_keras = tf.keras.layers.Conv2D(filters=4, kernel_size=(4, 4), strides=(1,1), bias_initializer='he_uniform', padding='valid')

In [398]:
input_batch = np.random.randn(2,28,28,3)

In [399]:
output_batch_keras = layer_keras(input_batch)

In [400]:
filters, biases = layer_keras.get_weights()
strides = layer_keras.strides

In [402]:
output_batch_numpy = filter_image_batch(input_batch, filters, biases, strides)

In [406]:
assert np.max(np.abs(output_batch_keras - output_batch_numpy)) < 1e-6 

### Convolutional Layer in JAX

In [456]:
import jax
import jax.numpy as jnp
from fastcore.basics import patch, store_attr
from typing import Tuple

In [526]:
class Conv2D: 
    filters: jnp.ndarray 
    biases: jnp.ndarray
    input_channels: int 
    output_channels: int 
    filter_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str
    seed: int

In [527]:
@patch
def __init__(self: Conv2D, input_channels, output_channels, filter_shape=(2,2), strides=(1,1), padding='valid', seed=1234, build=True):
    
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.filter_shape = filter_shape
    self.strides = strides 
    self.padding = padding
    self.seed = seed
        
    if build:
        key = jax.random.PRNGKey(seed)
        fkey, bkey = jax.random.split(key)
            
        # kaiming/he uniform, using Pytorch documentation
        K = input_channels * filter_shape[0] * filter_shape[1]
        sqrtK = jnp.sqrt(K)
        self.filters = jax.random.uniform(fkey, (*filter_shape, input_channels, output_channels), minval=-sqrtK, maxval=+sqrtK)
        self.biases = jax.random.uniform(bkey, (output_channels,), minval=-sqrtK, maxval=+sqrtK)   

The `__call__` method uses a JAX function called `conv_general_dilated`.  Except for the `dimension_numbers` argument, it's pretty easy to figure out what it's doing.  Like the Keras `Conv2D` layer, it has additional input arguments that give you further control over the convolution.  I'm not including these additional arguments here because I'm trying to keep things as simple as possible.  

In [542]:
@patch
def __call__(self: Conv2D, batch: jnp.ndarray):
    outputs = jax.lax.conv_general_dilated(
        lhs=batch,
        rhs=self.filters,
        window_strides=self.strides,
        padding=self.padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )   
    
    # This uses the broadcasting rules.
    outputs +=  biases
        
    # Need to add biases...
    return outputs 

The `dimension_numbers` field is a three element tuple that defines the shape layout of the input batch, the filters, and the output batch respectively.  We've adopted the default Keras layout, which means that for an input batch of images, the batch dimension first, the image height second, the image width third, and the number of input channels fourth.  The dimension number for this is represented as `'NHWC'`.  By default, the filters are arranged in a similar way although there is no batch dimension: the filter height comes first, the filter width second, the input channel count third, and the output channel count last.  As you can see, the description number for this is `'HWIO'`.

Because `conv_general_dilatated` does not work with the biases, they must be added to the convolution outputs.  Like numpy, JAX has broadcasting rules that make this mixed-rank addition work as expected.

In [530]:
@patch
def tree_flatten(self: Conv2D):
    params = (self.filters, self.biases)
    metadata = {k: v for k, v in self.__dict__.items() if k not in {'biases', 'filters'}}
    return params, metadata

In [531]:
@patch(cls_method=True)
def tree_unflatten(cls: Conv2D, metadata, params):
    
    layer = cls(**metadata, build=False)
    layer.filters, layer.biases = params
    
    return layer

Finally, register the `Conv2D` class as a pytree.  The `register_pytree_node_class` checks that the class name you plug in has a `tree_flatten` method and a `tree_unflatten` class method.  If you don't, JAX will raise an exception.  You'll get a different exception if you try to register the same class twice.  I haven't figured out how to remove a class from the pytree registry yet.  When working in Jupyter notebooks, my work around is to re-run the cell that has the class definition.  My guess is that this error doesn't show up when working in a text-editor or IDE.

In [None]:
??jax.tree_util.register_pytree_node_class

In [532]:
jax.tree_util.register_pytree_node_class(Conv2D)

__main__.Conv2D

In [533]:
cc = Conv2D(2,3)

In [534]:
a, b = jax.tree_flatten(cc)

In [535]:
dd = jax.tree_unflatten(b,a)

In [536]:
cc.__dict__

{'input_channels': 2,
 'output_channels': 3,
 'filter_shape': (2, 2),
 'strides': (1, 1),
 'padding': 'valid',
 'seed': 1234,
 'filters': DeviceArray([[[[ 2.0073094 ,  1.6691824 , -0.52152306],
                [ 0.01462327,  0.3469043 ,  2.7000263 ]],
 
               [[-2.5442414 ,  2.7617264 , -0.4976167 ],
                [ 0.39192185, -0.4488464 , -1.8725433 ]]],
 
 
              [[[ 0.5440059 ,  1.5025641 , -2.4879715 ],
                [-1.8504584 , -0.22452806,  2.825495  ]],
 
               [[-0.43859494, -1.9360245 , -1.3612974 ],
                [-0.29709414, -1.0776254 ,  0.91766566]]]], dtype=float32),
 'biases': DeviceArray([-1.5805513,  1.9187174,  2.2878523], dtype=float32)}

In [537]:
dd.__dict__

{'input_channels': 2,
 'output_channels': 3,
 'filter_shape': (2, 2),
 'strides': (1, 1),
 'padding': 'valid',
 'seed': 1234,
 'filters': DeviceArray([[[[ 2.0073094 ,  1.6691824 , -0.52152306],
                [ 0.01462327,  0.3469043 ,  2.7000263 ]],
 
               [[-2.5442414 ,  2.7617264 , -0.4976167 ],
                [ 0.39192185, -0.4488464 , -1.8725433 ]]],
 
 
              [[[ 0.5440059 ,  1.5025641 , -2.4879715 ],
                [-1.8504584 , -0.22452806,  2.825495  ]],
 
               [[-0.43859494, -1.9360245 , -1.3612974 ],
                [-0.29709414, -1.0776254 ,  0.91766566]]]], dtype=float32),
 'biases': DeviceArray([-1.5805513,  1.9187174,  2.2878523], dtype=float32)}

In [416]:
layer = Conv2D(3,4)

In [417]:
layer.tree_flatten()

((DeviceArray([[[[ 3.4430575 , -0.5913736 ,  1.7550349 ,  2.4153824 ],
                 [ 2.8320181 , -0.85370946, -1.1039739 , -0.12423529],
                 [ 2.7139978 , -0.70104486, -3.4376502 ,  2.6779363 ]],
  
                [[-2.0729692 ,  0.3460869 ,  1.6827095 , -1.5708281 ],
                 [ 3.45853   , -0.8395171 , -1.9809046 ,  2.2974694 ],
                 [ 0.7820349 ,  1.3051779 , -0.6155949 , -1.9725654 ]]],
  
  
               [[[ 0.04273734,  0.6269528 ,  0.33181936, -0.38829234],
                 [-0.02845082, -0.33282036, -3.160122  ,  1.999361  ],
                 [-0.4157628 ,  1.7358813 ,  2.6174173 ,  3.3147407 ]],
  
                [[-2.8882236 ,  0.46306247, -2.8370042 ,  3.4490387 ],
                 [ 2.8144124 , -0.44645762, -1.2920336 ,  2.0018265 ],
                 [ 3.3088658 ,  2.9831839 , -0.35555673,  1.3202416 ]]]],            dtype=float32),
  DeviceArray([-1.9357721 ,  0.32821512,  2.8020353 , -3.1447244 ], dtype=float32)),
 (3, 4, (1, 1), '

In [268]:
x = jnp.zeros((2,4,4,2))
b = jnp.array([1,2])
y = x + b

In [271]:
y[1,:,:,1]

DeviceArray([[2., 2., 2., 2.],
             [2., 2., 2., 2.],
             [2., 2., 2., 2.],
             [2., 2., 2., 2.]], dtype=float32)

In [354]:
??tf.keras.layers.Conv2D

[0;31mInit signature:[0m [0mtf[0m[0;34m.[0m[0mkeras[0m[0;34m.[0m[0mlayers[0m[0;34m.[0m[0mConv2D[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mConv2D[0m[0;34m([0m[0mConv[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""2D convolution layer (e.g. spatial convolution over images).[0m
[0;34m[0m
[0;34m  This layer creates a convolution kernel that is convolved[0m
[0;34m  with the layer input to produce a tensor of[0m
[0;34m  outputs. If `use_bias` is True,[0m
[0;34m  a bias vector is created and added to the outputs. Finally, if[0m
[0;34m  `activation` is not `None`, it is applied to the outputs as well.[0m
[0;34m[0m
[0;34m  When using this layer as the first layer in a model,[0m
[0;34m  provide the keyword argument `input_shape`[0m
[0;34m  (tuple of integers or `None`, does not include the sample axis),[0m
[0;34m  e.g.