Skip to content

Commit

Permalink
More renaming and fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bgroenks96 committed May 27, 2024
1 parent 19e2d8a commit d46bdad
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 54 deletions.
2 changes: 1 addition & 1 deletion examples/Tsurf_inv/run_Tsurf_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ snpe_output, _ = produce_or_load(snpe_config, outdir, filename="snpe_inference_d
θ = zero(inference_prob.u0) + θ
ϕ = SBI.forward_map(inference_prob.prior, θ)
sol = solve(forward_prob, LiteImplicitEuler(), p=ϕ.model)
map(retrieve, sol.prob.observables)
map(getvalue, sol.prob.observables)
end
preds = SBI.ntreduce(hcat, observables)
@strdict sol posterior preds
Expand Down
4 changes: 2 additions & 2 deletions src/SimulationBasedInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using UnPack
# Re-exported packages
@reexport using Bijectors
@reexport using ComponentArrays
@reexport using DimensionalData: Dimension, Dim, DimArray, X, Y, Z, Ti
@reexport using DimensionalData: DimensionalData, Dimension, Dim, DimArray, X, Y, Z, Ti
@reexport using Distributions
@reexport using PosteriorStats: PosteriorStats, summarize
@reexport using SciMLBase
Expand Down Expand Up @@ -75,7 +75,7 @@ export store!, getinputs, getoutputs, getmetadata
include("simulation_data.jl")

export SimulatorObservable, TimeSampledObservable, TransientObservable
export observe!, retrieve, coordinates
export observe!, getvalue, coordinates
include("observables.jl")

export AbstractSimulatorPrior, NamedProductPrior
Expand Down
2 changes: 1 addition & 1 deletion src/ensembles/ensemble_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ function ensemble_solve(
pred = reduce(hcat, map((i,out) -> pred_func(out, i, iter), 1:N_ens, enssol.u))
observables = map(enssol) do sol
# extract observables data
map(retrieve, sol.prob.observables)
map(getvalue, sol.prob.observables)
end
return (; pred, observables)
end
Expand Down
17 changes: 11 additions & 6 deletions src/forward_solve_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,26 @@ function CommonSolve.init(
p=forward_prob.prob.p,
saveat=[],
save_everystep=false,
copy_observables=false,
solve_kwargs...
)
# collect and combine sample points from all obsevables
t_sample = map(sampletimes, forward_prob.observables)
t_sample_all = forward_prob.config.obs_to_prob_time.(sort(unique(union(t_sample...))))
t_points = if isempty(t_sample_all) || t_sample_all[end] < forward_prob.tspan[end]
vcat(t_sample_all, [forward_prob.tspan[end]])
adstrip.(vcat(t_sample_all, [forward_prob.tspan[end]]))
else
t_sample_all
adstrip.(t_sample_all)
end
# reinitialize inner problem with new parameters
newprob = remake(forward_prob, p=p, copy_observables=false)
newprob = remake(forward_prob; p, copy_observables)
# initialize integrator with built-in saving disabled
integrator = init(newprob.prob, ode_alg; saveat, save_everystep, solve_kwargs...)
# initialize observables
for obs in newprob.observables
initialize!(obs, integrator)
end
return SimulatorODEForwardSolver(newprob, integrator, adstrip.(t_points), 1)
return SimulatorODEForwardSolver(newprob, integrator, t_points, 1)
end

function CommonSolve.step!(forward::SimulatorODEForwardSolver)
Expand All @@ -96,8 +97,12 @@ function CommonSolve.step!(forward::SimulatorODEForwardSolver)
if forward.step_idx > length(forward.tstops)
return forwardstep!(forward.integrator)
end
# otherwise, evaluate the next step and observables
retval = forwardstep!(integrator, dt, true)
# otherwise, evaluate the next step and observables if dt > 0
retval = if dt > zero(dt)
forwardstep!(integrator, dt, true)
else
nothing
end
# iterate over observables and update those for which t is a sample point
for obs in prob.observables
if t map(prob.config.obs_to_prob_time, sampletimes(obs))
Expand Down
4 changes: 2 additions & 2 deletions src/inference_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ SimulatorInferenceProblem(forward_prob::SimulatorForwardProblem, prior::Abstract
SimulatorInferenceProblem(forward_prob, nothing, prior, likelihoods...; kwargs...)

"""
prior(prob::SimulatorInferenceProblem)
getprior(prob::SimulatorInferenceProblem)
Retrieves the prior from the given `SimulatorInferenceProblem`.
"""
prior(prob::SimulatorInferenceProblem) = prob.prior
getprior(prob::SimulatorInferenceProblem) = prob.prior

SciMLBase.isinplace(prob::SimulatorInferenceProblem) = false

Expand Down
2 changes: 1 addition & 1 deletion src/likelihoods/joint_prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
Constructs a `JointPrior` from the given prior and likelihoods.
"""
function JointPrior(modelprior::AbstractSimulatorPrior, liks::SimulatorLikelihood...)
lik_priors = with_names(filter(!isnothing, map(prior, liks)))
lik_priors = (; filter(x -> !isnothing(x[2]), map(lik -> nameof(lik) => getprior(lik), liks))...)
param_nt = merge(
(model=rand(modelprior),),
map(d -> rand(d), lik_priors),
Expand Down
2 changes: 1 addition & 1 deletion src/likelihoods/likelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Base.nameof(l::SimulatorLikelihood) = l.name

observable(lik::SimulatorLikelihood)::SimulatorObservable = lik.obs

prior(lik::SimulatorLikelihood) = lik.prior
getprior(lik::SimulatorLikelihood) = lik.prior

"""
predictive_distribution(lik::SimulatorLikelihood, args...)
Expand Down
18 changes: 15 additions & 3 deletions src/observables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ function observe!(obs::SimulatorObservable{N,Transient}, state) where {N}
end

function getvalue(obs::SimulatorObservable{N,Transient}, ::Type{T}=Any) where {N,T}
return obs.output.state
data = obs.output.state
coords = coordinates(obs)
return DimArray(data, coords)
end

function setvalue!(obs::SimulatorObservable{N,Transient}, value) where {N}
Expand Down Expand Up @@ -239,9 +241,12 @@ end

function getvalue(obs::TimeSampledObservable, ::Type{TT}=Float64) where {TT}
@assert !isnothing(obs.output.buffer) "observable not yet initialized"
@assert length(obs.output.output) > 0 "output buffer is empty; check for errors in the model evaluation"
out = reduce(hcat, obs.output.output)
# drop first dimension if it is of length 1
return size(out,1) == 1 ? dropdims(out, dims=1) : out
coords = coordinates(obs)
darr = DimArray(out, coords)
singleton_dims = filter(c -> length(c) == 1, coords)
return dropdims(darr, dims=singleton_dims)
end

function setvalue!(obs::TimeSampledObservable, values::AbstractMatrix)
Expand All @@ -259,3 +264,10 @@ unflatten(obs::TimeSampledObservable, x::AbstractVector) = reshape(x, length(fir
_coerce(output::AbstractVector, shape::Dims) = reshape(output, shape)
_coerce(output::Number, ::Tuple{}) = [output] # lift to single element vector
_coerce(output, shape) = error("output of observable function must be a scalar or a vector! expected: $(shape), got $(typeof(output)) with $(size(output))")
function _coerce(output::Number, shape::Dims{1})
if shape[1] == 1
return [output]
else
error("scalar output does not match expected dimensions $shape")
end
end
103 changes: 68 additions & 35 deletions test/env/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,9 @@ version = "0.4.0"

[[deps.DiffEqBase]]
deps = ["ArrayInterface", "ConcreteStructs", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "FastClosures", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"]
git-tree-sha1 = "03b9555f4c3a7c2f530bb1ae13e85719c632f74e"
git-tree-sha1 = "37d49a1f8eedfe68b7622075ff3ebe3dd0e8f327"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
version = "6.151.1"
version = "6.151.2"

[deps.DiffEqBase.extensions]
DiffEqBaseCUDAExt = "CUDA"
Expand Down Expand Up @@ -566,6 +566,42 @@ git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.15.1"

[[deps.DifferentiationInterface]]
deps = ["ADTypes", "Compat", "DocStringExtensions", "FillArrays", "LinearAlgebra", "PackageExtensionCompat", "SparseArrays", "SparseMatrixColorings"]
git-tree-sha1 = "42f8413c004d68a392c76c02fe1db9b2567bffa9"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
version = "0.4.1"

[deps.DifferentiationInterface.extensions]
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
DifferentiationInterfaceDiffractorExt = "Diffractor"
DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTapirExt = "Tapir"
DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = "Zygote"

[deps.DifferentiationInterface.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.DimensionalData]]
deps = ["Adapt", "ArrayInterface", "ConstructionBase", "DataAPI", "Dates", "Extents", "Interfaces", "IntervalSets", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "PrecompileTools", "Random", "RecipesBase", "SparseArrays", "Statistics", "TableTraits", "Tables"]
git-tree-sha1 = "5f3bb465f4b06b25e9bbe8f1d9711834ef4697d6"
Expand Down Expand Up @@ -665,9 +701,9 @@ uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
version = "1.0.4"

[[deps.EnzymeCore]]
git-tree-sha1 = "18394bc78ac2814ff38fe5e0c9dc2cd171e2810c"
git-tree-sha1 = "0910982db2490a20f81dc7db5d4bbea236c027b3"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
version = "0.7.2"
version = "0.7.3"
weakdeps = ["Adapt"]

[deps.EnzymeCore.extensions]
Expand Down Expand Up @@ -1131,9 +1167,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LinearSolve]]
deps = ["ArrayInterface", "ChainRulesCore", "ConcreteStructs", "DocStringExtensions", "EnumX", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "LazyArrays", "Libdl", "LinearAlgebra", "MKL_jll", "Markdown", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "StaticArraysCore", "UnPack"]
git-tree-sha1 = "efd815eaa56c0ffdf86581df5aaefb7e901323a0"
git-tree-sha1 = "7648cc20100504f4b453917aacc8520e9c0ecfb3"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
version = "2.30.0"
version = "2.30.1"

[deps.LinearSolve.extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
Expand Down Expand Up @@ -1276,9 +1312,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[deps.MathOptInterface]]
deps = ["BenchmarkTools", "CodecBzip2", "CodecZlib", "DataStructures", "ForwardDiff", "JSON", "LinearAlgebra", "MutableArithmetics", "NaNMath", "OrderedCollections", "PrecompileTools", "Printf", "SparseArrays", "SpecialFunctions", "Test", "Unicode"]
git-tree-sha1 = "9cc5acd6b76174da7503d1de3a6f8cf639b6e5cb"
git-tree-sha1 = "fffbbdbc10ba66885b7b4c06f4bd2c0efc5813d6"
uuid = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
version = "1.29.0"
version = "1.30.0"

[[deps.MatrixFactorizations]]
deps = ["ArrayLayouts", "LinearAlgebra", "Printf", "Random"]
Expand All @@ -1288,9 +1324,9 @@ version = "2.2.0"

[[deps.MaybeInplace]]
deps = ["ArrayInterface", "LinearAlgebra", "MacroTools", "SparseArrays"]
git-tree-sha1 = "b1f2f92feb0bc201e91c155ef575bcc7d9cc3526"
git-tree-sha1 = "1b9e613f2ca3b6cdcbfe36381e17ca2b66d4b3a1"
uuid = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
version = "0.1.2"
version = "0.1.3"

[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
Expand Down Expand Up @@ -1386,9 +1422,9 @@ version = "1.2.0"

[[deps.NonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "SymbolicIndexingInterface", "TimerOutputs"]
git-tree-sha1 = "3939ebffebff79db0442103b6d3a5e8c50cbf43c"
git-tree-sha1 = "a5bc9c06e28108e04de0485273f0b5933cec66ed"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
version = "3.12.0"
version = "3.12.3"

[deps.NonlinearSolve.extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
Expand Down Expand Up @@ -1503,9 +1539,9 @@ version = "1.6.3"

[[deps.OrdinaryDiffEq]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "DataStructures", "DiffEqBase", "DocStringExtensions", "EnumX", "ExponentialUtilities", "FastBroadcast", "FastClosures", "FillArrays", "FiniteDiff", "ForwardDiff", "FunctionWrappersWrappers", "IfElse", "InteractiveUtils", "LineSearches", "LinearAlgebra", "LinearSolve", "Logging", "MacroTools", "MuladdMacro", "NonlinearSolve", "Polyester", "PreallocationTools", "PrecompileTools", "Preferences", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "SciMLStructures", "SimpleNonlinearSolve", "SimpleUnPack", "SparseArrays", "SparseDiffTools", "StaticArrayInterface", "StaticArrays", "TruncatedStacktraces"]
git-tree-sha1 = "b2cd04c49e8c7c6f705b527898fc843c3aa90605"
git-tree-sha1 = "75b0d2bf28d0df92931919004a5be5304c38cca2"
uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
version = "6.78.0"
version = "6.80.1"

[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
Expand Down Expand Up @@ -1602,9 +1638,9 @@ version = "1.4.3"

[[deps.PrettyTables]]
deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"]
git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660"
git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7"
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
version = "2.3.1"
version = "2.3.2"

[[deps.Printf]]
deps = ["Unicode"]
Expand All @@ -1627,9 +1663,9 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.10.0"

[[deps.PtrArrays]]
git-tree-sha1 = "077664975d750757f30e739c870fbbdc01db7913"
git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759"
uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
version = "1.1.0"
version = "1.2.0"

[[deps.PythonCall]]
deps = ["CondaPkg", "Dates", "Libdl", "MacroTools", "Markdown", "Pkg", "REPL", "Requires", "Serialization", "Tables", "UnsafePointers"]
Expand Down Expand Up @@ -1816,9 +1852,9 @@ version = "0.6.42"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "265f1a7a804d8093fa0b17e33e45373a77e56ca5"
git-tree-sha1 = "9f59654e2a85017ee27b0f59c7fac5a57aa10ced"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
version = "2.38.0"
version = "2.39.0"

[deps.SciMLBase.extensions]
SciMLBaseChainRulesCoreExt = "ChainRulesCore"
Expand Down Expand Up @@ -1887,27 +1923,18 @@ deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[deps.SimpleNonlinearSolve]]
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "DiffResults", "FastClosures", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "MaybeInplace", "PrecompileTools", "Reexport", "SciMLBase", "StaticArraysCore"]
git-tree-sha1 = "c020028bb22a2f23cbd88cb92cf47cbb8c98513f"
deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "DiffResults", "DifferentiationInterface", "FastClosures", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "MaybeInplace", "PrecompileTools", "Reexport", "SciMLBase", "Setfield", "StaticArraysCore"]
git-tree-sha1 = "f3c50acd5145f2c6ee85343ce6f433dd2caab41e"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
version = "1.8.0"
version = "1.9.0"
weakdeps = ["ChainRulesCore", "ReverseDiff", "Tracker", "Zygote"]

[deps.SimpleNonlinearSolve.extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveTrackerExt = "Tracker"
SimpleNonlinearSolveZygoteExt = "Zygote"

[deps.SimpleNonlinearSolve.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
Expand Down Expand Up @@ -1977,6 +2004,12 @@ git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852"
uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada"
version = "0.1.2"

[[deps.SparseMatrixColorings]]
deps = ["ADTypes", "Compat", "DocStringExtensions", "LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "d4adedbcc8776c567e0e22ef19f13cf10cb6ecaa"
uuid = "0a514795-09f3-496d-8182-132a7b665d35"
version = "0.3.2"

[[deps.Sparspak]]
deps = ["Libdl", "LinearAlgebra", "Logging", "OffsetArrays", "Printf", "SparseArrays", "Test"]
git-tree-sha1 = "342cf4b449c299d8d1ceaf00b7a49f4fbc7940e7"
Expand Down Expand Up @@ -2268,9 +2301,9 @@ version = "1.0.0"

[[deps.VectorizationBase]]
deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"]
git-tree-sha1 = "6129a4faf6242e7c3581116fbe3270f3ab17c90d"
git-tree-sha1 = "e863582a41c5731f51fd050563ae91eb33cf09be"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.21.67"
version = "0.21.68"

[[deps.VertexSafeGraphs]]
deps = ["Graphs"]
Expand Down
6 changes: 4 additions & 2 deletions test/problem_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ using Test
observable = SimulatorObservable(:u, state -> state.u, 0.0, 0.1:0.1:1.0, size(odeprob.u0), samplerate=0.01)
forwardprob = SimulatorForwardProblem(odeprob, observable)
forward_sol = solve(forwardprob, Tsit5())
@test forward_sol.sol.retcode == ReturnCode.Success
@test isa(forward_sol, SimulatorForwardSolution)
obs = getvalue(observable)
@test size(obs) == (10,)
@test all(diff(obs[1,:]) .< 0.0)
@test all(diff(obs) .< 0.0)
end

@testset "Forward NonlinearProblem" begin
Expand All @@ -25,9 +26,10 @@ end
observable = SimulatorObservable(:u, state -> state.u, size(nlprob.u0))
forwardprob = SimulatorForwardProblem(nlprob, observable)
forward_sol = solve(forwardprob, NewtonRaphson(), abstol=1e-6, reltol=1e-8)
@test forward_sol.sol.retcode == ReturnCode.Success
@test isa(forward_sol, SimulatorForwardSolution)
obs = getvalue(observable)
@test isa(obs, Vector{Float64})
@test isa(obs, AbstractVector{Float64})
@test round(obs[1], digits=3) == -20.271
end

Expand Down

0 comments on commit d46bdad

Please sign in to comment.