# Checking Understanding of Convolutional and Pooling Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: true

## Introduction

The purpose of this post is to make sure I understand how convolutional and pooling layers work.  Once again, I'll use Keras to double check all my work. 

## Import Libraries

In [264]:
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import pandas as pd

Here's a small sequential model consisting of a convolutional layer and max-pooling layer.

In [265]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=4, kernel_size=(2, 2), strides=(2,2), padding='VALID'),    
    tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2))
])

For this experiment, I don't care what the model input is.  So I'll apply `model` to a 1-element batch consisting of a random 28-by-28 array with 3 channels.  As you can see, `outputs` is a 1-element batch consisting of a 7-by-7 array with 4 channels. Here's a little table comparing the shape of `inputs` and `outputs`.

In [266]:
inputs = np.random.randn(1,28,28,3)
outputs = model(inputs)

print(f'Feature Mapping:  {inputs.shape} -> {outputs.shape}')

Feature Mapping:  (1, 28, 28, 3) -> (1, 7, 7, 4)


Getting the pooling layer output is easy; it's the same as `outputs`.  However, it's not immediately obvious how to get the intermediate convolution output. The recommended way to extract the outputs from all the layers in your model is to use the so-called *Functional* API.  

In [214]:
outputs = [layer.output for layer in model.layers]
layer_output_model = tf.keras.Model(inputs=model.input, outputs=outputs)
keras_features = layer_output_model(inputs)

In [215]:
kernels, biases = model.layers[0].get_weights()
print(f'kernels shape = {kernels.shape}, biases shape = {biases.shape}')

kernels shape = (2, 2, 3, 4), biases shape = (4,)


Here's a faily inefficient way to duplicate the evaluation of the `conv_layer` defined in above.  

## Convolutional Layer

Convolution is essentially a rolling dot-product, and this is exactly how the implementation below works.  For now, `conv2d` is a function of a single, multi-channel image `x`, `kernel` is one of the output filters, and `strides` is a pair of integers defining the number of vertical and horizontal positions   

I don't have a good technical reason for making the output array `y` one dimensional and reshaping at the end; it just looks cleaner than the alternatives.

In [216]:
def conv2d(x, kernel, strides):
    xm, xn, _ = x.shape 
    km, kn, _ = kernel.shape 
    
    sm, sn = strides
    ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    
    y = np.zeros(ym*yn)
    k = 0
    for i in range(0, xm-km+1, sm):
        for j in range(0, xn-kn+1, sn):
            y[k] = np.sum(kernel * x[i:i+km,j:j+kn,:])
            k += 1
            
    return np.reshape(y, (ym, yn))

In [217]:
class MyConv2D:
    def __init__(self, w, b, strides):
        self.w = w
        self.b = b 
        self.strides = strides 
        
    def __call__(self, inputs):
        biases, kernels, strides = self.b, self.w, self.strides
        
        num_output_channels = len(biases)
        
        # get the first image
        inputs = inputs[0,...]
        
        # get the list of output feature maps
        outputs = [conv2d(inputs, kernels[...,i], strides) + biases[i] for i in range(num_output_channels)]
        
        # horizontally stacking the 2D images
        outputs = np.stack(outputs, axis=-1) 
        
        # Add the batch dimension
        outputs = outputs[np.newaxis,...]
         
        return outputs
        

### Check

In [269]:
kernels, biases = model.layers[0].get_weights()
strides = model.layers[0].strides
keras_conv_features = keras_features[0]

In [270]:
my_layer = MyConv2D(kernels, biases, strides=strides)
my_conv_features = my_layer(inputs)

In [220]:
jax_result = jax.lax.conv_general_dilated(
    lhs=inputs,
    rhs=kernels,
    window_strides=strides,
    padding='valid',
    dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)



In [227]:
assert np.all(np.isclose(keras_conv_features, my_conv_features, atol=1e-6))
assert np.all(np.isclose(jax_result, my_conv_features, atol=1e-6))

### Convolution in JAX

In [117]:
a = jnp.zeros((2,3))
np.array(a)

array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [145]:
jax_result.shape

(1, 14, 14, 4)

## Pooling Layer

In [593]:
pooling_layer = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2))

In [594]:
yy = pooling_layer(feature_maps)
print(f'{feature_maps.shape} -> {yy.shape}')

(1, 3, 3, 2) -> (1, 1, 1, 2)


In [540]:
yy

<tf.Tensor: shape=(1, 1, 1, 2), dtype=float32, numpy=array([[[[0.9654248, 1.2988867]]]], dtype=float32)>

In [274]:
print(feature_maps)

tf.Tensor(
[[[[ 0.89900464 -1.0616233 ]
   [-0.10247962  1.2988867 ]
   [-0.8515937   0.6712394 ]]

  [[ 0.9654248   0.23607667]
   [-0.43404913  0.27193436]
   [-0.5712364  -0.73882663]]

  [[ 0.6847748  -0.8419535 ]
   [ 0.46285468  0.7197448 ]
   [-1.3353215   0.47066653]]]], shape=(1, 3, 3, 2), dtype=float32)


In [275]:
print(yy)

tf.Tensor(
[[[[ 0.9654248   1.2988867 ]
   [-0.10247962  1.2988867 ]]

  [[ 0.9654248   0.7197448 ]
   [ 0.46285468  0.7197448 ]]]], shape=(1, 2, 2, 2), dtype=float32)


In [279]:
print(feature_maps[0,:,:,1])

tf.Tensor(
[[-1.0616233   1.2988867   0.6712394 ]
 [ 0.23607667  0.27193436 -0.73882663]
 [-0.8419535   0.7197448   0.47066653]], shape=(3, 3), dtype=float32)


In [280]:
yy[0,:,:,1]

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1.2988867, 1.2988867],
       [0.7197448, 0.7197448]], dtype=float32)>

In [304]:
input_batch.shape

(1, 4, 4, 3)

In [365]:
list(range(0,10,2))

[0, 2, 4, 6, 8]

In [684]:
v = np.zeros((3,3))
v[0]

array([0., 0., 0.])

In [80]:
def pool2D(x, pool_size=(2,2), strides=(1,1), fn=np.max):
    xm, xn = x.shape 
    pm, pn = pool_size 
    sm, sn = strides
    
    ym, yn = 1 + (xm-pm+1) // sm, 1 + (xn-pn+1) // sn

    y = np.zeros((ym, yn))
    
    ii = 0
    for i in range(0, xm-pm+1, sm):
        jj = 0
        for j in range(0, xn-pn+1, sn):
            y[ii,jj] = fn(x[i:i+pm,j:j+pn])
            jj += 1
        ii += 1
    return y

In [81]:
x = np.random.randn(3,3)
b = pool2D(x, strides=(2,2))
print(a)
print(b)
print(x)

NameError: name 'a' is not defined

In [82]:
def pooling(features, pool_size=(2,2), strides=(2,2)):
    
    px, py = pool_size
    sm, sn = strides
    width, height, chans = features.shape 
    
    m, n = (width - px + 1) // sm, (height - py + 1) // sn
    
    features_ = np.zeros((m, n, chans))

    # Note that we're not changing the number of features
    for chan in range(chans):
        features_[:,:,chan] = pool2D(features[:,:,chan], pool_size, strides)
    
    return features_

In [598]:
pooling(feature_maps[0,:,:,:])

array([[[ 1.3011235 , -0.24245605]]])

In [249]:
np.stack([np.random.randn(10,10), np.random.randn(10,10)], axis=-1).shape

(10, 10, 2)

In [271]:
jax.lax.reduce_window(
    jnp.array(my_conv_features),
    -jnp.inf,
    jax.lax.max,
    window_dimensions=(1,2,2,1),
    window_strides=(1,2,2,1),
    padding='valid')

DeviceArray([[[[ 2.2203097 ,  0.7514108 ,  0.6305484 ,  1.5815421 ],
               [ 1.2177465 ,  0.9271074 ,  0.7391334 ,  0.9822045 ],
               [ 0.29114679,  1.8811291 ,  0.5059719 ,  0.63093317],
               [ 0.20470184,  1.3739707 ,  0.3607968 ,  1.2883999 ],
               [ 0.59057814,  1.5924685 ,  1.1464765 ,  1.4917585 ],
               [ 0.16934961,  0.64135337,  1.8001381 ,  0.16465424],
               [ 0.3894469 ,  0.6464122 ,  1.6553746 ,  1.2135262 ]],

              [[ 2.3377476 ,  0.51156694,  1.0020787 ,  1.3904204 ],
               [ 0.652942  ,  2.2939618 ,  1.4073436 ,  2.0440257 ],
               [ 0.7191202 ,  0.9767988 ,  0.668246  , -0.3064434 ],
               [ 1.3608527 ,  0.58269393,  0.35336065,  1.8263001 ],
               [ 0.78622943,  1.713314  ,  1.7906963 , -0.02068998],
               [ 0.76750404,  0.81492865,  0.57862407,  0.7561941 ],
               [ 0.25453514,  1.0671004 ,  1.1293529 ,  0.5066892 ]],

              [[ 0.6206944 ,  

In [272]:
outputs

<tf.Tensor: shape=(1, 7, 7, 4), dtype=float32, numpy=
array([[[[ 2.2203097 ,  0.75141084,  0.6305484 ,  1.581542  ],
         [ 1.2177465 ,  0.92710733,  0.7391335 ,  0.98220456],
         [ 0.29114673,  1.881129  ,  0.5059719 ,  0.6309332 ],
         [ 0.20470184,  1.3739707 ,  0.36079678,  1.2883998 ],
         [ 0.59057814,  1.5924685 ,  1.1464764 ,  1.4917585 ],
         [ 0.16934964,  0.64135337,  1.8001381 ,  0.16465425],
         [ 0.3894469 ,  0.6464122 ,  1.6553746 ,  1.2135262 ]],

        [[ 2.3377476 ,  0.511567  ,  1.0020788 ,  1.3904203 ],
         [ 0.652942  ,  2.2939618 ,  1.4073437 ,  2.0440257 ],
         [ 0.7191202 ,  0.97679865,  0.6682459 , -0.30644342],
         [ 1.3608527 ,  0.58269376,  0.35336065,  1.8263003 ],
         [ 0.7862294 ,  1.7133139 ,  1.7906963 , -0.02068999],
         [ 0.7675041 ,  0.81492865,  0.57862407,  0.7561942 ],
         [ 0.25453514,  1.0671003 ,  1.1293529 ,  0.5066892 ]],

        [[ 0.6206945 ,  0.09209803,  1.6532174 ,  2.0552828 

In [228]:
??jax.lax.reduce_window

[0;31mSignature:[0m
[0mjax[0m[0;34m.[0m[0mlax[0m[0;34m.[0m[0mreduce_window[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0moperand[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minit_value[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcomputation[0m[0;34m:[0m [0mCallable[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwindow_dimensions[0m[0;34m:[0m [0mSequence[0m[0;34m[[0m[0mUnion[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mAny[0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwindow_strides[0m[0;34m:[0m [0mSequence[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpadding[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mSequence[0m[0;34m[[0m[0mTuple[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mint[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbase_dilation[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mSequence[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[

In [238]:
??jax.lax.max

[0;31mSignature:[0m [0mjax[0m[0;34m.[0m[0mlax[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0mx[0m[0;34m:[0m [0mAny[0m[0;34m,[0m [0my[0m[0;34m:[0m [0mAny[0m[0;34m)[0m [0;34m->[0m [0mAny[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mmax[0m[0;34m([0m[0mx[0m[0;34m:[0m [0mArray[0m[0;34m,[0m [0my[0m[0;34m:[0m [0mArray[0m[0;34m)[0m [0;34m->[0m [0mArray[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34mr"""Elementwise maximum: :math:`\mathrm{max}(x, y)`[0m
[0;34m[0m
[0;34m  For complex numbers, uses a lexicographic comparison on the[0m
[0;34m  `(real, imaginary)` pairs."""[0m[0;34m[0m
[0;34m[0m  [0;32mreturn[0m [0mmax_p[0m[0;34m.[0m[0mbind[0m[0;34m([0m[0mx[0m[0;34m,[0m [0my[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mFile:[0m      ~/anaconda3/lib/python3.8/site-packages/jax/_src/lax/lax.py
[0;31mType:[0m      function
