# FFT Convolution Demo

In [1]:
import torch
import torch.nn.functional as f

from fft_conv import fft_conv1d, fft_conv2d, fft_conv3d, FFTConv1d

### Example Usage -- 1D Convolutions

In [2]:
# Create dummy data.  
#     Data shape: (batch, channels, length)
#     Kernel shape: (out_channels, in_channels, kernel_size)
# For ordinary 1D convolution, simply set batch=1.
signal = torch.randn(3, 3, 1024 * 1024)
kernel = torch.randn(2, 3, 128)
bias = torch.randn(2)

# Functional execution.  (Easiest for generic use cases.)
out = fft_conv1d(signal, kernel, bias=bias)

# Object-oriented execution.  (Requires some extra work, since the 
# defined classes were designed for use in neural networks.)
fft_conv = FFTConv1d(3, 2, 128, bias=True)
fft_conv.weight = torch.nn.Parameter(kernel)
fft_conv.bias = torch.nn.Parameter(bias)
out = fft_conv(signal)

### 1D Test

In [3]:
# Create dummy data
kernel_size = 255
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size)
signal = torch.randn(3, 3, 1024 * 1024)
bias = torch.randn(2)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv1d(signal, kernel, bias=bias, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv1d(signal, kernel, bias=bias, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

# Compute relative error
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 2.88 s, sys: 20.3 ms, total: 2.9 s
Wall time: 518 ms
--- FFT Convolution ---
CPU times: user 1.96 s, sys: 96.3 ms, total: 2.06 s
Wall time: 402 ms

Input shape: torch.Size([3, 3, 1048576])
Output shape: torch.Size([3, 2, 1048576])

Abs Error Mean: 9.550E-06
Abs Error Std Dev: 7.327E-06


### 2D Test

In [4]:
# Create dummy data
kernel_size = 11
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size, kernel_size)
signal = torch.randn(3, 3, 1024, 1024)
bias = torch.randn(2)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv2d(signal, kernel, bias=bias, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv2d(signal, kernel, bias=bias, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

# Compute relative error
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 1.54 s, sys: 12.1 ms, total: 1.55 s
Wall time: 262 ms
--- FFT Convolution ---
CPU times: user 617 ms, sys: 54.8 ms, total: 672 ms
Wall time: 116 ms

Input shape: torch.Size([3, 3, 1024, 1024])
Output shape: torch.Size([3, 2, 1024, 1024])

Abs Error Mean: 4.872E-06
Abs Error Std Dev: 3.803E-06


### 3D Test

In [5]:
# Create dummy data
kernel_size = 3
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size, kernel_size, kernel_size)
signal = torch.randn(3, 3, 128, 128, 128)
bias = torch.randn(2)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv3d(signal, kernel, bias=bias, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv3d(signal, kernel, bias=bias, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 2.5 s, sys: 435 ms, total: 2.94 s
Wall time: 555 ms
--- FFT Convolution ---
CPU times: user 563 ms, sys: 118 ms, total: 682 ms
Wall time: 115 ms

Input shape: torch.Size([3, 3, 128, 128, 128])
Output shape: torch.Size([3, 2, 128, 128, 128])

Abs Error Mean: 2.216E-06
Abs Error Std Dev: 1.731E-06
