In [None]:
import os
import numpy as np

from sp_inference import logging, processes, model, sampling, postprocessing 

In [None]:
logSettings = {
    "output_directory":      "03_example_transient",
    "verbose":               True,
    "print_interval":        10
}

dataSettings = {
    "process_type":          "Dr31Di21Process",
    "drift_parameters":      [2, 3],
    "diffusion_parameters":  [1, 2],
    "standard_deviation":    0.01,
    "rng_seed":              0,
    "num_domain_points":     50,
    "num_time_points":       10,
    "domain_bounds":         [-1.5, 1.5],
    "time_bounds":           [0.1, 1]
}

modelSettings = {
    "params_to_infer":       "all",
    "model_type":            "fokker_planck",
    "is_stationary":         False
}

priorSettings = {
    "mean_function":         [lambda x: -x, lambda x: 3*np.ones(x.shape)],
    "gamma":                 0.5,
    "delta":                 0.5,
    "robin_bc":              False
}

feSettings = {
    "num_mesh_points":       500,
    "boundary_locations":    [-3.5, 3.5],
    "boundary_values":       [0, 0],
    "element_degrees":       [1, 1]
}

transientSettings = {
    "start_time":            0.1,
    "end_time":              1.1,
    "time_step_size":        0.025,
    "initial_condition":     lambda x: 1/(0.5*np.sqrt(2*np.pi)) * \
                                       np.exp(-0.5*np.square(x)/0.5**2)
}

solverSettings = {
    "rel_tolerance":         1e-6,
    "abs_tolerance":         1e-12,
    "max_iter":              10,
    "GN_iter":               5,
    "c_armijo":              1e-4,
    "max_backtracking_iter": 10
}

hessianSettings = {
    "num_eigvals":           35,
    "num_oversampling":      10
}

visualizationSettings = {
    "show":                  True,
    "time_points":           [0.1, 1.0]
}

In [None]:
try:
    os.system('rm -r ' + logSettings["output_directory"])
except:
    pass

logger = logging.Logger(logSettings["verbose"],
                        logSettings["output_directory"],
                        logSettings["print_interval"])

In [None]:
randGenerator = np.random.default_rng(dataSettings["rng_seed"])
randLocs = randGenerator.uniform(*dataSettings["domain_bounds"], dataSettings["num_domain_points"])
randTimes = np.linspace(*dataSettings["time_bounds"], dataSettings["num_time_points"])

dataSettings["domain_points"] = randLocs
dataSettings["time_points"] = randTimes
dataSettings["fem_settings"] = feSettings
dataSettings["solver_settings"] = transientSettings

processType = processes.get_process(dataSettings["process_type"])
process = processType(dataSettings["drift_parameters"],
                      dataSettings["diffusion_parameters"],
                      logger)

forwardNoisy, forwardExact = process.generate_data(modelSettings["model_type"],
                                                   modelSettings["is_stationary"],
                                                   dataSettings)

exactDrift = process.compute_drift(randLocs)
exactDiffusion = process.compute_squared_diffusion(randLocs)
exactParamValues = np.column_stack((exactDrift, exactDiffusion))

exactParamData = [randLocs, exactParamValues]
randForwardData = [randLocs, randTimes, forwardNoisy]
exactForwardData = [randLocs, randTimes, forwardExact]

misFitSettings = {
    "data_locations": randLocs,
    "data_times": randTimes,
    "data_values": forwardNoisy,
    "data_std": dataSettings["standard_deviation"]
}

In [None]:
inferenceModel = model.SDEInferenceModel(modelSettings,
                                         priorSettings,
                                         feSettings,
                                         misFitSettings,
                                         transientSettings,
                                         logger)

priorMeanData, priorVarianceData, priorForwardData = inferenceModel.get_prior_info("Randomized")

In [None]:
mapMeanData, mapVarianceData, mapForwardData, hessEigVals \
    = inferenceModel.compute_gr_posterior(solverSettings, hessianSettings)

In [None]:
paramData = {"prior_mean": priorMeanData,
             "prior_variance": priorVarianceData,
             "posterior_mean": mapMeanData,
             "posterior_variance": mapVarianceData,
             "exact": exactParamData}

forwardData = {"prior": priorForwardData,
               "posterior": mapForwardData,
               "noisy": randForwardData,
               "exact": exactForwardData,
               "times": visualizationSettings["time_points"]}

postprocessor = postprocessing.Postprocessor(show=visualizationSettings["show"], logger=logger)
postprocessor.visualize_hessian_data(hessEigVals)
postprocessor.visualize_parameters(paramsInferred=modelSettings["params_to_infer"],
                                   mode="linearized",
                                   data=paramData)                              
postprocessor.visualize_forward_solution(modelType=modelSettings["model_type"],
                                         mode="linearized",
                                         isStationary=modelSettings["is_stationary"],
                                         data=forwardData)

In [None]:
MCMCSampler = sampling.MCMCSampler(inferenceModel, logSettings, samplerSettings)
MCMCMeanData, MCMCVarianceData, MCMCForwardData = MCMCSampler.run(samplingRunSettings)
qoiData = MCMCSampler.evaluate_qoi()

In [None]:
postprocessor = postprocessing.Postprocessor(logSettings, show=True)
postprocessor.visualize_hessian_data(hessEigVals)
postprocessor.visualize_parameters(paramsInferred=modelSettings["params_to_infer"],
                                   mode="linearized",
                                   priorMeanData=priorMeanData,
                                   priorVarianceData=priorVarianceData,
                                   posteriorMeanData=mapMeanData,
                                   posteriorVarianceData=mapVarianceData,
                                   exactData = [randLocs, exactParamValues])
postprocessor.visualize_forward_solution(modelType=modelSettings["model_type"],
                                         mode="linearized",
                                         isStationary=modelSettings["is_stationary"],
                                         priorData=priorForwardData,
                                         posteriorData=mapForwardData,
                                         randData=randForwardData,
                                         timePoints=visTimePoints)
postprocessor.visualize_parameters(paramsInferred=modelSettings["params_to_infer"],
                                   mode="mcmc",
                                   priorMeanData=priorMeanData,
                                   priorVarianceData=priorVarianceData,
                                   posteriorMeanData=MCMCMeanData,
                                   posteriorVarianceData=MCMCVarianceData,
                                   exactData = [randLocs, exactParamValues])
postprocessor.visualize_forward_solution(modelType=modelSettings["model_type"],
                                         mode="mcmc",
                                         isStationary=modelSettings["is_stationary"],
                                         priorData=priorForwardData,
                                         posteriorData=MCMCForwardData,
                                         randData=randForwardData,
                                         timePoints=visTimePoints)
postprocessor.postprocess_qoi(qoiData, maxLag=100)