In [4]:
import torch

$$
X = \begin{bmatrix}
0 & 1 \\
2 & 3
\end{bmatrix}
$$

$$
K = \begin{bmatrix}
0 & 1 \\
2 & 3
\end{bmatrix}
$$

$$XK = 
\begin{bmatrix}
0 & 0 & \\
0 & 0 & \\
  &   & 
\end{bmatrix}
+
\begin{bmatrix}
& 0 & 1 \\
& 2 & 3 \\
& & 
\end{bmatrix}
+
\begin{bmatrix}
& & \\
0 & 2 & \\
4 & 6 &
\end{bmatrix}
+
\begin{bmatrix}
& & \\
& 0 & 3 \\
& 6 & 9
\end{bmatrix}
=
\begin{bmatrix}
0 & 0 & 1 \\
0 & 4 & 6 \\
4 & 12 & 9
\end{bmatrix}
$$

In [3]:
def transpose_conv(X, K):
    kh, kw = K.shape
    xh, xw = X.shape
    Y = torch.zeros((xh + kh - 1, xw + kw - 1))
    for i in range(xh):
        for j in range(xw):
            Y[i:i+kh, j:j+kw] += X[i, j] * K
    return Y


X = torch.tensor([[0., 1.], [2., 3.]])
K = torch.tensor([[0., 1.], [2., 3.]])
transpose_conv(X, K)

tensor([[ 0.,  0.,  1.],
        [ 0.,  4.,  6.],
        [ 4., 12.,  9.]])

In [44]:
X = torch.tensor([[0., 1.], [2., 3.]])
K = torch.tensor([[0., 1.], [2., 3.]])
X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
tconv = torch.nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
tconv.weight.data = K # Inject K
tconv(X)

tensor([[[[0., 0., 0., 1.],
          [0., 0., 2., 3.],
          [0., 2., 0., 3.],
          [4., 6., 6., 9.]]]], grad_fn=<ConvolutionBackward0>)

## Multi-Channels

Let $X$ be input with 2 channels.

$$
X = [X_1, X_2] = \begin{bmatrix}
\begin{bmatrix}
0 & 1 \\ 2 & 3
\end{bmatrix}
&
\begin{bmatrix}
0 & 2 \\ 4 & 6
\end{bmatrix}
\end{bmatrix}
$$

Let $K$ be kernel with 2 channels.

$$
K = [K_1, K_2] = \begin{bmatrix}
\begin{bmatrix}
1 & 1 \\ 1 & 1
\end{bmatrix}
&
\begin{bmatrix}
0.5 & 0.5 \\ 0.5 & 0.5
\end{bmatrix}
\end{bmatrix}
$$

The first values of $X_1$ and $X_2$ are 0, so we expect to get 0's for whatever output. The second value of $X_1$ is 1 and $X_2$ is 2. We should expect output to be

$$
1 * \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix} + 2 * \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} = 
\begin{bmatrix} 2 & 2 \\ 2 & 2 \end{bmatrix} 
$$

Repeat the same logic for every operation.

$$
2 * \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix} + 4 * \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} = 
\begin{bmatrix} 4 & 4 \\ 4 & 4 \end{bmatrix} 
$$

$$
3 * \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix} + 6 * \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} = 
\begin{bmatrix} 6 & 6 \\ 6 & 6 \end{bmatrix} 
$$

In [125]:
X1 = torch.tensor([[0., 1.], [2., 3.]])
X2 = torch.tensor([[0., 2.], [4., 6.]])
X = torch.stack([X1,X2])

# Kernel with 2 channels.
K1 = torch.tensor([[1.0, 1.0], [1.0, 1.0]])
K2 = torch.tensor([[0.5, 0.5], [0.5, 0.5]])
K = torch.stack([K1, K2])

X, K = X.reshape(1, 2, 2, 2), K.reshape(2, 1, 2, 2) # Channel-first then N and (H,W)
tconv = torch.nn.ConvTranspose2d(2, 1, kernel_size=2, stride=2, bias=False)
tconv.weight.data = K # Inject K
tconv(X)

tensor([[[[0., 0., 2., 2.],
          [0., 0., 2., 2.],
          [4., 4., 6., 6.],
          [4., 4., 6., 6.]]]], grad_fn=<ConvolutionBackward0>)

In [53]:
tconv = torch.nn.ConvTranspose2d(2, 1, kernel_size=2, stride=2, bias=False)
print(tconv.weight.data.shape)

torch.Size([2, 1, 2, 2])


## Connection to Matrix Transposition

In [104]:
X = torch.arange(9.0).reshape(3, 3)
K = torch.tensor([[1.0, 2.0], [3.0, 4.0]])

In [112]:
def kernel2matrix(K):
    k, W = torch.zeros(5), torch.zeros((4, 9))
    k[:2], k[3:5] = K[0, :], K[1, :]
    W[0, :5], W[1, 1:6], W[2, 3:8], W[3, 4:] = k, k, k, k
    return W

W = kernel2matrix(K)
W

tensor([[1., 2., 0., 3., 4., 0., 0., 0., 0.],
        [0., 1., 2., 0., 3., 4., 0., 0., 0.],
        [0., 0., 0., 1., 2., 0., 3., 4., 0.],
        [0., 0., 0., 0., 1., 2., 0., 3., 4.]])

In [111]:
X

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])

Perform convolution.

In [115]:
Y = (W @ X.reshape(-1)).reshape((2, 2))
Y

tensor([[27., 37.],
        [57., 67.]])

To restore the matrix, we can just do a transpose of $W$. But, we only get back original values if we have a correct transpose version of $W$.

In [118]:
(W.T @ Y.reshape(-1)).reshape(3, 3)

tensor([[ 27.,  91.,  74.],
        [138., 400., 282.],
        [171., 429., 268.]])

## as_strided

$$\begin{bmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9
\end{bmatrix}
$$

In [72]:
x = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
N, D = x.shape

X is flatten to `[1, 2, 3, 4, 5, 6, 7, 8, 9]`, Stride (3, 1), it takes every 3rd element as `row[i][0]` and every element for `row[i][j]`.

In [85]:
torch.as_strided(x, (2, 2), (N, 1))

tensor([[1., 2.],
        [4., 5.]])

Use offset to slide it.

In [65]:
torch.as_strided(x, (2, 2), (N, 1), 1)

tensor([[2., 3.],
        [5., 6.]])

In [83]:
x.as_strided((2, 2), (N, 1), 3)

tensor([[4., 5.],
        [7., 8.]])

In [71]:
torch.as_strided(x, (2, 2), (N, 1), 4)

tensor([[5., 6.],
        [8., 9.]])