In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch, os, math
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io

from torch.utils.cpp_extension import load_inline
from types import SimpleNamespace as ns

%matplotlib widget

In [None]:
def check_output(output, filter_td):
    for filter_idx, (filter_part, signal_part) in enumerate(zip(filter_td, signal)):
        output_ref = np.convolve(signal_part, filter_part)
        print("Results are equal:", ((output[filter_idx, :] - output_ref[:output.shape[1]]) ** 2 < 1e-8).all())

In [None]:
C = 10 # Number of filter channels
C_in = 1 # Number of input channels
B = 256  # Block length
FL = 15 * B  # Filter length
K = FL // B  # Number of partitions
I = 1000  # Number of input blocks

output = []

# Define the signal batch
signal = np.random.randn(C_in, B * I)
signal_batch = signal.reshape(C_in, I, B).swapaxes(0, 1)

# Define the frequency-domain delay line (FDL)
fdl = np.zeros((C_in, B + 1, K), dtype=np.complex64)
fdl_cursor = 0

In [None]:
# Define the filter as a random array
filter_td = np.random.randn(C, FL).astype(np.float64)

# Partition the filter into blocks of length B
filter_parts = np.pad(filter_td, ((0, 0), (0, K * B - FL)), mode='constant').reshape(C, B, K, order='F')

assert (filter_parts[0, :B, 0] == filter_td[0, :B]).all()

# Partition the filter into blocks of length B, and zero-pad another B samples
filters_padded = np.pad(filter_parts, ((0, 0), (0, B), (0, 0)), mode='constant')  # shape: (C, 2 * B, K)

# Compute the RFFT of the filters (real-to-complex FFT)
filters_fd = np.fft.rfft(filters_padded, axis=1)  # shape: (C, B + 1, K)

In [None]:
# ============================================================================# %% Simulate input streaming by looping over the batches
output = []
input_buffer_td = np.zeros((C_in, 2 * B))
for signal_td in signal_batch:

    # 2. Packing and RFFT of the signal (online)
    # ============================================================================
    # # Partition the signal into blocks of length B, and zero-pad another B samples
    # signal_td_padded = np.pad(signal_td, (0, B), mode='constant')  # shape: (2 * B,)

    # Put the incoming signal in the input buffer after sliding the previous signal
    input_buffer_td[:, :B] = input_buffer_td[:, B:]
    input_buffer_td[:, B:] = signal_td

    # Compute the RFFT of the signals (real-to-complex FFT)
    input_fd = np.fft.rfft(input_buffer_td, axis=1)  # shape: (C, B + 1)

    # The following code is executed on the GPU
    # ============================================================================
    # Store the fd signal in a frequency-domain delay line
    fdl[:, :, fdl_cursor] = input_fd
    
    # Perform the complex multiplication between the fdl and the filter partitions
    output_fd = np.zeros((C, B + 1), dtype=np.complex64)
    for k in range(K):
        cursor = (fdl_cursor - k + K) % K
        output_fd += fdl[:, :, cursor] * filters_fd[:, :, k]

    fdl_cursor = (fdl_cursor + 1) % K  # Update the index
    # ============================================================================
    # The following code is executed on the CPU

    # Perform the inverse RFFT to obtain the output signal
    output_td = np.fft.irfft(output_fd, axis=1)  # shape: (C, 2 * B,)
    # Only keep the first B samples
    output.append(output_td[:, B:].T)

output = np.array(output).reshape(-1, C).T

In [None]:
check_output(output, filter_td)

# GPU Simulation in Python

In [None]:
# Simulate kernel execution
def blk_kernel(f, blocks, threads, *args):
    for i in range(blocks):
        for j in range(threads): f(i, j, threads, *args)

def part_conv_gpu(input_fd, fdl_cursor):
    # define the output spectrum
    output_fd = torch.empty((C * (B + 1)), dtype=torch.complex64)

    threads = 256
    blocks = math.ceil(C * (B + 1) / threads)
    
    # Store the fd signal in a frequency-domain delay line
    fdl[:, :, fdl_cursor] = input_fd
    blk_kernel(conv_kernel, blocks, threads, fdl.ravel(), filters_fd.ravel(), fdl_cursor, output_fd, K, B, C)

    fdl_cursor = (fdl_cursor + 1) % K  # Update the index

    return output_fd.reshape(C, B + 1), fdl_cursor

def conv_kernel(blockidx, threadidx, blockdim, fdl, filters_fd,
                fdl_cursor, output_fd, K, B, C):
    thread_id = blockidx * blockdim + threadidx
    
    if thread_id >= C * (B + 1):
        return

    channel_id = thread_id // (B + 1)
    bin_id = thread_id % (B + 1)
    cursor = fdl_cursor

    fdl_offset = bin_id * K
    filter_offset = channel_id * ((B + 1) * K) + bin_id * K
    output_offset = channel_id * (B + 1) + bin_id

    out = 0
    for k in range(K):
        out += fdl[fdl_offset + cursor] * filters_fd[filter_offset + k]
        cursor = (cursor - 1 + K) % K
    
    output_fd[output_offset] = out
    

In [None]:
# ============================================================================
# %% Simulate input streaming by looping over the batches
fdl = np.zeros((C_in, B + 1, K), dtype=np.complex64)
fdl_cursor = 0

input_buffer_td = np.zeros((C_in, 2 * B))
output = []

for signal_td in signal_batch:

    # 2. Packing and RFFT of the signal (online)
    # ============================================================================
    # # Partition the signal into blocks of length B, and zero-pad another B samples
    # signal_td_padded = np.pad(signal_td, (0, B), mode='constant')  # shape: (2 * B,)

    # Put the incoming signal in the input buffer after sliding the previous signal
    input_buffer_td[:, :B] = input_buffer_td[:, B:]
    input_buffer_td[:, B:] = signal_td

    # Compute the RFFT of the signals (real-to-complex FFT)
    input_fd = np.fft.rfft(input_buffer_td, axis=1)  # shape: (B + 1)

    # The following code is executed on the GPU
    # ============================================================================
    # Perform the complex multiplication between the fdl and the filter partitions
    output_fd, fdl_cursor = part_conv_gpu(input_fd, fdl_cursor)
    assert output_fd.isnan().any() == False
    # ============================================================================
    # The following code is executed on the CPU

    # Perform the inverse RFFT to obtain the output signal
    output_td = np.fft.irfft(output_fd, axis=1)  # shape: (C, 2 * B)
    # Only keep the first B samples and append to the output
    output.append(output_td[:, B:].T)

output = np.array(output).reshape(-1, C).T

In [None]:
check_output(output, filter_td)

# GPU implementation in CUDA

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING']='1'
%load_ext wurlitzer

In [None]:
cuda_begin = r'''
using namespace torch::indexing;
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

In [None]:
cuda_src = cuda_begin + r'''

__global__ void conv_kernel(const c10::complex<float>* fdl, const c10::complex<float>* filters_fd, int fdl_cursor, c10::complex<float>* output_fd) {
    
    const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (thread_id >= NUM_CHANNELS * (BLOCK_SIZE + 1)) return;

    const int channel_id = thread_id / (BLOCK_SIZE + 1);
    const int bin_id = thread_id % (BLOCK_SIZE + 1);
    int cursor = fdl_cursor;

    const int fdl_offset = bin_id * NUM_PARTS;
    const int filter_offset = channel_id * ((BLOCK_SIZE + 1) * NUM_PARTS) + bin_id * NUM_PARTS;
    const int output_offset = channel_id * (BLOCK_SIZE + 1) + bin_id;

    c10::complex<float> out = 0;
    for (int k = 0; k < NUM_PARTS; ++k) {
        out += fdl[fdl_offset + cursor] * filters_fd[filter_offset + k];
        cursor = (cursor - 1 + NUM_PARTS) % NUM_PARTS;
    }
    output_fd[output_offset] = out;
}

__global__ void conv_kernel_multi(const c10::complex<float>* fdl, const c10::complex<float>* filters_fd, int fdl_cursor, c10::complex<float>* output_fd) {
    
    const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (thread_id >= NUM_CHANNELS * (BLOCK_SIZE + 1)) return;

    const int channel_id = thread_id / (BLOCK_SIZE + 1);
    const int bin_id = thread_id % (BLOCK_SIZE + 1);
    int cursor = fdl_cursor;

    const int filter_offset = channel_id * ((BLOCK_SIZE + 1) * NUM_PARTS) + bin_id * NUM_PARTS;
    const int output_offset = channel_id * (BLOCK_SIZE + 1) + bin_id;

    c10::complex<float> out = 0;
    for (int k = 0; k < NUM_PARTS; ++k) {
        out += fdl[filter_offset + cursor] * filters_fd[filter_offset + k];
        cursor = (cursor - 1 + NUM_PARTS) % NUM_PARTS;
    }
    output_fd[output_offset] = out;
}

torch::Tensor part_conv_gpu(torch::Tensor input_fd, torch::Tensor fdl, torch::Tensor filters_fd, int fdl_cursor) {
    CHECK_INPUT(input_fd);
    CHECK_INPUT(fdl);
    CHECK_INPUT(filters_fd);

    auto output_fd = torch::empty({NUM_CHANNELS, BLOCK_SIZE+1}, input_fd.options());

    int threads = 256;
    int blocks = cdiv(NUM_CHANNELS * (BLOCK_SIZE + 1), threads);

    if (fdl.dim() == 2) {
        // Store the fd signal in a frequency-domain delay line
        fdl.index_put_({Slice(0, BLOCK_SIZE + 1), fdl_cursor}, input_fd);
        conv_kernel<<<blocks, threads>>>(fdl.data_ptr<c10::complex<float>>(), filters_fd.data_ptr<c10::complex<float>>(), fdl_cursor, output_fd.data_ptr<c10::complex<float>>());

    } else if (fdl.dim() == 3) {
        // Store the fd signal in a frequency-domain delay line
        fdl.index_put_({Slice(), Slice(0, BLOCK_SIZE + 1), fdl_cursor}, input_fd);
        conv_kernel_multi<<<blocks, threads>>>(fdl.data_ptr<c10::complex<float>>(), filters_fd.data_ptr<c10::complex<float>>(), fdl_cursor, output_fd.data_ptr<c10::complex<float>>());

    } else {
        throw std::runtime_error("fdl must be 2D or 3D");
    }

    C10_CUDA_KERNEL_LAUNCH_CHECK(); // Check for errors
    return output_fd;
}
'''

In [None]:
def load_cuda(cuda_src, cpp_src, funcs, K, B, C, verbose=False):
    return load_inline(
        cuda_sources=[cuda_src],
        cpp_sources=[cpp_src],
        functions=funcs,
        extra_cuda_cflags=["-O2", f"-DNUM_CHANNELS={C}", f"-DBLOCK_SIZE={B}", f"-DNUM_PARTS={K}"],
        verbose=verbose,
        name="inline_ext"
    )

cpp_src = "torch::Tensor part_conv_gpu(torch::Tensor input_fd, torch::Tensor fdl, torch::Tensor filters_fd, int fdl_cursor);"
module = load_cuda(cuda_src, cpp_src, ['part_conv_gpu'], K, B, C, verbose=True)

In [None]:
# ===========================================================
# Reset parameters
filters_fd_gpu = torch.tensor(filters_fd, dtype=torch.complex64, device='cuda').contiguous()
fdl_gpu = torch.zeros((C, B + 1, K), dtype=torch.complex64, device='cuda').contiguous()

fdl_cursor = 0
input_buffer_td = np.zeros((C, 2 * B))
output = []

# Simulate input streaming by looping over the batches
for signal_td in signal_batch:
    # Put the incoming signal in the input buffer after sliding the previous signal
    input_buffer_td[:, :B] = input_buffer_td[:, B:]
    input_buffer_td[:, B:] = signal_td

    # Compute the RFFT of the signals (real-to-complex FFT)
    input_fd = np.fft.rfft(input_buffer_td)  # shape: (B + 1)

    # The following code is executed on the GPU
    # =================================================================
    
    # Move the input to the GPU
    input_fd_gpu = torch.tensor(input_fd, dtype=torch.complex64, device='cuda').contiguous()

    # Perform the complex multiplication between the fdl and the filter partitions
    output_fd = module.part_conv_gpu(input_fd_gpu, fdl_gpu, filters_fd_gpu, fdl_cursor)
    fdl_cursor = (fdl_cursor + 1) % K  # Update the index

    # ============================================================================
    # The following code is executed on the CPU

    # Perform the inverse RFFT to obtain the output signal
    output_td = np.fft.irfft(output_fd.cpu(), axis=1)  # shape: (C, 2 * B)
    # Only keep the first B samples and append to the output
    output.append(output_td[:, B:].T)

output = np.array(output).reshape(-1, C).T


In [None]:
# Check the output
check_output(output, filter_td)