In [1]:
import dolfin as dl
import hippylib as hl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from spin.core import problem
from spin.fenics import converter
from spin.hippylib import hessian, laplace, misfit, optimization, prior

sns.set_theme(style="ticks")

In [2]:
mesh = dl.IntervalMesh(100, -1, 1)
problem_settings = problem.SPINProblemSettings(
    mesh=mesh,
    pde_type="mean_exit_time",
    inference_type="drift_only",
    drift=("-x[0]",),
    log_squared_diffusion=("std::log(1.0)",),
)
problem_builder = problem.SPINProblemBuilder(problem_settings)
spin_problem = problem_builder.build()

In [30]:
parameter_coordinates = spin_problem.coordinates_parameters
solution_coordinates = spin_problem.coordinates_variables
true_parameter = converter.create_dolfin_function(
    ("-x[0]",), spin_problem.function_space_parameters
)
true_parameter = true_parameter.vector().get_local()
true_solution = spin_problem.solve_forward(true_parameter)
data_stride = 5
data_locations = solution_coordinates[::data_stride]
data_values = true_solution[::data_stride]
rng = np.random.default_rng(seed=0)
noise_std = 0.02
noise = rng.normal(loc=0, scale=noise_std, size=data_values.size)
data_values = data_values + noise

In [None]:
_, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5), layout="constrained")
axs[0].plot(parameter_coordinates, true_parameter)
axs[0].set_xticks((-1, -0.5, 0, 0.5, 1))
axs[0].set_yticks((-1, -0.5, 0, 0.5, 1))
axs[0].set_title("True drift")
axs[0].set_xlabel(r"$x$")
axs[0].set_ylabel(r"$b(x)$")
axs[1].plot(solution_coordinates, true_solution, label="Exact")
axs[1].scatter(data_locations, data_values, color="firebrick", label="Data")
axs[1].set_xticks((-1, -0.5, 0, 0.5, 1))
axs[1].set_yticks((0, 0.3, 0.6, 0.9, 1.2, 1.5))
axs[1].set_title("True solution")
axs[1].set_xlabel(r"$x$")
axs[1].set_ylabel(r"$\tau(x)$")
axs[1].legend()

In [32]:
prior_settings = prior.PriorSettings(
    function_space=spin_problem.function_space_parameters,
    mean=("-0.5*x[0]",),
    variance=("0.1",),
    correlation_length=("0.1",),
    robin_bc=False,
)
prior_builder = prior.BilaplacianVectorPriorBuilder(prior_settings)
spin_prior = prior_builder.build()
prior_variance = spin_prior.compute_variance_with_boundaries(
    method="Randomized", num_eigenvalues_randomized=100
)

In [None]:
_, ax = plt.subplots(figsize=(5, 5), layout="constrained")
ax.plot(parameter_coordinates, spin_prior.mean_array, label="Mean")
ax.fill_between(
    parameter_coordinates.flatten(),
    spin_prior.mean_array - 1.96 * np.sqrt(prior_variance),
    spin_prior.mean_array + 1.96 * np.sqrt(prior_variance),
    alpha=0.3,
    label="95% CI",
)
ax.set_title("Prior")
ax.set_xticks((-1, -0.5, 0, 0.5, 1))
ax.set_xlim((-1, 1))
ax.legend()

In [34]:
misfit_settings = misfit.MisfitSettings(
    function_space=spin_problem.function_space_variables,
    observation_points=data_locations,
    observation_values=data_values,
    noise_variance=np.ones(data_locations.size) * noise_std ** 2,
)
misfit_builder = misfit.MisfitBuilder(misfit_settings)
spin_misfit = misfit_builder.build()

In [35]:
inference_model = hl.Model(
    spin_problem.hippylib_variational_problem, spin_prior.hippylib_prior, spin_misfit
)

In [None]:
optimization_settings = optimization.SolverSettings(
    relative_tolerance=1e-8, absolute_tolerance=1e-8, verbose=True
)
initial_guess = spin_prior.mean_array
newton_solver = optimization.NewtonCGSolver(optimization_settings, inference_model)
solver_solution = newton_solver.solve(initial_guess)

In [37]:
hessian_settings = hessian.LowRankHessianSettings(
    num_eigenvalues=20,
    num_oversampling=5,
    inference_model=inference_model,
    evaluation_point=[
        solver_solution.forward_solution,
        solver_solution.optimal_parameter,
        solver_solution.adjoint_solution,
    ],
)
eigenvalues, eigenvectors = hessian.compute_low_rank_hessian(hessian_settings)

In [None]:
index_vector = np.arange(1, eigenvalues.size + 1)
_, ax = plt.subplots(figsize=(5, 5), layout="constrained")
ax.semilogy(index_vector, eigenvalues, marker="o")
ax.set_xlabel(r"$i$")
ax.set_ylabel(r"$\lambda_i$")

In [39]:
laplace_approximation_settings = laplace.LowRankLaplaceApproximationSettings(
    inference_model=inference_model,
    mean=solver_solution.optimal_parameter,
    low_rank_hessian_eigenvalues=eigenvalues,
    low_rank_hessian_eigenvectors=eigenvectors,
)
laplace_approximation = laplace.LowRankLaplaceApproximation(laplace_approximation_settings)
posterior_variance = laplace_approximation.compute_pointwise_variance(
    method="Randomized", num_eigenvalues_randomized=100
)

In [None]:
_, ax = plt.subplots(figsize=(5, 5), layout="constrained")
ax.plot(parameter_coordinates, solver_solution.optimal_parameter, label="Mean")
ax.fill_between(
    parameter_coordinates.flatten(),
    solver_solution.optimal_parameter - 1.96 * np.sqrt(posterior_variance),
    solver_solution.optimal_parameter + 1.96 * np.sqrt(posterior_variance),
    alpha=0.3,
    label="95% CI",
)
ax.set_title("Laplace approximation")
ax.set_xticks((-1, -0.5, 0, 0.5, 1))
ax.set_xlim((-1, 1))
ax.legend()