In [1]:
import dolfin as dl
import matplotlib.pyplot as plt
import numpy as np

from spin.core import problem
from spin.fenics import converter as fex_converter
from spin.hippylib import misfit, prior

In [2]:
mesh = dl.IntervalMesh(100, -1, 1)
#mesh = dl.RectangleMesh(dl.Point(-1, -1), dl.Point(1, 1), 10, 10)
problem_settings = problem.SPINProblemSettings(
    mesh=mesh,
    pde_type="mean_exit_time",
    inference_type="drift_only",
    drift=(("-x[0]",)),
    log_squared_diffusion=("std::log(1.0)",),
)

In [3]:
problem_builder = problem.SPINProblemBuilder(problem_settings)
spin_problem = problem_builder.build()

In [4]:
forward_vector = spin_problem.hippylib_variational_problem.generate_state()
adjoint_vector = spin_problem.hippylib_variational_problem.generate_state()
parameter_vector = spin_problem.hippylib_variational_problem.generate_parameter()
parameter_vector.set_local(-np.linspace(-1, 1, parameter_vector.size()))
state_list = [forward_vector, parameter_vector, adjoint_vector]
spin_problem.hippylib_variational_problem.solveFwd(forward_vector, state_list)
forward_array = fex_converter.convert_to_numpy(forward_vector, spin_problem.function_space_variables)
forward_array.shape

(101,)

In [7]:
prior_settings = prior.PriorSettings(
    function_space = spin_problem.function_space_parameters,
    mean=("-0.5*x[0]",),
    variance=("0.1",),
    correlation_length=("0.1",),
)
prior_builder = prior.BilaplacianVectorPriorBuilder(prior_settings)
spin_prior = prior_builder.build()
mean_array = spin_prior.mean_array
mean_array.shape

(101,)

In [6]:
misfit_settings = misfit.MisfitSettings(
    function_space = spin_problem.function_space_variables,
    observation_points=np.linspace(-0.9, 0.9, 10),
    observation_values=None
)

BeartypeCallHintParamViolation: Method spin.hippylib.misfit.MisfitSettings.__init__() parameter observation_values="None" violates type hint numpy.ndarray[typing.Any, numpy.dtype[numpy.floating]] | collections.abc.Iterable[numpy.ndarray[typing.Any, numpy.dtype[numpy.floating]]] | collections.abc.Iterable[collections.abc.Iterable[numpy.ndarray[typing.Any, numpy.dtype[numpy.floating]]]], as <class "builtins.NoneType"> "None":
* Not <protocol ABC "collections.abc.Iterable">.
* Not instance of <class "numpy.ndarray">.