In [2]:
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.models import Sequential
import numpy as np
import tensorflow as tf
import torch

In [72]:
def create_random_input(shape):
    return np.random.normal(size=shape)

def ref_conv2d(X, w, b, kernel_size, filters, strides, padding):
    return Conv2D(input_shape=X.shape, weights=[w, b], kernel_size=kernel_size, filters=filters, strides=strides, padding=padding)(X).numpy()
    
def check(fn, shape, kernel_size, filters, strides, padding):
    np.random.seed(3)
    X = create_random_input(shape)
    w = create_random_input((kernel_size[0], kernel_size[1], shape[3], filters))
    b = create_random_input((filters,))
    
    
    #X = np.ones(X.shape)
    #b = np.ones(b.shape)
    w = np.ones(w.shape)
    
    ref = ref_conv2d(X, w, b, kernel_size, filters, strides, padding)
    res = fn(X, w, b, kernel_size, filters, strides, padding)
    
    print(res)
    print(ref)
    
    return np.allclose(res, ref)

In [73]:
def res_conv2d(X, w, b, kernel_size, filters, strides, padding, groups = 1):
    # extract parameters
    (channels_in, kernel_rows, kernel_cols, channels_out) = w.shape
    (batch, rows_in, cols_in, channels_in_) = X.shape
    grouped_channels_out = channels_out//groups
    
    # check parameter compatibility
    assert channels_in * groups == channels_in_
    assert channels_out % groups == 0
    
    if padding == "same":
        rows_offset = kernel_rows//2
        cols_offset = kernel_cols//2
    
        # calculate output dimensions
        rows_out = (rows_in + strides[0] - 1) // strides[0]
        cols_out = (cols_in + strides[1] - 1) // strides[1]
        
    
        # create output buffer
        out = np.zeros((int(batch), int(rows_out), int(cols_out), int(groups*channels_out)))

        # prefill the output with bias
        for i in range(batch):
            for y in range(rows_out):
                for x in range(cols_out):
                    for g in range(groups):
                        gc = g * grouped_channels_out
                        for co in range(grouped_channels_out):
                            out[i, y, x, gc + co] = b[co]

        # convolute
        for g in range(groups):
            for i in range(batch):
                for co in range(grouped_channels_out):
                    for y in range(rows_out):
                        for x in range(cols_out):
                            sy = y * strides[0] - rows_offset
                            sx = x * strides[1] - rows_offset
                            for ci in range(channels_in):
                                for ky in range(sy, sy + kernel_rows):
                                    for kx in range(sx, sx + kernel_cols):
                                        if ky >= 0 and ky < rows_in and kx >= 0 and kx < cols_in:
                                            gc = g * grouped_channels_out
                                            out[i, y, x, gc + co] += X[i, ky , kx, gc + ci] * w[ci, (ky - sy), (kx - sx), co]


        #TODO: try indexing (g, b, co, ci, y, x, ky, kx)
        return out
    
    else:
        rows_out = (rows_in - (kernel_rows - strides[0])) // strides[0]
        cols_out = (cols_in - (kernel_cols - strides[1])) // strides[1]
        
    
        # create output buffer
        out = np.zeros((int(batch), int(rows_out), int(cols_out), int(groups*channels_out)))

        # prefill the output with bias
        for i in range(batch):
            for y in range(rows_out):
                for x in range(cols_out):
                    for g in range(groups):
                        gc = g * grouped_channels_out
                        for co in range(grouped_channels_out):
                            out[i, y, x, gc + co] = b[co]

        # convolute
        for g in range(groups):
            for i in range(batch):
                for co in range(grouped_channels_out):
                    for y in range(rows_out):
                        for x in range(cols_out):
                            sy = y * strides[0]
                            sx = x * strides[1]
                            for ci in range(channels_in):
                                for ky in range(sy, sy + kernel_rows):
                                    for kx in range(sx, sx + kernel_cols):
                                        gc = g * grouped_channels_out
                                        out[i, y, x, gc + co] += X[i, ky , kx, gc + ci] * w[ci, (ky - sy), (kx - sx), co]


        #TODO: try indexing (g, b, co, ci, y, x, ky, kx)
        return out

#print(check(res_conv2d, (4, 128, 128, 3), (3, 3), 5, (1, 1), "valid"))
print(check(res_conv2d, (4, 128, 128, 1), (1, 1), 1, (1, 1), "valid"))
#print(check(res_conv2d, (1, 4, 4, 1), (1, 1), 1, (1, 1), "valid"))

[[[[ 2.71607689]
   [ 1.36395826]
   [ 1.02394588]
   ...
   [ 1.21448358]
   [ 0.85000745]
   [ 1.20351691]]

  [[ 0.27903753]
   [ 0.18998358]
   [ 0.75935832]
   ...
   [ 1.19323546]
   [ 0.0153575 ]
   [ 0.77139001]]

  [[ 0.28865752]
   [ 0.2730332 ]
   [ 3.63937475]
   ...
   [ 1.18183599]
   [-0.95938851]
   [ 1.02404215]]

  ...

  [[ 0.44786289]
   [ 1.6655292 ]
   [ 2.23878455]
   ...
   [ 0.59076755]
   [ 1.78174174]
   [-0.50496107]]

  [[ 0.54481616]
   [ 1.7672007 ]
   [ 0.19824551]
   ...
   [ 0.43240879]
   [ 2.69003693]
   [-0.0678365 ]]

  [[ 0.9901682 ]
   [ 0.34041118]
   [ 0.94198364]
   ...
   [-0.11965981]
   [ 1.78009468]
   [ 1.84959651]]]


 [[[ 0.26683737]
   [ 3.24982502]
   [ 0.70195748]
   ...
   [ 1.65677624]
   [ 1.66203106]
   [ 1.77892073]]

  [[ 2.54803867]
   [-0.17601135]
   [ 1.46385325]
   ...
   [ 2.59416126]
   [ 0.0729981 ]
   [ 0.62784676]]

  [[ 1.6048177 ]
   [ 1.10351478]
   [ 1.02011333]
   ...
   [ 0.53801618]
   [ 1.82717448]
   [ 1.8853