## Fast Fourier Transform and Spectral Reconstruction

This notebook enables the exploration of the FFT in PyTorch and helps in determining how to form a continuous representation of a function from a discrete representation of the function over a single period.

In [None]:
# Parameters
# 1D example
num_points = 100
amplitudes = [1.0, 2.0, 0.4]
base_frequency = 0.2
frequency_multipliers = [1.0]  # , 2.0]#, 3.0] # These need to be integers, please
phases = [0.0, 0.2, 0.5]
# 2D example
num_x_points = 1000
num_y_points = 1000
amplitudes_2d = [1.0]
base_x_frequency = 1.0
base_y_frequency = 1.0
x_frequency_multipliers = [1.0]  # These need to be integers
y_frequency_multipliers = [1.0]  # These need to be integers
phases_2d = [0.0]

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch

from pyinsulate.spectral import SpectralReconstruction

### 1D Example

In [None]:
# Define a functional form of a sine function
def construct_wave_equation(amplitude, frequency, phase):
    def wave_equation(xs):
        return amplitude * torch.sin(2 * np.pi * frequency * xs + phase)

    return wave_equation

In [None]:
# Sample points and plot
bounds = (0, 1.0 / base_frequency)
x = torch.linspace(*bounds, steps=num_points)[:-1]
wave = x.new_zeros(x.size())
for amplitude, multiplier, phase in zip(amplitudes, frequency_multipliers, phases):
    wave += construct_wave_equation(amplitude, base_frequency * multiplier, phase)(x)

fig = plt.figure()
plt.plot(x.numpy(), wave.numpy())
plt.title("Original wave")
plt.xlabel(r"$x$")
plt.show()

In [None]:
# Compute the fft of the function
wave_fft = torch.rfft(wave, signal_ndim=1, normalized=False, onesided=True)
print(wave_fft.size())
# plot the fft
fig = plt.figure()
plt.stem(
    wave_fft.numpy()[..., 0],
    markerfmt="ro",
    linefmt="r-",
    basefmt=" ",
    use_line_collection=True,
)
plt.stem(
    wave_fft.numpy()[..., 1],
    markerfmt="bo",
    linefmt="b-",
    basefmt=" ",
    use_line_collection=True,
)
plt.title("DFT(wave)")
plt.xlabel(r"$f$")
plt.legend(["real", "imag"])
plt.show()

In [None]:
# Define a new function, which is the reconstruction of the sine
def reconstruct(f_transformed):
    def f(x):
        # torch.ger = outer product
        angles = (
            2
            * np.pi
            * torch.ger(torch.arange(len(f_transformed), dtype=f_transformed.dtype), x)
        )
        cosines = torch.mean(
            f_transformed[..., 0].unsqueeze(-1) * torch.cos(angles), dim=0
        )
        sines = torch.mean(
            f_transformed[..., 1].unsqueeze(-1) * torch.sin(angles), dim=0
        )
        return cosines - sines

    return f

In [None]:
# Sample points from reconstruction and plot
larger_x = torch.linspace(-0.1, 1.1, steps=int(6 / 5 * len(x)))[:-1]
wave_recon_ext = reconstruct(wave_fft)(larger_x)
wave_recon = reconstruct(wave_fft)(torch.linspace(0, 1, steps=len(x))[:-1])


print(f"worst error: {torch.max(wave - wave_recon).item()}")

extension_factor = 2
num_complex_coefs = int(np.floor(float(wave.size()[-1]) / 2.0))
print(num_complex_coefs)
frequencies_size = tuple(
    (*wave.size()[:-1], num_complex_coefs * extension_factor + 1, 2)
)  # real + imag
final_size = tuple((*wave.size()[:-1], extension_factor * wave.size()[-1]))
wave_frequencies = wave.new_zeros(frequencies_size)
wave_frequencies[..., : num_complex_coefs + 1, :] = torch.rfft(
    wave, 1, onesided=True, normalized=False
)
torch_recon = extension_factor * torch.irfft(
    wave_frequencies, 1, onesided=True, normalized=False, signal_sizes=final_size
)
torch_recon_x = torch.linspace(*bounds, steps=len(torch_recon))
baseline = torch.irfft(torch.rfft(wave, 1, onesided=False), 1, onesided=False)
print(wave.size())
print(torch_recon.size())

In [None]:
# wave2 = (wave[:-1] + wave[1:])/2
# x2 = x[:-1]

# plot
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].plot(x.numpy(), wave.numpy(), "k", label="Truth")
axes[0].plot(x.numpy(), baseline.numpy(), "g", label="Baseline")
axes[0].plot(torch_recon_x.numpy(), torch_recon.numpy(), "b", label="IFFT(ext(FFT(.)))")
axes[0].plot(x.numpy(), wave_recon.numpy(), "r", label="Fourier Series")
axes[0].plot(
    (bounds[1] - bounds[0]) * larger_x.numpy() + bounds[0],
    wave_recon_ext.numpy(),
    "y",
    label="Fourier Series Extended",
    zorder=-1,
)
axes[0].set_title("Reconstructed wave")
axes[0].set_xlabel(r"$x$")
axes[0].legend()
# axes[1].plot(x.numpy(), wave_recon.numpy() - wave.numpy(), 'r', label="Fourier Series")
axes[1].plot(
    x.numpy(),
    torch_recon.numpy()[::extension_factor] - wave.numpy(),
    "b",
    label="IFFT(ext(FFT(.)))",
)
axes[1].plot(x.numpy(), baseline.numpy() - wave.numpy(), "g", label="Baseline")
axes[1].set_title("Error")
axes[1].set_xlabel(r"$x$")
axes[1].legend()
axes[2].plot(
    x.numpy(), np.abs(wave_recon.numpy() - wave.numpy()), "r", label="Fourier Series"
)
axes[2].plot(
    x.numpy(),
    np.abs(torch_recon.numpy()[::extension_factor] - wave.numpy()),
    "b",
    label="IFFT(ext(FFT(.)))",
)
axes[2].plot(x.numpy(), np.abs(baseline.numpy() - wave.numpy()), "g", label="Baseline")
axes[2].set_title("Magnitude of error")
axes[2].set_xlabel(r"$x$")
axes[2].set_yscale("log")
axes[2].legend()
plt.show()

# # Plot DFT of Reconstruction
# recon_fft = torch.rfft(torch_recon, signal_ndim=1, normalized=False, onesided=True)
# # plot the fft
# fig = plt.figure()
# plt.stem(wave_fft.numpy()[0,:,0], markerfmt='ro', linefmt='r-', basefmt=' ', use_line_collection=True)
# plt.stem(wave_fft.numpy()[0,:,1], markerfmt='bo', linefmt='b-', basefmt=' ', use_line_collection=True)
# plt.title("DFT(wave)")
# plt.xlabel(r"$f$")
# plt.legend(["real", "imag"])
# plt.show()

### 1D Example using the SpectralReconstruction layer

In [None]:
def wave_equation(amplitude, frequency, phase, xs):
    return amplitude * torch.sin(2 * np.pi * frequency * xs + phase)

In [None]:
# Create a "neural network" with the reconstruction layer
net = SpectralReconstruction(1)

# Create a fake sine wave
# batch_size = np.random.randint(1, 10)
batch_size = 1
amplitude = torch.rand(batch_size, 1) * 10
frequency = torch.randint(
    1, 10, size=(batch_size, 1), dtype=amplitude.dtype
)  # needs to be integers
phase = torch.rand(batch_size, 1) * 2 * np.pi
# an odd number of points allows us to test the middle of the domain
xs = torch.linspace(0, 1, steps=501).view(1, -1)
wave = wave_equation(amplitude, frequency, phase, xs)

query_points = xs.unsqueeze(-1)  # torch.rand(amplitude.size())

recon = net(wave, query_points)
print(f"recon: {recon.size()}")
# print(f"recon: {recon}")

# print(
#     f"wave(query_points): {wave_equation(amplitude, frequency, phase, query_points)}"
# )

In [None]:
# Plot the wave and the reconstruction
for b in range(batch_size):
    fig, axes = plt.subplots(1, 2)
    axes[0].plot(xs.numpy()[b], wave.numpy()[b], "k", label="Truth")
    axes[0].plot(xs.numpy()[b], recon.numpy()[b], "r", label="Reconstruction")
    axes[0].set_title("Reconstructed wave")
    axes[0].set_xlabel(r"$x$")
    axes[0].legend()
    axes[1].plot(
        xs.numpy()[b],
        np.abs(recon.numpy()[b] - wave.numpy()[b]),
        "r",
        label="Reconstruction",
    )
    axes[1].plot(
        xs.numpy()[b],
        np.abs(torch.irfft(torch.rfft(wave, 1), 1).numpy()[b] - wave.numpy()[b]),
        "k",
        label="Baseline",
    )
    axes[1].set_title("Magnitude of error")
    axes[1].set_xlabel(r"$x$")
    axes[1].set_yscale("log")
    axes[1].legend()
    plt.tight_layout()
    plt.show()

### 2D Example

In [None]:
# # Define a functional form of 2d a "sine" function
# def construct_wave_equation(amplitude, x_frequency, y_frequency, phase):
#     def wave_equation(xs, ys):
#         return amplitude * torch.cos(2 * np.pi * (x_frequency * xs + y_frequency * ys) + phase)

#     return wave_equation

In [None]:
# # Sample points and plot
# x_bounds = (0, 1.0 / base_x_frequency)
# y_bounds = (0, 1.0 / base_y_frequency)
# x = torch.linspace(*x_bounds, steps=num_x_points)
# y = torch.linspace(*y_bounds, steps=num_y_points)
# xs, ys = torch.meshgrid(x, y)
# wave = xs.new_zeros(xs.size())
# for amplitude, x_multiplier, y_multiplier, phase in zip(amplitudes, x_frequency_multipliers, y_frequency_multipliers, phases):
#     wave += construct_wave_equation(amplitude, base_x_frequency * x_multiplier, base_y_frequency * y_multiplier, phase)(xs, ys)

# fig = plt.figure()
# plt.imshow(wave.numpy(), extent=(*x_bounds, *y_bounds), cmap=mpl.cm.get_cmap('hot'))
# plt.colorbar()
# plt.title("Original wave")
# plt.xlabel(r"$x$")
# plt.ylabel(r"$y$", rotation=0)
# plt.show()

In [None]:
# # Compute the fft of the function
# wave_fft = torch.rfft(wave, signal_ndim=2, normalized=False)
# # plot the fft
# fig, axes = plt.subplots(1, 2)
# im0 = axes[0].imshow(wave_fft.numpy()[...,0], origin="lower")
# im1 = axes[1].imshow(wave_fft.numpy()[...,1], origin="lower")
# fig.colorbar(im0, ax=axes[0])
# fig.colorbar(im1, ax=axes[1])
# axes[0].set_xlabel(r"$f_x$")
# axes[0].set_ylabel(r"$f_y$", rotation=0)
# axes[1].set_xlabel(r"$f_x$")
# axes[1].set_ylabel(r"$f_y$", rotation=0)
# axes[0].set_title("Real")
# axes[1].set_title("Imag")
# plt.suptitle("DFT(wave)")
# plt.tight_layout()
# plt.show()

In [None]:
# # Define a new function, which is the reconstruction of the sine
# def reconstruct(f_transformed):
#     def f(query_points):
#         # torch.ger = outer product
#         # TODO figure out how this actually needs to be done.
#         frequency_grid = torch.cartesian_prod(*[torch.arange(x, dtype=f_transformed.dtype) for x in f_transformed.size()[:-1]])
#         angles = 2 * np.pi * torch.einsum('...,bj->b...j', frequency_grid, query_points)
#         sum_dims = tuple(range(1, len(angles.size()))) # all dims but batch
#         cosines = torch.mean(f_transformed[..., 0].unsqueeze(-1) * torch.prod(torch.cos(angles), dim=-1), dim=sum_dims)
#         sines = torch.mean(f_transformed[..., 1].unsqueeze(-1) * torch.prod(torch.sin(angles), dim=-1), dim=sum_dims)
#         return cosines - sines
#     return f

In [None]:
# # Sample points from reconstruction and plot
# x_bounds_large = (-0.3, 1.3)
# y_bounds_large = (-0.3, 1.3)
# x = torch.linspace(*x_bounds_large, steps=num_x_points)
# y = torch.linspace(*y_bounds_large, steps=num_y_points)
# xs, ys = torch.meshgrid(x, y)
# query_points = torch.stack([xs.reshape(-1), ys.reshape(-1)])
# wave_recon = reconstruct(wave_fft)(query_points).reshape(xs.size())

# fig = plt.figure()
# plt.imshow(wave_recon.numpy(), extent=(*x_bounds_large, *y_bounds_large), cmap=mpl.cm.get_cmap('hot'))
# plt.colorbar()
# plt.title("Reconstructed wave")
# plt.xlabel(r"$x$")
# plt.ylabel(r"$y$", rotation=0)
# plt.show()