# Checking Understanding of Convolutional and Pooling Layers

- toc: true
- badges: true
- comments: false
- categories: [jax, convolution, pooling]
- hide: true

## Introduction

The purpose of this post is to make sure I understand how convolutional and pooling layers work.  Once again, I'll use Keras to double check all my work. 

## Import Libraries

In [86]:
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import pandas as pd

Here's a small sequential model consisting of a convolutional layer and max-pooling layer.

In [180]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=4, kernel_size=(2, 2), strides=(2,2), padding='VALID'),    
    tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2))
])

For this experiment, I don't care what the model input is.  So I'll apply `model` to a 1-element batch consisting of a random 28-by-28 array with 3 channels.  As you can see, `outputs` is a 1-element batch consisting of a 7-by-7 array with 4 channels.  

In [181]:
inputs = np.random.randn(1,28,28,3)
outputs = model(inputs)
print(f'Feature Mapping:  {inputs.shape} -> {outputs.shape}')

Feature Mapping:  (1, 28, 28, 3) -> (1, 7, 7, 4)


We definitely don't want `model` to change the number of items in the batch, so the fact that the first-dimension (which is the batch dimension) for `inputs` and `outputs` is 1 make sense.  It also makes sense that the number of channels in `outputs` should be 4.  The `filters=4` argument in `Conv2D` means that no matter how large or small the input is, `Conv2D` will generate a feature map with 4 channels.  We'll see how this works in more detail later on.  Because pooling layers preserve the number of channels between inputs and outputs, `outputs` also has 4 channels.

This is the recommended way to get the features from all layers in the model: basically use Keras' functional API to build a model that returns the outputs from each layer, given an input batch.  

In [182]:
outputs = [layer.output for layer in model.layers]
layer_output_model = tf.keras.Model(inputs=model.input, outputs=outputs)
keras_features = layer_output_model(inputs)

Because `model` has two layers, the `keras_features` array has two elements.  Here are the shapes of each:

In [183]:
print(f'Conv2D output shape = {keras_features[0].shape}')
print(f'MaxPool2D output shape = {keras_features[1].shape}')

Conv2D output shape = (1, 14, 14, 4)
MaxPool2D output shape = (1, 7, 7, 4)


In [184]:
kernels, biases = model.layers[0].get_weights()
print(f'kernels shape = {kernels.shape}, biases shape = {biases.shape}')

kernels shape = (2, 2, 3, 4), biases shape = (4,)


Here's a faily inefficient way to duplicate the evaluation of the `conv_layer` defined in above.  

## Convolutional Layer

In [185]:
def conv2d(x, kernel, strides):
    xm, xn, _ = x.shape 
    km, kn, _ = kernel.shape 
    
    sm, sn = strides
    ym, yn = 1 + ((xm - km + 1)//sm), 1 + ((xn - kn + 1)//sn)
    
    y = np.zeros(ym*yn)
    k = 0
    for i in range(0, xm-km+1, sm):
        for j in range(0, xn-kn+1, sn):
            y[k] = np.sum(kernel * x[i:i+km,j:j+kn,:])
            k += 1
            
    return np.reshape(y, (ym, yn))

In [186]:
class MyConv2D:
    def __init__(self, w, b, strides):
        self.w = w
        self.b = b 
        self.strides = strides 
        
    def __call__(self, inputs):
        biases, kernels, strides = self.b, self.w, self.strides
        
        num_output_channels = len(biases)
        
        # get the first image
        inputs = inputs[0,...]
        
        # get the list of output feature maps
        outputs = [conv2d(inputs, kernels[...,i], strides) + biases[i] for i in range(num_output_channels)]
        
        # horizontally stacking the 2D images
        outputs = np.stack(outputs, axis=-1) 
        
        # Add the batch dimension
        outputs = outputs[np.newaxis,...]
         
        return outputs
        

### Check

In [187]:
kernels, biases = model.layers[0].get_weights()
strides = model.layers[0].strides
keras_conv_features = keras_features[0]

In [188]:
my_layer = MyConv2D(kernels, biases, strides=strides)
my_conv_features = my_layer(inputs)

In [195]:
jax_result = jax_conv2d(
    lhs=inputs,
    rhs=kernels,
    window_strides=strides,
    padding='valid',
    dimension_numbers=('NHWC', 'HWIO', 'NHWC')
)



(1, 14, 14, 4)

In [196]:
assert np.all(np.isclose(keras_conv_features, my_conv_features, atol=1e-6))
assert np.all(np.isclose(jax_result, my_conv_features, atol=1e-6))

### Convolution in JAX

In [117]:
a = jnp.zeros((2,3))
np.array(a)

array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [145]:
jax_result.shape

(1, 14, 14, 4)

In [104]:
jax_conv2d = jax.lax.conv_general_dilated
jax_conv2d??

[0;31mSignature:[0m
[0mjax_conv2d[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mlhs[0m[0;34m:[0m [0mAny[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrhs[0m[0;34m:[0m [0mAny[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwindow_strides[0m[0;34m:[0m [0mSequence[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpadding[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mSequence[0m[0;34m[[0m[0mTuple[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mint[0m[0;34m][0m[0;34m][0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlhs_dilation[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mSequence[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrhs_dilation[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mSequence[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0

## Pooling Layer

In [593]:
pooling_layer = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(2,2))

In [594]:
yy = pooling_layer(feature_maps)
print(f'{feature_maps.shape} -> {yy.shape}')

(1, 3, 3, 2) -> (1, 1, 1, 2)


In [540]:
yy

<tf.Tensor: shape=(1, 1, 1, 2), dtype=float32, numpy=array([[[[0.9654248, 1.2988867]]]], dtype=float32)>

In [274]:
print(feature_maps)

tf.Tensor(
[[[[ 0.89900464 -1.0616233 ]
   [-0.10247962  1.2988867 ]
   [-0.8515937   0.6712394 ]]

  [[ 0.9654248   0.23607667]
   [-0.43404913  0.27193436]
   [-0.5712364  -0.73882663]]

  [[ 0.6847748  -0.8419535 ]
   [ 0.46285468  0.7197448 ]
   [-1.3353215   0.47066653]]]], shape=(1, 3, 3, 2), dtype=float32)


In [275]:
print(yy)

tf.Tensor(
[[[[ 0.9654248   1.2988867 ]
   [-0.10247962  1.2988867 ]]

  [[ 0.9654248   0.7197448 ]
   [ 0.46285468  0.7197448 ]]]], shape=(1, 2, 2, 2), dtype=float32)


In [279]:
print(feature_maps[0,:,:,1])

tf.Tensor(
[[-1.0616233   1.2988867   0.6712394 ]
 [ 0.23607667  0.27193436 -0.73882663]
 [-0.8419535   0.7197448   0.47066653]], shape=(3, 3), dtype=float32)


In [280]:
yy[0,:,:,1]

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[1.2988867, 1.2988867],
       [0.7197448, 0.7197448]], dtype=float32)>

In [304]:
input_batch.shape

(1, 4, 4, 3)

In [365]:
list(range(0,10,2))

[0, 2, 4, 6, 8]

In [684]:
v = np.zeros((3,3))
v[0]

array([0., 0., 0.])

In [80]:
def pool2D(x, pool_size=(2,2), strides=(1,1), fn=np.max):
    xm, xn = x.shape 
    pm, pn = pool_size 
    sm, sn = strides
    
    ym, yn = 1 + (xm-pm+1) // sm, 1 + (xn-pn+1) // sn

    y = np.zeros((ym, yn))
    
    ii = 0
    for i in range(0, xm-pm+1, sm):
        jj = 0
        for j in range(0, xn-pn+1, sn):
            y[ii,jj] = fn(x[i:i+pm,j:j+pn])
            jj += 1
        ii += 1
    return y

In [81]:
x = np.random.randn(3,3)
b = pool2D(x, strides=(2,2))
print(a)
print(b)
print(x)

NameError: name 'a' is not defined

In [82]:
def pooling(features, pool_size=(2,2), strides=(2,2)):
    
    px, py = pool_size
    sm, sn = strides
    width, height, chans = features.shape 
    
    m, n = (width - px + 1) // sm, (height - py + 1) // sn
    
    features_ = np.zeros((m, n, chans))

    # Note that we're not changing the number of features
    for chan in range(chans):
        features_[:,:,chan] = pool2D(features[:,:,chan], pool_size, strides)
    
    return features_

In [598]:
pooling(feature_maps[0,:,:,:])

array([[[ 1.3011235 , -0.24245605]]])

In [249]:
np.stack([np.random.randn(10,10), np.random.randn(10,10)], axis=-1).shape

(10, 10, 2)