In [None]:

import os
os.environ.setdefault("LOGURU_LEVEL", "INFO" if not os.getenv("DEBUG") else "DEBUG")

In [None]:

from timeit import timeit

from loguru import logger
import numpy as np
from numpy.lib.stride_tricks import as_strided
from scipy import signal as libsignal

In [None]:




def conv2d_vectorized(
    signal: np.ndarray, kernel: np.ndarray, out_h: int, out_w: int
) -> np.ndarray:
    view_shape = (out_h, out_w, *kernel.shape)  # (
    view_strides = (
        signal.strides[0],  # jump to next row in signal,
        signal.strides[1],  # jump to next col in signal,
        signal.strides[0],  # jump to next row in window,
        signal.strides[1],  # jump to next col in window,
    )
    views = as_strided(signal, view_shape, view_strides)  # (2,3,2,2)
    reshaped_view = views.reshape(-1, kernel.size)
    flattened_kernel = kernel.reshape(-1)
    # [[0 0 1 2]
    #  [0 0 2 2]
    #  [0 0 2 1]
    #  [1 2 0 0]
    #  [2 2 0 0]
    #  [2 1 0 0]]
    return np.einsum("ji,i->j", reshaped_view, flattened_kernel).reshape(out_h, out_w)


def conv2d_naive(
    signal: np.ndarray, kernel: np.ndarray, out_h: int, out_w: int
) -> np.ndarray:
    kernel_h, kernel_w = kernel.shape
    conv = np.zeros((out_h, out_w))
    for i in range(out_h):
        for j in range(out_w):
            conv[i][j] = np.sum(signal[i : kernel_h + i, j : kernel_w + j] * kernel)

    return conv


# Signal
# [[0,0,0,0],
#  [1,2,2,1],
#  [0,0,0,0]],
signal: np.ndarray = np.array([
    [0] * 4,
    [1, 2, 2, 1],
    [0] * 4,
])
# Kernel
# [[-1,-1],
#  [1,1]]
kernel: np.ndarray = np.asarray([-1, -1, 1, 1]).reshape(2, 2)

# Transpose by swapping strides
transposed_signal = as_strided(
    signal,
    (signal.shape[1], signal.shape[0]),
    strides=(signal.strides[1], signal.strides[0]),
)
logger.debug(transposed_signal)
# [[0 1 0]
#  [0 2 0]
#  [0 2 0]
#  [0 1 0]]

# Flatten kernel
flattened_kernel = kernel.reshape(-1)
logger.debug(flattened_kernel)
# [-1,-1,1,1]
# We would like to create a view of the signal which is nx4 to allow for easy
# broadcasting of the multiplication operation with the flattened_kernel.
# We do this with striding.
out_h, out_w = map(lambda t: t[0] - t[1] + 1, zip(signal.shape, kernel.shape))
conv = conv2d_vectorized(signal, kernel, out_h, out_w)
logger.info(
    "conv2d_vectorized: "
    + "{:.5e}".format(
        timeit(
            "conv2d_vectorized( signal, kernel, *map(lambda t: t[0] - t[1] + 1, zip(signal.shape, kernel.shape)))",
            globals=globals(),
            number=1000,
        )
    ),
)
conv_2 = conv2d_naive(signal, kernel, out_h, out_w)
logger.info(
    "conv2d_naive: "
    + "{:.5e}".format(
        timeit(
            "conv2d_naive( signal, kernel, out_h, out_w)",
            globals=globals(),
            number=1000,
        )
    ),
)
assert np.array_equal(conv, conv_2)

conv_3 = libsignal.correlate2d(signal, kernel, mode="valid")
logger.info(
    "signal.correlated2d: "
    + "{:.5e}".format(
        timeit("libsignal.correlate2d( signal, kernel)", globals=globals(), number=1000)
    ),
)
assert np.array_equal(conv, conv_3)


def conv2d_vectorized_batched_channelled(
    signal: np.ndarray, kernel: np.ndarray, out_h: int, out_w: int
) -> np.ndarray:
    batch, channel = signal.shape[:2]
    view_shape = (batch, channel, out_h, out_w, *kernel.shape[1:])  # (
    view_strides = (
        signal.strides[0],  # jump to next batch
        signal.strides[1],  # jump to next channel
        signal.strides[2],  # jump to next row in signal,
        signal.strides[3],  # jump to next col in signal,
        signal.strides[2],  # jump to next row in window,
        signal.strides[3],  # jump to next col in window,
    )
    views = as_strided(signal, view_shape, view_strides)  # (2,2,2,3,2,2)
    # [[0 0 1 2]
    #  [0 0 2 2]
    #  [0 0 2 1]
    #  [1 2 0 0]
    #  [2 2 0 0]
    #  [2 1 0 0]]
    return np.einsum("bchwkl,ckl->bchw", views, kernel).reshape(
        batch, channel, out_h, out_w
    )


batches = 1000
channels = 3
batched_channelled_signal = np.random.rand(batches, channels, *signal.shape)
channelled_kernel = kernel.reshape(1, *kernel.shape)
logger.debug(
    result := conv2d_vectorized_batched_channelled(
        batched_channelled_signal, channelled_kernel, out_h, out_w
    )
)

logger.info(
    "conv2d_vectorized_batched_channelled: "
    + "{:.5e}".format(
        tot := timeit(
            "libsignal.correlate2d( signal, kernel)", globals=globals(), number=1000
        )
    ),
)
logger.info(f"Per batch/channel: {tot / (batches * channels):.5e}")
