# FFT Convolution Demo

In [1]:
import torch
import torch.nn.functional as f

from fft_conv import fft_conv_nd, FFTConv1d

### Example Usage -- 1D Convolutions

In [2]:
# Create dummy data.  
#     Data shape: (batch, channels, length)
#     Kernel shape: (out_channels, in_channels, kernel_size)
# For ordinary 1D convolution, simply set batch=1.
signal = torch.randn(3, 3, 1024 * 1024)
kernel = torch.randn(2, 3, 128)
bias = torch.randn(2)

# Functional execution.  (Easiest for generic use cases.)
out = fft_conv_nd(signal, kernel, bias=bias)

# Object-oriented execution.  (Requires some extra work, since the 
# defined classes were designed for use in neural networks.)
fft_conv = FFTConv1d(3, 2, 128, bias=True)
fft_conv.weight = torch.nn.Parameter(kernel)
fft_conv.bias = torch.nn.Parameter(bias)
out = fft_conv(signal)

### 1D Test

In [3]:
# Create dummy data
kernel_size = 255
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size)
signal = torch.randn(3, 3, 1024 * 1024)
bias = torch.randn(2)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv1d(signal, kernel, bias=bias, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv_nd(signal, kernel, bias=bias, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

# Compute relative error
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 4.18 s, sys: 11.7 ms, total: 4.19 s
Wall time: 715 ms
--- FFT Convolution ---
CPU times: user 2.06 s, sys: 98.8 ms, total: 2.16 s
Wall time: 417 ms

Input shape: torch.Size([3, 3, 1048576])
Output shape: torch.Size([3, 2, 1048576])

Abs Error Mean: 9.603E-06
Abs Error Std Dev: 7.366E-06


### 2D Test

In [4]:
# Create dummy data
kernel_size = 11
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size, kernel_size)
signal = torch.randn(3, 3, 1024, 1024)
bias = torch.randn(2)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv2d(signal, kernel, bias=bias, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv_nd(signal, kernel, bias=bias, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

# Compute relative error
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 2.01 s, sys: 2.7 ms, total: 2.01 s
Wall time: 347 ms
--- FFT Convolution ---
CPU times: user 736 ms, sys: 60 ms, total: 796 ms
Wall time: 136 ms

Input shape: torch.Size([3, 3, 1024, 1024])
Output shape: torch.Size([3, 2, 1024, 1024])

Abs Error Mean: 4.937E-06
Abs Error Std Dev: 3.854E-06


### 3D Test

In [7]:
# Create dummy data
kernel_size = 3
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size, kernel_size, kernel_size)
signal = torch.randn(3, 3, 100, 300, 300)
bias = torch.randn(2)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv3d(signal, kernel, bias=bias, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv_nd(signal, kernel, bias=bias, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 11.1 s, sys: 2.85 s, total: 14 s
Wall time: 2.55 s
--- FFT Convolution ---
CPU times: user 5.18 s, sys: 853 ms, total: 6.03 s
Wall time: 1.02 s

Input shape: torch.Size([3, 3, 100, 300, 300])
Output shape: torch.Size([3, 2, 100, 300, 300])

Abs Error Mean: 3.606E-06
Abs Error Std Dev: 2.794E-06


In [6]:
net = torch.nn.Sequential(
    FFTConv1d(1, 3, 101, padding=50),
    FFTConv1d(3, 3, 101, padding=50),
    FFTConv1d(3, 1, 101, padding=50),
)
output = net(torch.randn(1, 1, 1024))
loss = output.sum()
loss.backward()