# Convolution Demo

In [None]:
import itertools
import torch
import matplotlib.pyplot as plt

torch.set_printoptions(precision=3, sci_mode=False)

## Definition of convolution

Let $\phi(x)$ and $w(x)$ be real-valued scalar functions defined on $\mathbb{R}^2$ which decay to zero as $x \to \pm \infty$.
Their convolution is another function defined by the integral transform,

\begin{align}
    (w \ast \phi)(x) = \int_{\mathbb{R}^2} \mathrm{d} y \, w(y) \phi(x - y) \, .
\end{align}

We will take the view that $\phi(x)$ is the function being transformed, and $w(y)$ is the **kernel** of the transform.
The discretized version of this operation is

\begin{align}
    (w \ast \phi)(n) = \sum_{m \in \mathbb{Z}^2} w(m) \phi(n - m) \, .
\end{align}

We are interested in the situation where the domain is a square lattice $\Lambda \subset a \mathbb{N}^2$, which is periodic in both dimensions with period $aL$, where $a$ is the lattice spacing.
On this domain the convolution can be written

\begin{align}
    (w \ast \phi)(an) = \sum_{m_1, m_2=0}^{L-1} w(am) \phi\big(a(n-m)_{\text{mod} L}\big) \, .
\end{align}

Let us set the lattice spacing to $a = 1$ so that $x \equiv n$ and $y \equiv m$.

\begin{align} \label{eq:conv_layer}
    (w \ast \phi)(x) = \sum_{y_1, y_2=0}^{L} w(y) \phi\big((x-y)_{\text{mod} L}\big) \, .
\end{align}

As it happens, the traditional conventional convolutional neural network performs a cross-correlation '$\star$' rather than a convolution '$\ast$'.
For real-valued functions this differs by nothing more than a reflection in $\phi$;

\begin{align}
    (w \star \phi)(x) = \sum_{y_1, y_2=0}^{L} w(y) \phi\big((x+y)_{\text{mod} L}\big) \, .
\end{align}

Usually we want a convolutional layer to encode some local structure in the data.
This leads to kernels that are nonzero only within some radius $K$ of $y=(0, 0)$.
For example, on a square lattice, a kernel with radius $K=1$ is represented by the following matrix

\begin{align}
w(y) = \begin{pmatrix}
\ddots \\
& 0 & 0 & 0 & 0 & 0 \\
& 0 & 0 & w(L, 0) & 0 & 0 \\
& 0 & w(0, L) & w(0, 0) & w(0, 1) & 0 \\
& 0 & 0 & w(1, 0) & 0 & 0 \\
& 0 & 0 & 0 & 0 & 0 & \\
&&&&&&\ddots
\end{pmatrix}
\end{align}

Discrete convolution or cross correlation can be expressed as a matrix-vector product

$$
w \star \phi \equiv W . \Phi
$$

where the vector $\Phi$ is the flattened $\phi$ and the matrix $W$ is a representation of $w$ as a **Toeplitz matrix** (a matrix with constant diagonals).
For brevity, let $w(y_1, y_2)$ be written $w_{y_1y_2}$.
The $K=1$ kernel described above now becomes

\begin{align}
    W = \begin{pmatrix}
    w_{00} & w_{01} & 0 & \ldots & 0 & w_{0L} & w_{10} & 0 & \ldots & 0 & w_{LL} \\
    w_{LL} & w_{00} & w_{01} & 0 & \ldots & 0 & w_{0L} & w_{10} &  0 & \ldots & 0\\
    0 & w_{LL} & w_{00} & w_{01} & 0 & \ldots & 0 & w_{0L} & w_{10} & 0 & \ldots \\
    & & & & & \ddots\\
    w_{01} & 0 & \ldots & 0 & w_{0L} & w_{10} & 0 & \ldots & 0 & w_{LL} & w_{00}
    \end{pmatrix}
\end{align}


## PyTorch implementation: `conv2d`

In [None]:
def conv2d_inputs(phi, w):
    """Reshape and pad inputs for conv2d.
    
    conv2d expects shapes:
    
    - input (phi) : (batch size, input channels, width, height)
    - kernel (w)  : (batch size, output channels, kernel width, kernel height)
    
    However, our input is periodic so we also need to pad phi by an amount
    K, which is the kernel radius.
    """
    assert phi.dim() == 2
    assert w.dim() == 2
    K1, K2 = [(k - 1) // 2 for k in w.shape]
    phi = phi.view(1, 1, *phi.shape)
    w = w.view(1, 1, *w.shape)
    phi = torch.nn.functional.pad(
        phi,  # (n_batch, n_channels, width, height)
        pad=(K1, K1, K2, K2),  # pad last 2 dimensions by 1 on each side
        mode="circular",
    )
    return phi, w

### Identity transformation

Note that `conv2d` treats the matrix `w` as though it is centered on $(0, 0)$.
Hence, the identity kernel should have a one in the middle position.

In [None]:
L = 6
phi = torch.empty(L, L).normal_()
w = torch.Tensor([
    [0, 0, 0],
    [0, 1, 0],
    [0, 0, 0],
])
conv_out = torch.nn.functional.conv2d(*conv2d_inputs(phi, w))
assert torch.allclose(conv_out, phi)

### Shift

In [None]:
L = 4
phi = torch.arange(L * L).reshape(L, L).float()
w = torch.Tensor([
    [0, 0, 0],
    [0, 0, 1],
    [0, 0, 0],
])
conv_out = torch.nn.functional.conv2d(*conv2d_inputs(phi, w))

# Check the direction of torch.roll
assert phi.roll(shifts=(-1, -1), dims=(0, 1))[0, 0] == phi[1, 1]

assert torch.allclose(conv_out, phi.roll(-1, dims=1))
print(phi.int().numpy())
print(conv_out.squeeze().int().numpy())

### Verifying that `conv2d` implements cross-correlation

I am trying to check that conv2d implements the cross-correlation calculation defined earlier in this notebook. This requires us to modify the kernel so that it looks like a full L1 x L2 matrix with the (0, 0) element in the top-left corner.
    
For example, conv2d interprets the following tensor as having 'b' in the (0, 0) position, 'a' in the (L, L) position and 'c' in the (1, 1) position:
    
```python
torch.Tensor([
    [a, 0, 0],
    [0, b, 0],
    [0, 0, c]
])
```
However, to match our cross-correlation calculation we require an L1 x L2 tensor that looks like the following:
    
```python
torch.Tensor([
    [b, 0, ..., 0],
    [0, c, ..., 0],
         ...
    [0, 0, ..., a]
])
```
    
Therefore, we need to:
1. Pad the tensor with zeros to make it the same shape as the input
2. Roll the kernel so that the (0, 0) element is in the top-left corner

In [None]:
def pad_and_roll_kernel(w, L1, L2):
    """Pad and roll kernel so that we can compute the cross-correlation.
    
    I am trying to check that conv2d implements the cross-correlation calculation
    defined earlier in this notebook. This requires us to modify the kernel so
    that it looks like a full L1 x L2 matrix with the (0, 0) element in the top-
    left corner.
    
    For example, conv2d interprets the following tensor as having 'b' in the (0, 0)
    position, 'a' in the (L, L) position and 'c' in the (1, 1) position:
    
    >>> [[a, 0, 0],
         [0, b, 0],
         [0, 0, c]]
    
    However, to match our cross-correlation calculation we require an L1 x L2 tensor
    that looks like the following:
    
    >>> [[b, 0, ..., 0],
         [0, c, ..., 0],
         ...
         [0, 0, ..., a]]
    
    Therefore, we need to:
        
        1) Pad the tensor with zeros to make it the same shape as the input
        2) Roll the kernel so that the (0, 0) element is in the top-left corner
    """
    K1, K2 = [(k - 1) // 2 for k in w.shape]
    
    # Pad with zeros to make it the same size as phi
    w = torch.nn.functional.pad(w, (0, L1 - w.shape[0], 0, L2 - w.shape[1]), mode="constant", value=0)
    
    # Roll the kernel
    w = w.roll((-K1, -K2), (0, 1))
    
    return w

L1, L2 = 6, 6
w = torch.Tensor([
    [1, 0, 0],
    [0, 2, 0],
    [0, 0, 3],
]).int()
expected = torch.Tensor([
    [2, 0, 0, 0, 0, 0],
    [0, 3, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1]]
).int()
result = pad_and_roll_kernel(w, L1, L2)
assert torch.equal(result, expected)

In [None]:
L = 4
phi = torch.empty(L, L).normal_()
w = torch.empty(L - 1, L - 1).uniform_()

conv_out = torch.nn.functional.conv2d(*conv2d_inputs(phi, w))

w = pad_and_roll_kernel(w, L, L)

cross_corr = torch.stack(
    [
        w[y1, y2] * phi.roll((-y1, -y2), (0, 1))
        for y1, y2 in itertools.product(range(L), range(L))
    ],
    dim=0,
).sum(dim=0)
assert torch.allclose(cross_corr, conv_out)

# To be explicit
cross_corr_explicit = torch.zeros_like(phi)
for x1 in range(L):
    for x2 in range(L):
        for y1 in range(L):
            for y2 in range(L):
                res = w[y1, y2] * phi[(x1 + y1) % L, (x2 + y2) % L]
                cross_corr_explicit[x1, x2] += res

assert torch.allclose(cross_corr_explicit, conv_out)