# Convolutional Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: true

## Introduction

In this post, I'll start by implementing a basic convolutional layer using numpy and validate it against Keras.  After this, I'll move onto writing a more efficient one using JAX.    

## Import Libraries

For now, I only need numpy and tensorflow.

In [18]:
import numpy as np
import tensorflow as tf

Here's a small sequential model consisting of a convolutional layer and max-pooling layer.

In [33]:
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=(2, 2), strides=(2,2), padding='valid')

For this experiment, I don't care what the input to the model is.  So I'll just create a 1-element batch consisting of a random 28-by-28 array and apply `model` to it.  

In [35]:
inputs = np.random.randn(1,28,28,3)
outputs = layer(inputs)

2022-07-30 15:14:21.032488: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib:/usr/local/bin:/usr/local/lib:
2022-07-30 15:14:21.032522: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-07-30 15:14:21.032810: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (pop-os): /proc/driver/nvidia/version does not exist
2022-07-30 15:14:21.041129: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Let's check the shape of `inputs` and `outputs`.  

In [36]:
print(f'Feature Mapping:  {inputs.shape} -> {outputs.shape}')

Feature Mapping:  (1, 28, 28, 3) -> (1, 14, 14, 4)


The `outputs` is a 1-element batch consisting of a 7-by-7 array with 4 channels.

Getting the pooling layer output is easy; it's the same as `outputs`.  However, it's not immediately obvious how to get the intermediate convolution output. The recommended way to extract the outputs from all the layers in your model is to use the so-called *Functional* API.  

## Convolutional Layer

For now, `conv2d` is a function of a single, multi-channel image `x`, `kernel` is one of the output filters, and `strides` is a pair of integers defining the number of vertical and horizontal positions   

I don't have a good technical reason for making the output array `y` one dimensional and reshaping at the end; it just looks cleaner than some of the alternatives I played around with.

In [37]:
def filter_image_v1(image, kernel, strides):
    
    xm, xn, _ = image.shape 
    km, kn, _ = kernel.shape 
    
    sm, sn = strides
    ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    
    y = np.zeros(ym*yn)
    k = 0
    for i in range(0, xm-km+1, sm):
        for j in range(0, xn-kn+1, sn):
            y[k] = np.sum(kernel * image[i:i+km,j:j+kn,:])
            k += 1
            
    return np.reshape(y, (ym, yn))

This implementation applies all filters and biases to a single image in the batch, resulting in a rank 3 array.  After doing all the preliminary setup, 

In [67]:
def filter_image_v2(image, filters, biases, strides):
    
    xm, xn, _  = image.shape 
    km, kn, _, no = filters.shape 
    
    # ni => number of input channels
    # no => number of output channels
    
    sm, sn = strides
    ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    
    y = np.zeros((ym, yn, no))

    for iy, ix in enumerate(range(0, xm-km+1, sm)):
        for jy, jx in enumerate(range(0, xn-kn+1, sn)):
            # Extract the window and apply every output filter to the window
            window = image[ix:ix+km,jx:jx+kn,:]
            for channel in range(no):
                y[iy,jy,channel] = np.sum(filters[...,channel] * window) + biases[channel]
            
    return y

In [69]:
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=(2, 2), strides=(2,2), padding='valid')

In [76]:
image = np.random.randn(1,28,28,3)
keras_output = layer(image)[0,:,:,:]
filters, biases = layer.get_weights()
strides = (2,2)

filter_image_v2(np.squeeze(image), filters, biases, strides)

array([[[-1.08534935e-01,  5.39739621e-02,  5.59589955e-01,
          1.08302593e-01],
        [ 6.58584484e-02,  4.46099085e-01,  7.87501674e-01,
         -7.03558574e-01],
        [-8.00203938e-01, -1.33391277e-01, -1.39527965e-01,
         -5.17312697e-01],
        [-5.95222389e-01, -2.30186106e+00, -1.00016599e+00,
         -3.53730820e+00],
        [ 1.05507175e+00,  2.29927715e-01,  7.53264852e-02,
          1.70476202e+00],
        [ 4.23598228e-01, -2.80074382e-01,  1.63063877e+00,
          8.33309978e-01],
        [ 3.20196650e+00,  1.66784160e-01, -3.63189106e-01,
         -1.00131308e-01],
        [ 2.22388031e-01, -1.71247240e+00,  3.77570452e-01,
         -1.13530512e+00],
        [-1.39182051e+00, -6.61468238e-01, -6.18569571e-01,
         -7.34410004e-01],
        [-5.75423405e-01, -1.75251637e-01, -3.71638919e-01,
         -6.08501428e-01],
        [-2.72515885e-01,  2.60302572e+00, -3.36486769e-01,
         -2.30307188e-01],
        [-1.12954774e+00,  4.51343939e-01, 

In [77]:
keras_output

<tf.Tensor: shape=(14, 14, 4), dtype=float32, numpy=
array([[[-1.08534932e-01,  5.39739281e-02,  5.59589982e-01,
          1.08302608e-01],
        [ 6.58584684e-02,  4.46099073e-01,  7.87501693e-01,
         -7.03558564e-01],
        [-8.00204039e-01, -1.33391246e-01, -1.39527947e-01,
         -5.17312586e-01],
        [-5.95222414e-01, -2.30186105e+00, -1.00016594e+00,
         -3.53730822e+00],
        [ 1.05507171e+00,  2.29927674e-01,  7.53264874e-02,
          1.70476198e+00],
        [ 4.23598200e-01, -2.80074388e-01,  1.63063884e+00,
          8.33309948e-01],
        [ 3.20196629e+00,  1.66784167e-01, -3.63189101e-01,
         -1.00131355e-01],
        [ 2.22388119e-01, -1.71247244e+00,  3.77570391e-01,
         -1.13530517e+00],
        [-1.39182055e+00, -6.61468208e-01, -6.18569613e-01,
         -7.34409988e-01],
        [-5.75423360e-01, -1.75251648e-01, -3.71638924e-01,
         -6.08501554e-01],
        [-2.72515923e-01,  2.60302567e+00, -3.36486816e-01,
         -2.30307

Now I can make a little class that contains the filters, biases, and other necessary parameters for a convolutional layer.  The `__call__` method correlates each output filter with the input image and adds the bias.  Each filtered output is added to a list, and converted to a single-element batch.

In [40]:
class MyConv2D:
    def __init__(self, filters, biases, strides):
        self.filters = filters
        self.biases = biases 
        self.strides = strides 
        
    def __call__(self, images):
        biases, filters, strides = self.biases, self.filters, self.strides
        
        num_output_channels = len(biases)
        
        # get the first image
        inputs = inputs[0,...]
        
        # get the list of output feature maps
        
        outputs = [filter_image_v2(image, filters, biases, strides) for image in images]
        outputs = np.stack(outputs, axis=0)
        return outputs

        

In [41]:
x = np.array([[1,2,3],[4,5,6],[7,8,9]])
y = []

y.append(x)
y.append(x)

z = np.array(y)
z.shape

(2, 3, 3)

## Compare to Keras

To compare my convolutional layer to Keras', I need to get any convolutional-relevant parameters from the Keras model and use them to initialize my layer.

In [42]:
kernels, biases = model.layers[0].get_weights()
strides = model.layers[0].strides
keras_conv_features = keras_features[0]

NameError: name 'model' is not defined

In [290]:
my_layer = MyConv2D(kernels, biases, strides=strides)
my_conv_features = my_layer(inputs)

In [292]:
assert np.all(np.isclose(keras_conv_features, my_conv_features, atol=1e-6))
assert np.all(np.isclose(jax_result, my_conv_features, atol=1e-6))

### Convolution in JAX

In [None]:
jax_result = jax.lax.conv_general_dilated(
    lhs=inputs,
    rhs=kernels,
    window_strides=strides,
    padding='valid',
    dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)