In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

import torch
import torch.fft as fft
import torch.nn.functional as F 

from performer.models.components.fft_conv import fft_conv1d, pad_to, unfold, fft_conv1d_new

In [None]:
signal = torch.zeros(4, 1, 480*4)
signal[..., 0] = 1.
signal[..., 300] = 1.
signal[..., 500] = 1.
signal[..., 800] = 1.
ir = torch.zeros(2, 1, 480*3)
ir[..., 0] = 1.
ir[..., 1439] = 1.

In [None]:
# Working standard convolution
sig_pad = ir.shape[-1] - 1
padded_signal = F.pad(signal, (sig_pad, sig_pad))
# padded_signal = torch.cat([signal[..., -sig_pad:], signal], dim=-1)

conv_out = F.conv1d(padded_signal, ir.flip(-1))
conv_out.shape

# torch.allclose(signal[0, 0], conv_out[0, 0])

In [None]:
# Working FFT convolution
sig_pad = ir.shape[-1] - 1
padded_signal = F.pad(signal, (sig_pad, sig_pad))

end_pad = padded_signal.shape[-1] - ir.shape[-1]
padded_ir = F.pad(ir, (0, end_pad))

ir_z = fft.rfft(padded_ir)
signal_z = fft.rfft(padded_signal)

fft_conv_out = signal_z.transpose(0, 1) * ir_z.conj()
fft_conv_out = fft_conv_out.transpose(0, 1)

fft_conv_out = fft.irfft(fft_conv_out)

# torch.allclose(signal[0, 0], conv_out[0, 0])

In [None]:
# fft_conv_out = fft_conv_out[:, :, ir.shape[-1] - 1:]
fft_conv_out = fft_conv_out[..., :padded_signal.size(-1) - ir.size(-1) + 1]
fft_conv_out.shape

In [None]:
torch.allclose(fft_conv_out, conv_out)

In [None]:
(fft_conv_out - conv_out).abs().max()

In [None]:
fft_conv_out.shape, conv_out.shape

In [None]:
plt.plot(fft_conv_out[0, 0])

In [None]:
plt.plot(conv_out[0, 0])

In [None]:
out, tail = fft_conv1d_new(signal, ir)

In [None]:
plt.plot(out[0, 0])

In [None]:
plt.plot(tail[0, 0])

In [None]:
out.shape, tail.shape

In [None]:
out = fft_conv1d(signal, ir)
out.shape

In [None]:
padded = F.pad(signal, (ir.shape[-1] - 1, 0))

batch, channels, length = padded.shape
out_channels, _, kernel_size = ir.shape

ir_ = pad_to(ir, length)
ir_z = fft.rfft(ir_)

frames_z = fft.rfft(padded).unsqueeze(2)
print(frames_z.shape)
print(ir_z.shape)
out_z = frames_z * ir_z.conj()
_out = fft.irfft(out_z)
print(_out.shape)

_out = _out[..., : -kernel_size + 1]
_out = _out.reshape(batch, out_channels, -1)
_out = _out[..., :]
target_length = (length - kernel_size) + 1

# TODO: this line throws away the tail. Will be necessary for real-time synth.
_out = _out[..., :target_length]

In [None]:
_out.shape

In [None]:
torch.allclose(out, _out)

In [None]:
out.shape, _out.shape