# Convolutional Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: true

## Introduction

In this post, I'll start by implementing a basic convolutional layer using numpy and validate it against Keras.  After this, I'll write a more efficient one using JAX.

## How Convolutional Layers work

## Import Libraries

For now, I only need numpy and tensorflow.

In [18]:
import numpy as np
import tensorflow as tf

## Implementation from First Principles

This function filters a single image with every output filter and adds the bias term, resulting in a rank 3 array.  The first two levels of the nested loop extract a rank 3 chunk from `image`, while the third level of the nested loop performs the filtering and biasing.  After a chunk is processed and the results placed in the output array `y`, the filter shape and stride is used to calculate the next chunk position.  

In [331]:
def filter_image(image, filters, strides):
    
    xm, xn, _  = image.shape 
    
    km, kn, ni, no = filters.shape 
    
    
    sm, sn = strides
    #ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    ym, yn = 1 + ((xm - km)//sm), 1 + ((xn - kn)//sn) 
    y = np.zeros((ym, yn, no))

    for iy, ix in enumerate(range(0, xm-km+1, sm)):
        for jy, jx in enumerate(range(0, xn-kn+1, sn)):
            # Apply each output filter and bias term to this chunk
            chunk = image[ix:ix+km,jx:jx+kn,:]
            for channel in range(no):
                y[iy,jy,channel] = np.sum(filters[:,:,:,channel] * chunk)# + biases[channel]
            
    return y

Once we have an algorithm to filter a single image, it's very simple to filter a batch of images.  Here's the code:

In [332]:
def filter_image_batch(batch, filters, biases, strides):
    outputs = [filter_image(image, filters, strides) for image in batch]
    outputs = np.array(outputs)
    outputs = outputs + biases
    return outputs

First, the list comprehension passed to the `np.array` function, i.e.

```python
[filter_image(image, filters, biases, strides) for image in batch]
```
applies `filter_image` function definied above to each image in the batch, resulting in a list of filtered images.  By passing this list to the `np.array` function, it's converted to an `ndarray` with a leading batch dimension.
    

## Compare to Keras

In [408]:
layer_keras = tf.keras.layers.Conv2D(filters=4, kernel_size=(4, 4), strides=(1,1), bias_initializer='he_uniform', padding='valid')

In [398]:
input_batch = np.random.randn(2,28,28,3)

In [399]:
output_batch_keras = layer_keras(input_batch)

In [400]:
filters, biases = layer_keras.get_weights()
strides = layer_keras.strides

In [402]:
output_batch_numpy = filter_image_batch(input_batch, filters, biases, strides)

In [406]:
assert np.max(np.abs(output_batch_keras - output_batch_numpy)) < 1e-6 

### Convolutional Layer in JAX

In [226]:
import jax
import jax.numpy as jnp
from fastcore.basics import patch
from typing import Tuple

In [442]:
class Conv2D: 
    filters: jnp.ndarray 
    biases: jnp.ndarray
    input_channels: int 
    output_channels: int 
    filter_shape: Tuple[int,int]
    strides: Tuple[int,int]
    padding: str
    seed: int

In [441]:
jax.tree_util.register_pytree_node_class(Conv2D)

ValueError: Duplicate custom PyTreeDef type registration for <class '__main__.Conv2D'>.

In [444]:
@patch
def __init__(self: Conv2D, input_channels, output_channels, filter_shape=(2,2), strides=(1,1), padding='valid', build=True, seed=1234):
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.strides = strides 
    self.padding = padding
    self.build = build 
    self.seed = seed
        
    if build:
        key = jax.random.PRNGKey(seed)
        fkey, bkey = jax.random.split(key)
            
        # kaiming/he uniform...
        K = input_channels * filter_shape[0] * filter_shape[1]
        sqrtK = jnp.sqrt(K)
        self.filters = jax.random.uniform(fkey, (*filter_shape, input_channels, output_channels), minval=-sqrtK, maxval=+sqrtK)
        self.biases = jax.random.uniform(bkey, (output_channels,), minval=-sqrtK, maxval=+sqrtK)   

In [411]:
@patch
def __call__(self: Conv2D, batch: jnp.ndarray):
    filtered_outputs = jax.lax.conv_general_dilated(
        lhs=batch,
        rhs=self.filters,
        window_strides=self.strides,
        padding=self.padding,
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )   
    
    # This uses the broadcasting rules.
    outputs = filtered_outputs + biases
        
    # Need to add biases...
    return filtered_outputs 

In [437]:
@patch
def tree_flatten(self: Conv2D):
    params = (self.filters, self.biases)
    metadata = {
        'input_channels': self.input_channels,
        'output_channels': self.output_channels,
        'filter_shape': self.filter_shape,
        'strides': self.strides,
        'padding': self.padding,
        'seed': self.seed,
        'build': False
    }
    return params, metadata

In [438]:
@patch(cls_method=True)
def tree_unflatten(cls: Conv2D, metadata, params):
    
    layer = cls(**metadata)
    layer.filters, layers.biases = layer.params
    
    return layer

In [449]:
cc = Conv2D(2,3)

In [451]:
a, b = jax.tree_flatten(cc)

In [452]:
dd = jax.tree_unflatten(b,a)

In [453]:
cc.__dict__

{'input_channels': 2,
 'output_channels': 3,
 'strides': (1, 1),
 'padding': 'valid',
 'build': True,
 'seed': 1234,
 'filters': DeviceArray([[[[ 2.0073094 ,  1.6691824 , -0.52152306],
                [ 0.01462327,  0.3469043 ,  2.7000263 ]],
 
               [[-2.5442414 ,  2.7617264 , -0.4976167 ],
                [ 0.39192185, -0.4488464 , -1.8725433 ]]],
 
 
              [[[ 0.5440059 ,  1.5025641 , -2.4879715 ],
                [-1.8504584 , -0.22452806,  2.825495  ]],
 
               [[-0.43859494, -1.9360245 , -1.3612974 ],
                [-0.29709414, -1.0776254 ,  0.91766566]]]], dtype=float32),
 'biases': DeviceArray([-1.5805513,  1.9187174,  2.2878523], dtype=float32)}

In [455]:
dd == cc

True

In [432]:
jax.tree_util.register_pytree_node(Conv2D, op.methodcaller('tree_flatten'), Conv2D.tree_unflatten)

ValueError: Duplicate custom PyTreeDef type registration for <class '__main__.Conv2D'>.

In [430]:
import operator as op

In [431]:
op.methodcaller??

[0;31mInit signature:[0m [0mop[0m[0;34m.[0m[0mmethodcaller[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m/[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
methodcaller(name, ...) --> methodcaller object

Return a callable object that calls the given method on its operand.
After f = methodcaller('name'), the call f(r) returns r.name().
After g = methodcaller('name', 'date', foo=1), the call g(r) returns
r.name('date', foo=1).
[0;31mSource:[0m        
[0;32mclass[0m [0mmethodcaller[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""[0m
[0;34m    Return a callable object that calls the given method on its operand.[0m
[0;34m    After f = methodcaller('name'), the call f(r) returns r.name().[0m
[0;34m    After g = methodcaller('name', 'date', foo=1), the call g(r) returns[0m
[0;34m    r.name('date', foo=1).[0m
[0;34m    """[0m[0;34m[0m
[0;34m[0m    [0m__slots__[0m

In [427]:
??jax.tree_util.register_pytree_node_class

[0;31mSignature:[0m [0mjax[0m[0;34m.[0m[0mtree_util[0m[0;34m.[0m[0mregister_pytree_node_class[0m[0;34m([0m[0mcls[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mregister_pytree_node_class[0m[0;34m([0m[0mcls[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Extends the set of types that are considered internal nodes in pytrees.[0m
[0;34m[0m
[0;34m  This function is a thin wrapper around ``register_pytree_node``, and provides[0m
[0;34m  a class-oriented interface::[0m
[0;34m[0m
[0;34m    @register_pytree_node_class[0m
[0;34m    class Special:[0m
[0;34m      def __init__(self, x, y):[0m
[0;34m        self.x = x[0m
[0;34m        self.y = y[0m
[0;34m      def tree_flatten(self):[0m
[0;34m        return ((self.x, self.y), None)[0m
[0;34m      @classmethod[0m
[0;34m      def tree_unflatten(cls, aux_data, children):[0m
[0;34m        return cls(*children)[0m
[0;34m  """[0m[0;34m[0m
[0;34m[0m  [

In [416]:
layer = Conv2D(3,4)

In [417]:
layer.tree_flatten()

((DeviceArray([[[[ 3.4430575 , -0.5913736 ,  1.7550349 ,  2.4153824 ],
                 [ 2.8320181 , -0.85370946, -1.1039739 , -0.12423529],
                 [ 2.7139978 , -0.70104486, -3.4376502 ,  2.6779363 ]],
  
                [[-2.0729692 ,  0.3460869 ,  1.6827095 , -1.5708281 ],
                 [ 3.45853   , -0.8395171 , -1.9809046 ,  2.2974694 ],
                 [ 0.7820349 ,  1.3051779 , -0.6155949 , -1.9725654 ]]],
  
  
               [[[ 0.04273734,  0.6269528 ,  0.33181936, -0.38829234],
                 [-0.02845082, -0.33282036, -3.160122  ,  1.999361  ],
                 [-0.4157628 ,  1.7358813 ,  2.6174173 ,  3.3147407 ]],
  
                [[-2.8882236 ,  0.46306247, -2.8370042 ,  3.4490387 ],
                 [ 2.8144124 , -0.44645762, -1.2920336 ,  2.0018265 ],
                 [ 3.3088658 ,  2.9831839 , -0.35555673,  1.3202416 ]]]],            dtype=float32),
  DeviceArray([-1.9357721 ,  0.32821512,  2.8020353 , -3.1447244 ], dtype=float32)),
 (3, 4, (1, 1), '

In [268]:
x = jnp.zeros((2,4,4,2))
b = jnp.array([1,2])
y = x + b

In [271]:
y[1,:,:,1]

DeviceArray([[2., 2., 2., 2.],
             [2., 2., 2., 2.],
             [2., 2., 2., 2.],
             [2., 2., 2., 2.]], dtype=float32)

In [354]:
??tf.keras.layers.Conv2D

[0;31mInit signature:[0m [0mtf[0m[0;34m.[0m[0mkeras[0m[0;34m.[0m[0mlayers[0m[0;34m.[0m[0mConv2D[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mConv2D[0m[0;34m([0m[0mConv[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""2D convolution layer (e.g. spatial convolution over images).[0m
[0;34m[0m
[0;34m  This layer creates a convolution kernel that is convolved[0m
[0;34m  with the layer input to produce a tensor of[0m
[0;34m  outputs. If `use_bias` is True,[0m
[0;34m  a bias vector is created and added to the outputs. Finally, if[0m
[0;34m  `activation` is not `None`, it is applied to the outputs as well.[0m
[0;34m[0m
[0;34m  When using this layer as the first layer in a model,[0m
[0;34m  provide the keyword argument `input_shape`[0m
[0;34m  (tuple of integers or `None`, does not include the sample axis),[0m
[0;34m  e.g.