In [None]:
using Pkg; Pkg.activate(".")
using Plots, LinearAlgebra, LaTeXStrings, ForwardDiff
include("tools.jl")

## Fourier Spectral Methods in d > 1

We start by re-implementing our 1D methods for 2D and 3D: 
- x grid
- k grid 
- trigonometric interpolant 
- evaluation of the trig interp on a finer grid

In [None]:

"""
Given a one-dimensional array y, return d d-dimensional arrays 
 y ⊗ 1 ⊗ ... ⊗ 1   (x1-coordinate)
 1 ⊗ y ⊗ 1 ⊗ ...   (x2-coordinate)
... 
 1 ⊗ ... ⊗ 1 ⊗ y   (xd-coordinate)
"""
function tensorgrid(d, x1)
    dims = ntuple(i -> length(x1), d)
    X = reshape(x1 * ones(Bool, length(x1)^(d-1))', dims)
    pdim(i, d) = (dd = collect(1:d); dd[1] = i; dd[i] = 1; tuple(dd...))
    return ntuple(i -> permutedims(X, pdim(i,d)), d)
end

"""
d-dimensional x grid 
"""
xgrid(d, N) = tensorgrid(d, xgrid(N))

"""
d-dimensional k-grid 
"""
kgrid(d, N) = tensorgrid(d, kgrid(N))


"""
construct the coefficients of the trigonometric interpolant
in d dimensions
"""
function triginterp_fft(f::Function, N, d::Integer)
    XX = xgrid(d, N)
    # nodal values at interpolation nodes
    F = f.(XX...) 
    return fft(F) / (2*N)^d
end 

function evaltrig_grid(F̂::AbstractArray{T, 2}, M::Integer) where {T}
    N = size(F̂, 1) ÷ 2;
    @assert size(F̂) == (2*N, 2*N)
    @assert M >= N
    F̂_M = zeros(ComplexF64, (2*M, 2*M)) 
    kk1 = 1:N; kk2 = N+1:2*N; kk3 = 2*M-N+1:2*M
    F̂_M[kk1, kk1] .= F̂[kk1, kk1]
    F̂_M[kk1, kk3] .= F̂[kk1, kk2]
    F̂_M[kk3, kk1] .= F̂[kk2, kk1] 
    F̂_M[kk3, kk3] .= F̂[kk2, kk2]
    x = xgrid(M) 
    Fx = real.(ifft(F̂_M) * (2*M)^2)
    return Fx, x
end

In [None]:
N = 2 
F̂ = randn(2*N, 2*N)
x = xgrid(N)
F = real.(ifft(F̂)*(2*N)^2)
surface(x, x, F; size=(400,300), colorbar=nothing)

In [None]:
M = 32 
FM, x = evaltrig_grid(F̂, M)
surface(x, x, FM; size=(400,300), colorbar=nothing)

### Approximation Rates 

We will explore approximation rates for a simple generalization of the periodic witch of agnesi: 
$$
f(x_1, \dots, x_d) = \frac{1}{1 + c^2 \sum_{t = 1}^d \sin^2(x_t)}
$$
But to make things a bit clearer, we change it slightly to 
$$
f(x_1, \dots, x_d) = \frac{1}{1 + c^2 \sum_{t = 1}^d \sin^2(x_t/2-\pi/2)}
= \frac{1}{1+ c^2/2 \sum_{t=1}^d \cos(x_t)}
$$
This ensures that there is just a single peak in the center of the domain $[0, 2\pi)^d$. In the rewriting we used that $\sin^2(x/2) = \frac12 - \frac12 \cos(x)$ and $\cos(x-\pi) = -\cos(x)$.

### Two Dimensions

In [None]:
f2_fun, α = let c = 4.0
    ( (x1, x2) -> 1/(1 + 0.5*c^2 * (2 + cos(x1) + cos(x2))) ), 2*asinh(1/c)
end 

In [None]:
# plot the target function
N = 64; 
X1, X2 = xgrid(2, N)
F = f2_fun.(X1, X2)
x = xgrid(N)
surface(x, x, F; size=(400,300), colorbar=nothing)

In [None]:
D = 2  # dimension
NN = 8:8:64

# Target funtion on a fine grid
Ne = 256
X1e, X2e = xgrid(D, Ne)
FM_e = f2_fun.(X1e, X2e)

errs = Float64[]
for N in NN 
    F̂ = triginterp_fft(f2_fun, N, D)
    FN_e, x = evaltrig_grid(F̂, Ne)
    err_N = norm(FM_e[:] - FN_e[:], Inf)
    push!(errs, err_N)
end


In [None]:
plot(NN, errs, m=:o, lw=3, 
        yscale = :log10, label = "error", 
        size = (400, 250), )
plot!([30,50], 10 * exp.(- α * [30, 50]), 
        lw=2, ls=:dash, c=:black, 
        label = L"e^{-\alpha N}")

### Three Dimensions

Nothing really changes except the increasing cost of the computations in 3D. The following code snippets can give a starting point for implementing some 3-dimensional codes. 

In [None]:
function evaltrig_grid(F̂::AbstractArray{T, 3}, M::Integer) where {T}
    N = size(F̂, 1) ÷ 2;
    @assert size(F̂) == (2*N, 2*N, 2*N)
    @assert M >= N
    F̂_M = zeros(ComplexF64, (2*M, 2*M, 2*M))
    kk1 = 1:N; kk2 = N+1:2*N; kk3 = 2*M-N+1:2*M
    F̂_M[kk1, kk1, kk1] .= F̂[kk1, kk1, kk1]
    F̂_M[kk1, kk1, kk3] .= F̂[kk1, kk1, kk2]
    F̂_M[kk1, kk3, kk1] .= F̂[kk1, kk2, kk1]
    F̂_M[kk1, kk3, kk3] .= F̂[kk1, kk2, kk2]
    F̂_M[kk3, kk1, kk1] .= F̂[kk2, kk1, kk1]
    F̂_M[kk3, kk1, kk3] .= F̂[kk2, kk1, kk2]
    F̂_M[kk3, kk3, kk1] .= F̂[kk2, kk2, kk1]
    F̂_M[kk3, kk3, kk3] .= F̂[kk2, kk2, kk2]
    x = xgrid(M) 
    Fx = real.(ifft(F̂_M) * (2*M)^3)
    return Fx, x
end

In [None]:
f3_fun, α = let c = 4.0
    ( (x1, x2, x3) -> 1 / (1 + 
        0.5*c^2 * (3+cos(x1)+cos(x2)+cos(x3)))), 2*asinh(1/c)
end 

In [None]:
D = 3  # dimension
NN = 4:4:32   # grid sizes

# Target function on a fine grid
# Note: 256^3 = 2^24 > 10^7 grid points!!
Ne = 128   
X1e, X2e, X3e = xgrid(D, Ne)
FM_e = f3_fun.(X1e, X2e, X3e)

errs = Float64[]
for N in NN 
    X1, X2, X3 = xgrid(D, N)
    F̂ = triginterp_fft(f3_fun, N, D)
    FN_e, x = evaltrig_grid(F̂, Ne)
    err_N = norm(FM_e[:] - FN_e[:], Inf)
    push!(errs, err_N)
end


In [None]:
plot(NN, errs, m=:o, lw=3, 
        yscale = :log10, label = "error", 
        size = (400, 250), )
plot!([16,30], 10 * exp.(- α * [16, 30]), 
        lw=2, ls=:dash, c=:black, 
        label = L"e^{-\alpha N}")