Comparison of physical convolution and Fourier convolution from: https://stackoverflow.com/a/60584560/

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from timeit import default_timer
import pandas as pd
import warnings

In [2]:
def conv2d_pyt(f, g):
    assert len(f.size()) == 3
    assert len(g.size()) == 3

    f_new = f.unsqueeze(0)
    g_new = g.unsqueeze(0)

    pad_y = (g.size(1) - 1) // 2
    pad_x = (g.size(2) - 1) // 2

    fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
    return fcg[0, 0, :, :]

def conv2d_fft(f_new, F_g):
    assert len(f.size()) == 3
    assert len(g.size()) == 3

    # take fft of both f and g
    F_f = torch.fft.rfft2(f_new, dim=[-1,-2])

    # complex multiply
    FxG = F_f * F_g

    # sum over channels
    FxG = FxG.sum(0)

    # inverse fft
    fcg = torch.fft.irfft2(FxG, f_new.shape[1:], dim=[-1, -2])

    f_pad_y = (f_new.size(1) - f.size(1)) // 2
    f_pad_x = (f_new.size(2) - f.size(2)) // 2

    # crop center before returning
    return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]

# function for pre-processing the input to conv2d_fft so that the output is
# identical to conv2d_pyt
# also transforms the kernel, since we directly learn the transform of the
# kernel
def fft_preprocess(f, g):
    assert len(f.size()) == 3
    assert len(g.size()) == 3

    device = f.device

    size_y = f.size(1) + g.size(1) - 1
    size_x = f.size(2) + g.size(2) - 1

    f_new = torch.zeros((f.shape[0], size_y, size_x)).to(device)
    g_new = torch.zeros((f.shape[0], size_y, size_x)).to(device)

    # copy f to center
    f_pad_y = (f_new.size(1) - f.size(1)) // 2
    f_pad_x = (f_new.size(2) - f.size(2)) // 2
    f_new[:, f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f

    # anchor of g is 0,0 (flip g and wrap circular)
    g_center_y = g.size(1) // 2
    g_center_x = g.size(2) // 2
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        g_y, g_x = torch.meshgrid(torch.arange(g.size(1)), torch.arange(g.size(2)))
    g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(1)
    g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(2)
    g_new[:, g_new_y, g_new_x] = g[:, g_y, g_x]

    # transform the kernel
    F_g = torch.fft.rfft2(g_new, dim=[-1,-2])

    return f_new, F_g

The output for these convolutions is identical:

In [3]:
# calculate f*g
torch.manual_seed(1)
num_channels = 20
H = W = 64
kernel_size = 3
f = torch.randn(num_channels, H, W).cuda()
g = torch.randn(num_channels, kernel_size, kernel_size).cuda()

fcg_pyt = conv2d_pyt(f, g)
f_new, F_g = fft_preprocess(f, g)
fcg_fft = conv2d_fft(f_new, F_g)

loss = lambda x, y: ((x - y).norm() / x.norm()).item()

avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()

print('Average difference:', avg_diff)
print('L2 relative loss:', loss(fcg_pyt, fcg_fft))

Average difference: 4.117644039070001e-06
L2 relative loss: 4.2855637616412423e-07


Time complexity analysis:

In [4]:
def time_conv(f, g):

    start = default_timer()
    fcg_pyt = conv2d_pyt(f, g)
    time_pyt = (default_timer() - start) * 1000

    f_new, F_g = fft_preprocess(f, g)
    start = default_timer()
    fcg_fft = conv2d_fft(f_new, F_g)
    time_fft = (default_timer() - start) * 1000

    diff = loss(fcg_pyt, fcg_fft)

    return {"Physical Convolution":time_pyt, "Fourier Convolution":time_fft, "diff":diff}

In [5]:
torch.manual_seed(1)
times = {"Physical Convolution":[], "Fourier Convolution":[]}
n = 10000
max_diff = 0
for _ in range(n):

    f = torch.randn(num_channels, H, W).cuda()
    g = torch.randn(num_channels, kernel_size, kernel_size).cuda()
    trial = time_conv(f, g)
    times["Physical Convolution"].append(trial["Physical Convolution"])
    times["Fourier Convolution"].append(trial["Fourier Convolution"])
    if trial["diff"] > max_diff:
        max_diff = trial["diff"]

for conv in times.keys():
    times[conv] = np.array(times[conv])
    mu = times[conv].mean()
    sigma = times[conv].std()
    times[conv] = [f"{mu.round(3)}({sigma.round(3)})"]

times, max_diff

({'Physical Convolution': ['0.131(0.029)'],
  'Fourier Convolution': ['0.143(0.025)']},
 4.6969438471933245e-07)

In [6]:
df = pd.DataFrame(times)
df.index = ["Time (ms)"]
print(df.to_markdown())

|           | Physical Convolution   | Fourier Convolution   |
|:----------|:-----------------------|:----------------------|
| Time (ms) | 0.131(0.029)           | 0.143(0.025)          |
