In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch, os, math
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io

from torch.utils.cpp_extension import load_inline
from types import SimpleNamespace as ns

%matplotlib widget

In [2]:
def check_output(output, filter_td):
    for filter_idx, filter_part in enumerate(filter_td):
        output_ref = np.convolve(signal, filter_part)
        print("Results are equal:", ((output[filter_idx, :-1] - output_ref) ** 2 < 1e-8).all())

In [3]:
C = 10 # Number of channels
B = 256  # Block length
FL = 10 * B  # Filter length
K = FL // B  # Number of partitions
I = 1000  # Number of input blocks

output = []

# Simulate input streaming by dividing the signal into batches of length B
# signal = np.random.randn(I * B).astype(np.float64)
signal = np.zeros(I * B).astype(np.float64)
signal[0] = 1
signal_batch = signal.reshape(I, B)
# Add batches of zeros to ensure shape is (4 + 16, B)
signal_batch = np.append(signal_batch, np.zeros((K, B)), axis=0)

# Define the frequency-domain delay line (FDL)
fdl = np.zeros((B + 1, K), dtype=np.complex64)
fdl_cursor = 0

In [None]:
# Define the filter as a random array
filter_td = np.random.randn(C, FL).astype(np.float64)

# Partition the filter into blocks of length B
filter_parts = np.pad(filter_td, ((0, 0), (0, K * B - FL)), mode='constant').reshape(C, B, K, order='F')

assert (filter_parts[0, :B, 0] == filter_td[0, :B]).all()

# Partition the filter into blocks of length B, and zero-pad another B samples
filters_padded = np.pad(filter_parts, ((0, 0), (0, B), (0, 0)), mode='constant')  # shape: (C, 2 * B, K)

# Compute the RFFT of the filters (real-to-complex FFT)
filters_fd = np.fft.rfft(filters_padded, axis=1)  # shape: (C, B + 1, K)

In [5]:
# ============================================================================# %% Simulate input streaming by looping over the batches
output = []
input_buffer_td = np.zeros(2 * B)
for signal_td in signal_batch:

    # 2. Packing and RFFT of the signal (online)
    # ============================================================================
    # # Partition the signal into blocks of length B, and zero-pad another B samples
    # signal_td_padded = np.pad(signal_td, (0, B), mode='constant')  # shape: (2 * B,)

    # Put the incoming signal in the input buffer after sliding the previous signal
    input_buffer_td[:B] = input_buffer_td[B:]
    input_buffer_td[B:] = signal_td

    # Compute the RFFT of the signals (real-to-complex FFT)
    input_fd = np.fft.rfft(input_buffer_td)  # shape: (B + 1)

    # The following code is executed on the GPU
    # ============================================================================
    # Store the fd signal in a frequency-domain delay line
    fdl[:, fdl_cursor] = input_fd
    
    # Perform the complex multiplication between the fdl and the filter partitions
    output_fd = np.zeros((C, B+1), dtype=np.complex64)
    for k in range(K):
        cursor = (fdl_cursor - k + K) % K
        output_fd += fdl[np.newaxis, :, cursor] * filters_fd[:, :, k]

    fdl_cursor = (fdl_cursor + 1) % K  # Update the index
    # ============================================================================
    # The following code is executed on the CPU

    # Perform the inverse RFFT to obtain the output signal
    output_td = np.fft.irfft(output_fd, axis=1)  # shape: (C, 2 * B,)
    # Only keep the first B samples
    output.append(output_td[:, B:].T)

output = np.array(output).reshape(-1, C).T

In [6]:
check_output(output, filter_td)

Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True


# GPU Simulation in Python

In [None]:
# Simulate kernel execution
def blk_kernel(f, blocks, threads, *args):
    for i in range(blocks):
        for j in range(threads): f(i, j, threads, *args)

def part_conv_gpu(input_fd, fdl_cursor):
    # define the output spectrum
    output_fd = torch.empty((C * (B + 1)), dtype=torch.complex64)

    threads = 256
    blocks = math.ceil(C * (B + 1) / threads)
    
    # Store the fd signal in a frequency-domain delay line
    fdl[:, fdl_cursor] = input_fd
    blk_kernel(conv_kernel, blocks, threads, fdl.ravel(), filters_fd.ravel(), fdl_cursor, output_fd,
               K, B, C)

    fdl_cursor = (fdl_cursor + 1) % K  # Update the index

    return output_fd.reshape(C, B + 1), fdl_cursor

def conv_kernel(blockidx, threadidx, blockdim, fdl, filters_fd,
                fdl_cursor, output_fd, K, B, C):
    thread_id = blockidx * blockdim + threadidx
    
    if thread_id >= C * (B + 1):
        return

    channel_id = thread_id // (B + 1)
    bin_id = thread_id % (B + 1)
    cursor = fdl_cursor

    fdl_offset = bin_id * K
    filter_offset = channel_id * ((B + 1) * K) + bin_id * K
    output_offset = channel_id * (B + 1) + bin_id

    out = 0
    for k in range(K):
        out += fdl[fdl_offset + cursor] * filters_fd[filter_offset + k]
        cursor = (cursor - 1 + K) % K
    
    output_fd[output_offset] = out
    

In [8]:
# ============================================================================
# %% Simulate input streaming by looping over the batches
fdl = np.zeros((B + 1, K), dtype=np.complex64)
fdl_cursor = 0

input_buffer_td = np.zeros(2 * B)
output = []

for signal_td in signal_batch:

    # 2. Packing and RFFT of the signal (online)
    # ============================================================================
    # # Partition the signal into blocks of length B, and zero-pad another B samples
    # signal_td_padded = np.pad(signal_td, (0, B), mode='constant')  # shape: (2 * B,)

    # Put the incoming signal in the input buffer after sliding the previous signal
    input_buffer_td[:B] = input_buffer_td[B:]
    input_buffer_td[B:] = signal_td

    # Compute the RFFT of the signals (real-to-complex FFT)
    input_fd = np.fft.rfft(input_buffer_td)  # shape: (B + 1)

    # The following code is executed on the GPU
    # ============================================================================
    # Perform the complex multiplication between the fdl and the filter partitions
    output_fd, fdl_cursor = part_conv_gpu(input_fd, fdl_cursor)
    assert output_fd.isnan().any() == False
    # ============================================================================
    # The following code is executed on the CPU

    # Perform the inverse RFFT to obtain the output signal
    output_td = np.fft.irfft(output_fd, axis=1)  # shape: (C, 2 * B)
    # Only keep the first B samples and append to the output
    output.append(output_td[:, B:].T)

output = np.array(output).reshape(-1, C).T

In [9]:
check_output(output, filter_td)

Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True


# GPU implementation in CUDA

In [10]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'
%load_ext wurlitzer

In [11]:
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs,
                       extra_cuda_cflags=["-O2"] if opt else [], verbose=verbose, name="inline_ext")

In [12]:
cuda_begin = r'''
using namespace torch::indexing;
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

In [None]:
cuda_src = cuda_begin + r'''
__global__ void conv_kernel(const c10::complex<double>* fdl, const c10::complex<double>* filters_fd, int fdl_cursor, c10::complex<double>* output_fd, int K, int B, int C) {
    
    const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (thread_id >= C * (B + 1)) return;

    const int channel_id = thread_id / (B + 1);
    const int bin_id = thread_id % (B + 1);
    int cursor = fdl_cursor;

    const int fdl_offset = bin_id * K;
    const int filter_offset = channel_id * ((B + 1) * K) + bin_id * K;
    const int output_offset = channel_id * (B + 1) + bin_id;

    c10::complex<double> out = 0;
    for (int k = 0; k < K; ++k) {
        out += fdl[fdl_offset + cursor] * filters_fd[filter_offset + k];
        cursor = (cursor - 1 + K) % K;
    }
    output_fd[output_offset] = out;
}

torch::Tensor part_conv_gpu(torch::Tensor input_fd, torch::Tensor fdl, torch::Tensor filters_fd, int fdl_cursor, int K, int B, int C) {
    CHECK_INPUT(input_fd);
    CHECK_INPUT(fdl);
    CHECK_INPUT(filters_fd);

    auto output_fd = torch::empty({C, B+1}, input_fd.options());

    int threads = 256;
    int blocks = cdiv(C * (B + 1), threads);

    // Store the fd signal in a frequency-domain delay line
    fdl.index_put_({Slice(0, B+1), fdl_cursor}, input_fd);

    conv_kernel<<<blocks, threads>>>(fdl.data_ptr<c10::complex<double>>(), filters_fd.data_ptr<c10::complex<double>>(), fdl_cursor, output_fd.data_ptr<c10::complex<double>>(), K, B, C);
    
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    return output_fd;
}
'''

In [None]:
cpp_src = "torch::Tensor part_conv_gpu(torch::Tensor input_fd, torch::Tensor fdl, torch::Tensor filters_fd, int fdl_cursor, int K, int B, int C);"

module = load_cuda(cuda_src, cpp_src, ['part_conv_gpu'], verbose=True)

Using /home/dspuser/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/dspuser/.cache/torch_extensions/py311_cu124/inline_ext/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module inline_ext...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module inline_ext...


In [15]:
# Move the filters to the GPU
filters_fd_gpu = torch.tensor(filters_fd, dtype=torch.complex128, device='cuda').contiguous()
fdl_gpu = torch.zeros((B + 1, K), dtype=torch.complex128, device='cuda').contiguous()

In [16]:
# ===========================================================
# Reset parameters
fdl_cursor = 0
input_buffer_td = np.zeros(2 * B)
output = []

# Simulate input streaming by looping over the batches
for signal_td in signal_batch:
    # Put the incoming signal in the input buffer after sliding the previous signal
    input_buffer_td[:B] = input_buffer_td[B:]
    input_buffer_td[B:] = signal_td

    # Compute the RFFT of the signals (real-to-complex FFT)
    input_fd = np.fft.rfft(input_buffer_td)  # shape: (B + 1)

    # The following code is executed on the GPU
    # =================================================================
    
    # Move the input to the GPU
    input_fd_gpu = torch.tensor(input_fd, dtype=torch.complex128, device='cuda').contiguous()

    # Perform the complex multiplication between the fdl and the filter partitions
    output_fd = module.part_conv_gpu(input_fd_gpu, fdl_gpu, filters_fd_gpu, fdl_cursor, K, B, C)
    fdl_cursor = (fdl_cursor + 1) % K  # Update the index

    # ============================================================================
    # The following code is executed on the CPU

    # Perform the inverse RFFT to obtain the output signal
    output_td = np.fft.irfft(output_fd.cpu(), axis=1)  # shape: (C, 2 * B)
    # Only keep the first B samples and append to the output
    output.append(output_td[:, B:].T)

output = np.array(output).reshape(-1, C).T


In [17]:
# Check the output
check_output(output, filter_td)

Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
Results are equal: True
