In [None]:
import matplotlib.pyplot as plt
import numpy as np
import probnum as pn
import scipy.sparse
import scipy.sparse.linalg

import linpde_gp

In [None]:
%matplotlib inline

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats("pdf", "svg")

In [None]:
bvp = linpde_gp.problems.pde.PoissonEquationDirichletProblem(
    domain=(-1.0, 1.0),
    rhs=linpde_gp.functions.Constant(input_shape=(), value=2.0),
    boundary_values=(0.0, 0.0),
)

## 1D Fourier Basis Functions

In [None]:
def fourier_basis_element_1d(x, n, domain):
    l, r = domain

    return np.sin(n * np.pi * (x - l) / (r - l))

In [None]:
xs_plot = np.linspace(*bvp.domain, 100)
    
for n in range(1, 5):
    plt.plot(xs_plot, fourier_basis_element_1d(xs_plot, n, bvp.domain))

plt.show()

In [None]:
def coords2fn(coords, domain):
    ns = np.arange(1, coords.size + 1)

    def f(grid):
        sol = fourier_basis_element_1d(  # shape: (G, N)
            grid[:, None],
            ns[None, :],
            domain,
        )
        sol *= coords
        
        return np.sum(sol, axis=-1)
    
    return f

In [None]:
xs_plot = np.linspace(*bvp.domain, 100)

plt.plot(
    xs_plot,
    coords2fn(coords=np.array([-4.0, 1.0, 2.0, 2.0, 1.0]), domain=bvp.domain)(xs_plot)
)
plt.show()

## Approximate Laplace Operator in the Fourier Basis

In [None]:
def poisson_1d_zero_boundary_operator_fourier(N: int, domain):
    l, r = domain
    Ns = np.arange(1, N + 1)

    return pn.linops.Matrix(
        scipy.sparse.diags(
            Ns * np.pi / (4 * (r - l)) * ((2 * np.pi) * Ns + np.sin((2 * np.pi) * Ns)),
            offsets=0,
            format="csr",
            dtype=np.double,
        )
    )

In [None]:
poisson_1d_zero_boundary_operator_fourier(6, bvp.domain)

In [None]:
plt.imshow(poisson_1d_zero_boundary_operator_fourier(6, bvp.domain).todense())
plt.show()

## Approximate RHS in the Fourier Basis

In [None]:
def poisson_1d_rhs_fourier(alpha: float, N: int, domain):
    if isinstance(alpha, float):
        l, r = domain
        Ns = np.arange(1, N + 1)

        return alpha * (r - l) / np.pi * (1 - np.cos(np.pi * Ns)) / Ns
    else:
        raise TypeError()

In [None]:
poisson_1d_rhs_fourier(1.0, 6, bvp.domain)

## Solution

In [None]:
def discrete_1d_fourier_solve(N: int, domain):
    A = poisson_1d_zero_boundary_operator_fourier(N, domain)
    b = poisson_1d_rhs_fourier(2.0, N, domain)

    (coeffs, _) = scipy.sparse.linalg.cg(A.A, b)
    
    return coeffs

In [None]:
u_fourier_coords = discrete_1d_fourier_solve(N=3, domain=bvp.domain)
u_fourier = coords2fn(u_fourier_coords, bvp.domain)

In [None]:
u_fourier_coords

In [None]:
xs_plot = np.linspace(-1.0, 1.0, 100)

plt.plot(xs_plot, bvp.solution(xs_plot), label="Exact Solution")
plt.plot(xs_plot, u_fourier(xs_plot), label="Fourier Solution")
plt.legend()

plt.show()

### Implementation in `linpde_gp`

In [None]:
import ipywidgets

%matplotlib widget

fig, ax = plt.subplots(num="Solution to the 1D Poisson Problem with g(x) = 0")

def interact(domain: tuple, rhs: float, n: int):
    # Define the problem
    bvp = linpde_gp.problems.pde.PoissonEquationDirichletProblem(
        domain=domain,
        rhs=linpde_gp.functions.Constant(input_shape=(), value=rhs),
        boundary_values=(0.0, 0.0),
    )
    
    # Define a finite basis
    basis = linpde_gp.galerkin.bases.FourierBasis(
        domain=bvp.domain,
        num_frequencies=n,
    )
    
    discrete_problem = linpde_gp.galerkin.project(bvp, basis)
    
    # Pick a linear solver
    solver = linpde_gp.linalg.solvers.ConjugateGradients()

    # Solve the problem
    sol_coords_fourier = solver.solve(discrete_problem).support
    sol_fourier = basis.coords2fn(sol_coords_fourier)
    
    # Plot the solution
    plot_grid = np.linspace(*domain, 200)
    
    ax.cla()
    ax.plot(plot_grid, bvp.solution(plot_grid), label="Exact Solution")
    ax.plot(plot_grid, sol_fourier(plot_grid), label="Fourier Solution")
    ax.legend()

    fig.canvas.draw()
    
ipywidgets.interactive(
    interact,
    domain=ipywidgets.FloatRangeSlider(
        value=(-1.0, 1.0),
        min=-3.0,
        max=3.0,
        description="Domain",
    ),
    rhs=ipywidgets.FloatSlider(
        value=2.0,
        min=-3.0,
        max=3.0,
        description="f(x)",
    ),
    n=ipywidgets.IntSlider(
        value=1,
        min=1,
        max=20,
        continuous_update=True,
    ),
)