# FFT Convolution Demo

In [1]:
import torch
import torch.nn.functional as f

from fft_conv import fft_conv1d, fft_conv2d, fft_conv3d, FFTConv1d

### Example Usage -- 1D Convolutions

In [2]:
# Create dummy data.  
#     Data shape: (batch, channels, length)
#     Kernel shape: (out_channels, in_channels, kernel_size)
# For ordinary 1D convolution, simply set batch=1.
signal = torch.randn(3, 3, 1024 * 1024)
kernel = torch.randn(2, 3, 128)

# Functional execution.  (Easiest for generic use cases.)
out = fft_conv1d(signal, kernel)

# Object-oriented execution.  (Requires some extra work, since the 
# defined classes were designed for use in neural networks.)
fft_conv = FFTConv1d(1, 1, 128, bias=False)
fft_conv.weight = torch.nn.Parameter(kernel)
out = fft_conv(signal)

### 1D Test

In [3]:
# Create dummy data
kernel_size = 255
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size)
signal = torch.randn(3, 3, 1024 * 1024)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv1d(signal, kernel, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv1d(signal, kernel, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

# Compute relative error
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 2.75 s, sys: 22.1 ms, total: 2.77 s
Wall time: 477 ms
--- FFT Convolution ---
CPU times: user 1.88 s, sys: 80.1 ms, total: 1.96 s
Wall time: 380 ms

Input shape: torch.Size([3, 3, 1048576])
Output shape: torch.Size([3, 2, 1048576])

Abs Error Mean: 9.582E-06
Abs Error Std Dev: 7.355E-06


### 2D Test

In [4]:
# Create dummy data
kernel_size = 11
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size, kernel_size)
signal = torch.randn(3, 3, 1024, 1024)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv2d(signal, kernel, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv2d(signal, kernel, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

# Compute relative error
abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 1.42 s, sys: 18.9 ms, total: 1.44 s
Wall time: 241 ms
--- FFT Convolution ---
CPU times: user 676 ms, sys: 53.1 ms, total: 729 ms
Wall time: 125 ms

Input shape: torch.Size([3, 3, 1024, 1024])
Output shape: torch.Size([3, 2, 1024, 1024])

Abs Error Mean: 4.800E-06
Abs Error Std Dev: 3.746E-06


### 3D Test

In [5]:
# Create dummy data
kernel_size = 3
padding = kernel_size // 2
kernel = torch.randn(2, 3, kernel_size, kernel_size, kernel_size)
signal = torch.randn(3, 3, 128, 128, 128)

# Perform both direct and FFT convolutions
print('--- Direct Convolution ---')
%time y0 = f.conv3d(signal, kernel, padding=padding)
print('--- FFT Convolution ---')
%time y1 = fft_conv3d(signal, kernel, padding=padding)

# Print input/output tensor shapes
print(f'\nInput shape: {signal.shape}')
print(f'Output shape: {y0.shape}')

abs_error = torch.abs(y0 - y1)
print(f'\nAbs Error Mean: {abs_error.mean():.3E}')
print(f'Abs Error Std Dev: {abs_error.std():.3E}')

--- Direct Convolution ---
CPU times: user 2.24 s, sys: 598 ms, total: 2.84 s
Wall time: 508 ms
--- FFT Convolution ---
CPU times: user 475 ms, sys: 200 ms, total: 676 ms
Wall time: 114 ms

Input shape: torch.Size([3, 3, 128, 128, 128])
Output shape: torch.Size([3, 2, 128, 128, 128])

Abs Error Mean: 2.464E-06
Abs Error Std Dev: 1.956E-06
