In [1]:
import torch
from torch import nn

In [2]:
def corr2d(X, K):

    # Convolution in deep learning is a misnomer.
    # In fact, it is cross-correlation.
    # https://d2l.ai/chapter_convolutional-neural-networks/conv-layer.html
    # This is equivalent as Conv2D that that input_channel == output_channel == 1 and stride == 1.

    assert X.dim() == 2 and K.dim() == 2

    h, w = K.shape
    Y = torch.zeros((X.shape[0] - h + 1, X.shape[1] - w + 1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i, j] = (X[i:i + h, j:j + w] * K).sum()

    return Y

In [3]:
def get_sparse_kernel_matrix(K, h_X, w_X):

    # Assuming no channels and stride == 1.
    # Convert the kernel matrix to sparse matrix (dense matrix with lots of zeros in fact).
    # This is a little bit brain-twisting.

    h_K, w_K = K.shape

    h_Y, w_Y = h_X - h_K + 1, w_X - w_K + 1

    W = torch.zeros((h_Y * w_Y, h_X * w_X))
    for i in range(h_Y):
        for j in range(w_Y):
            for ii in range(h_K):
                for jj in range(w_K):
                    W[i * w_Y + j, i * w_X + j + ii * w_X + jj] = K[ii, jj]

    return W

In [4]:
def conv2d_as_matrix_mul(X, K):

    # Assuming no channels and stride == 1.
    # Convert the kernel matrix to sparse matrix (dense matrix with lots of zeros in fact).
    # This is a little bit brain-twisting.

    h_K, w_K = K.shape
    h_X, w_X = X.shape

    h_Y, w_Y = h_X - h_K + 1, w_X - w_K + 1

    W = get_sparse_kernel_matrix(K=K, h_X=h_X, w_X=w_X) # given dimensions of the sparse matrix

    Y = torch.matmul(W, X.reshape(-1)).reshape(h_Y, w_Y)

    return Y

In [5]:
def conv_transposed_2d_as_matrix_mul(X, K):

    # Assuming no channels and stride == 1.
    # Convert the kernel matrix to sparse matrix (dense matrix with lots of zeros in fact).
    # This is a little bit brain-twisting.

    h_K, w_K = K.shape
    h_X, w_X = X.shape

    h_Y, w_Y = h_X + h_K - 1, w_X + w_K - 1

    # It's like the kernel were applied on the output tensor.
    W = get_sparse_kernel_matrix(K=K, h_X=h_Y, w_X=w_Y)

    # Weight matrix tranposed.
    Y = torch.matmul(W.T, X.reshape(-1)).reshape(h_Y, w_Y)

    return Y

In [6]:
X = torch.arange(30).reshape(5, 6).float()
K = torch.arange(8).reshape(2, 4).float()
print("X:")
print(X)
print("K:")
print(K)
print("Cross-Correlation:")
Y = corr2d(X=X, K=K)
print(Y)

X:
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.],
        [12., 13., 14., 15., 16., 17.],
        [18., 19., 20., 21., 22., 23.],
        [24., 25., 26., 27., 28., 29.]])
K:
tensor([[0., 1., 2., 3.],
        [4., 5., 6., 7.]])
Cross-Correlation:
tensor([[184., 212., 240.],
        [352., 380., 408.],
        [520., 548., 576.],
        [688., 716., 744.]])


In [7]:
conv = nn.Conv2d(in_channels=1,
                 out_channels=1,
                 kernel_size=K.shape,
                 padding=0,
                 stride=1,
                 bias=False)
conv.weight.data = K.unsqueeze(0).unsqueeze(0)
Z1 = conv(X.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0).detach()
print("Convolution:")
print(Z1)
assert torch.equal(Y, Z1)

Convolution:
tensor([[184., 212., 240.],
        [352., 380., 408.],
        [520., 548., 576.],
        [688., 716., 744.]])


In [8]:
print("Convolution as Matrix Multiplication:")
Z2 = conv2d_as_matrix_mul(X=X, K=K)
print(Z2)
assert torch.equal(Y, Z2)

Convolution as Matrix Multiplication:
tensor([[184., 212., 240.],
        [352., 380., 408.],
        [520., 548., 576.],
        [688., 716., 744.]])


In [9]:
conv_transposed = nn.ConvTranspose2d(in_channels=1,
                                     out_channels=1,
                                     kernel_size=K.shape,
                                     padding=0,
                                     stride=1,
                                     bias=False)
conv_transposed.weight.data = K.unsqueeze(0).unsqueeze(0)
Z3 = conv_transposed(Y.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0).detach()
print("Transposed Convolution:")
print(Z3)
# The shape will "go back".
assert Z3.shape == X.shape

Transposed Convolution:
tensor([[    0.,   184.,   580.,  1216.,  1116.,   720.],
        [  736.,  2120.,  4208.,  5984.,  4880.,  2904.],
        [ 1408.,  3800.,  7232., 10016.,  7904.,  4584.],
        [ 2080.,  5480., 10256., 14048., 10928.,  6264.],
        [ 2752.,  6304., 10684., 12832.,  9476.,  5208.]])


In [10]:
print("Transposed Convolution as Matrix Multiplication:")
Z4 = conv_transposed_2d_as_matrix_mul(X=Y, K=K)
print(Z4)
assert torch.equal(Z3, Z4)
assert Z4.shape == X.shape

Transposed Convolution as Matrix Multiplication:
tensor([[    0.,   184.,   580.,  1216.,  1116.,   720.],
        [  736.,  2120.,  4208.,  5984.,  4880.,  2904.],
        [ 1408.,  3800.,  7232., 10016.,  7904.,  4584.],
        [ 2080.,  5480., 10256., 14048., 10928.,  6264.],
        [ 2752.,  6304., 10684., 12832.,  9476.,  5208.]])
