In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [17, 6]
from IPython.display import Audio
from ipywidgets import HTML

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft

import numpy as np

### This just works for reverb

In [None]:
def fft_conv1d(signal, kernel):
    padded_kernel = F.pad(kernel, (0, signal.shape[-1] - kernel.shape[-1]))
    
    f_signal = fft.rfft(signal)
    f_kernel = fft.rfft(padded_kernel)
    
    f_kernel = torch.conj(f_kernel)
    
    f_conv = f_signal * f_kernel
    f_conv = fft.irfft(f_conv)
    f_conv = f_conv[..., :signal.shape[-1] - kernel.shape[-1] + 1]
    
    return f_conv

In [None]:
# N, C, L
signal = torch.randn(10, 1, 1000)
# OC, IC, L
kernel = torch.randn(1, 1, 100)

In [None]:
# N, OC, OL
conv = F.conv1d(signal, kernel)

In [None]:
# N, OC, OL
f_conv = fft_conv1d(signal, kernel)

In [None]:
np.testing.assert_allclose(conv, f_conv, atol=1e-5, rtol=1e-5)

### Batched sequence kernel fft_conv1d

In [None]:
# N, S, C, L
signal = torch.randn(5, 10, 1, 1000)
# N, S, C, B
bands = torch.randn(5, 10, 1, 51)
kernel = fft.fftshift(fft.irfft(bands))

In [None]:
# merge batch, sequence, channel dimensions
signal_ = signal.reshape(signal.shape[0] * signal.shape[1] * signal.shape[2], signal.shape[3])
kernel_ = kernel.reshape(kernel.shape[0] * kernel.shape[1] * kernel.shape[2], kernel.shape[3])
# add batch dim
signal_ = signal_.unsqueeze(0)
# add in channel dim
kernel_ = kernel_.unsqueeze(1)

### Looping over kernels is equivalent to group convolution

In [None]:
conv = F.conv1d(signal_, kernel_, groups=kernel_.shape[0])
conv = conv.reshape(signal.shape[0], signal.shape[1], signal.shape[2], -1)

In [None]:
l_signal = signal.reshape(signal.shape[0] * signal.shape[1], 1, signal.shape[2], signal.shape[3])
l_kernel = kernel.reshape(kernel.shape[0] * kernel.shape[1], 1, kernel.shape[2], kernel.shape[3])

In [None]:
l_conv = []
for s, k in zip(l_signal, l_kernel):
    l_conv.append(F.conv1d(s, k))

In [None]:
l_conv = torch.cat(l_conv, dim=1)
l_conv = l_conv.reshape(signal.shape[0], signal.shape[1], signal.shape[2], -1)

In [None]:
np.testing.assert_allclose(conv, l_conv, atol=1e-5, rtol=1e-5)

### And this is batched sequence convolution

In [None]:
f_conv = fft_conv1d(signal_, kernel_.transpose(0, 1))
f_conv = f_conv.reshape(signal.shape[0], signal.shape[1], signal.shape[2], -1)

In [None]:
np.testing.assert_allclose(conv, f_conv, atol=1e-5, rtol=1e-5)

### Function definition

In [None]:
def grouped_fft_conv1d(signal, kernel):
    # merge batch, sequence, channel dimensions
    signal_ = signal.reshape(signal.shape[0] * signal.shape[1] * signal.shape[2], signal.shape[3])
    kernel_ = kernel.reshape(kernel.shape[0] * kernel.shape[1] * kernel.shape[2], kernel.shape[3])
    # add batch dim
    signal_ = signal_.unsqueeze(0)
    # add out channel dim
    kernel_ = kernel_.unsqueeze(0)
    
    f_conv = fft_conv1d(signal_, kernel_)
    f_conv = f_conv.reshape(signal.shape[0], signal.shape[1], signal.shape[2], -1)
    
    return f_conv

In [None]:
f_conv = grouped_fft_conv1d(signal, kernel)

In [None]:
np.testing.assert_allclose(conv, f_conv, atol=1e-5, rtol=1e-5)

### Stuff

In [None]:
hop_size = 480
seq_len = 10
batch_size = 5

In [None]:
noise = torch.rand(batch_size, 1, seq_len * hop_size + hop_size) * 2.0 - 1.0
framed_noise = noise.unfold(-1, hop_size * 2, hop_size).transpose(1, 2)
print(noise.shape, framed_noise.shape)