## How to construct the Toeplitz matrix?
https://www.baeldung.com/cs/convolution-matrix-multiplication

## How to compute the actual convolution?


1.   Matrix multiplication along the N*N dimension (flattened input, and T)
2.   Summation along the input channels
3.   Redo the operation for each output channels, and concatenate



## Putting it together:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import linalg

In [None]:
# same as: scipy.linalg.toeplitz
def toeplitz_perfilter(row, N, K, stride):
  # N is already padded
  O = int(np.floor(((N - (K - 1) - 1) / stride) + 1))
  # Repeat the kernel matrix J times
  repeated_matrix = row.repeat(O, 1)

  # Pad the matrix to be O*N
  padded_matrix = F.pad(repeated_matrix, (0, N - row.size(0)))

  # Shift them circularry by incrementally increasing amount
  rolled_rows = [torch.roll(padded_matrix, shifts=i*stride, dims=1) for i in range(O)]
  output = torch.stack(rolled_rows)[:, 0, :]

  # Mask out the places of the 0s
  i_indices, j_indices = torch.meshgrid(torch.arange(O), torch.arange(N))
  mask = ((i_indices * stride > j_indices) | (j_indices >= i_indices * stride + row.size()[0]))
  # Apply the mask
  return torch.where(~mask, output, 0)


def toeplitz_perchannel(kernel, input_size, stride=1):
    """
    Output dim: (O**2, N**2)
    where N is already padded
    and O = floor(((N - (K - 1) - 1) / stride) + 1)
    """
    # shapes
    K = kernel.shape[0]
    N = input_size[0]
    O = int(np.floor(((N - (K - 1) - 1) / stride) + 1))

    # from each row of the kernel construct a toeplitz matrix
    W_conv = torch.zeros((O, O, N, N))
    for r in range(K):
        toeplitz = toeplitz_perfilter(kernel[r], N, K, stride)
        #toeplitz2 = torch.tensor(linalg.toeplitz(c=(kernel[r,0], *np.zeros(N-K)), r=(*kernel[r], *np.zeros(N-K))))
        for c in range(O):
          # and create the doubly blocked W
          W_conv[c, :, r+(c*stride), :] = toeplitz

    return W_conv.reshape(O*O, N*N)


def toeplitz_multichannel(kernel, input_size, padding=0, stride=1):
    """Compute toeplitz matrix for 2d conv with multiple in and out channels and batches.
    Input dim: (in_ch, N, N)
    Kernel dim: (out_ch, in_ch, K, K)
    Output dim: (out_ch, in_ch, O**2, N**2)
    where O = floor((N + 2*padding - (K - 1) - 1) / stride + 1)
    """
    # idea is that for each output channel and input channel we want to create a Toeplitz map (reduce to the single channel case)
    N = input_size[-1]
    N_padded = N + 2*padding
    K = kernel.shape[-1]
    O = int(np.floor(((N + 2*padding - (K - 1) - 1) / stride) + 1))
    In_Ch = input_size[0]
    Out_Ch = kernel.shape[0]
    output_size = (Out_Ch, In_Ch, O**2, N_padded**2)
    T = torch.zeros(output_size)
    for i,ks in enumerate(kernel):  # loop over output channel
        for j,k in enumerate(ks):  # loop over input channel
            T_k = toeplitz_perchannel(k, (N_padded, N_padded), stride)
            T[i, j, :, :] = T_k

    return T


def toeplitz_multiply(W, B, X, output_dim):
  out_filters = []

  for Wo in W: # iterate through the output channels
    in_channels = []
    for i in range(Wo.shape[0]): # iterate through the input channels
      Xi = X[i, :]
      Wi = Wo[i, :, :]
      in_channels.append(Wi @ Xi) # matmul as Conv2d for a single channel

    F = torch.stack(in_channels)
    out_filters.append(torch.sum(F, dim=0)) # sum over input channels

  O = torch.stack(out_filters).reshape(-1, output_dim, output_dim) # stack the output channels
  B = B.reshape(-1, 1, 1)
  print(W.shape)
  print(B.shape)
  return O + B # add the bias


### Multi channel case

In [None]:
# Define a simple convolutional layer
padding = 1
stride = 2
conv_layer = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=2, stride=stride, padding=padding)
# Extract weights and bias
weights = conv_layer.weight.data
bias = conv_layer.bias.data

#Define the input with size: (batch_size, channels, height, width)
input_size = (1, 3, 3, 3)
I = torch.randn(input_size)
output_dim = int(np.floor((I.shape[-1] + 2*padding - (weights.shape[-1] - 1) - 1) / stride + 1))

# Create the Toeplitz matrix transformation
W = toeplitz_multichannel(weights, I.shape[1:], padding, stride)

# Pad X accordingly
X = F.pad(I, (padding, padding, padding, padding), "constant", 0)

X = X.view(input_size[1], -1) # flatten 2D -> 1D

# Conv2d as matrix multiplication
r1 = toeplitz_multiply(W, bias, X, output_dim)

# Compare the results
r2 = conv_layer(I)

print("Differenence between the outputs: ", torch.norm(r1-r2))

torch.Size([16, 3, 4, 25])
torch.Size([16, 1, 1])
Differenence between the outputs:  tensor(2.9013e-07, grad_fn=<LinalgVectorNormBackward0>)


### Single channel case

In [None]:
# Define a simple convolutional layer
conv_layer = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=1, padding=0)
# Extract weights and bias
weights = conv_layer.weight.data
bias = conv_layer.bias.data

#Define the input with size: (batch_size, channels, height, width)
input_size = (1, 1, 3, 3)
I = torch.randn(input_size)
output_dim = I.shape[-1] - weights.shape[-1] + 1

W = toeplitz_multichannel(weights, I.shape[1:])
X = I.view(input_size[1], -1)

r1 = toeplitz_multiply(W, bias, X, output_dim)
r2 = conv_layer(I)
print("Difference between the two outputs: ", torch.norm(r1-r2))

torch.Size([1, 1, 4, 9])
torch.Size([1, 1, 1])
Difference between the two outputs:  tensor(4.2147e-08, grad_fn=<LinalgVectorNormBackward0>)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
