Skip to content

Commit

Permalink
default :EMPTY
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgeny Metelkin committed May 4, 2022
1 parent fbdefe8 commit 1e47670
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion TODO.md
@@ -1,6 +1,6 @@
# TODO

- allow function for `loss_grad`
- use scale option for get_interval
- remove bound for scanned parameter
- remove warning for autodiff
- use fitting reset as an option
9 changes: 7 additions & 2 deletions src/cico_one_pass.jl
Expand Up @@ -17,8 +17,7 @@ function get_right_endpoint(
local_alg::Symbol = :LN_NELDERMEAD,
# options for local fitter :max_iter
max_iter::Int = 10^5,
#autodiff::Bool = true,
loss_grad::Union{Function, Symbol} = :AUTODIFF, #:EMPTY,
loss_grad::Union{Function, Symbol} = :EMPTY,
kwargs...
)
# dim of the theta vector
Expand All @@ -33,6 +32,12 @@ function get_right_endpoint(
show(findall(zeroParameter))
end
end

# checking loss_grad
is_gradient = occursin(r"^LD_", String(local_alg))
if loss_grad == :EMPTY && is_gradient
throw(ArgumentError("`loss_grad` must be set for gradient local fitter `$(local_alg)`"))
end

# optimizer
local_opt = Opt(local_alg, n_theta)
Expand Down
6 changes: 6 additions & 0 deletions src/get_optimal.jl
Expand Up @@ -87,6 +87,12 @@ function get_optimal(
end
end

# checking loss_grad
is_gradient = occursin(r"^LD_", String(local_alg))
if loss_grad == :EMPTY && is_gradient
throw(ArgumentError("`loss_grad` must be set for gradient local fitter `$(local_alg)`"))
end

# progress info
prog = ProgressUnknown("Fitter counter:"; spinner=false, enabled=!silent, showspeed=true)
count = 0
Expand Down
6 changes: 4 additions & 2 deletions test/test_bands.jl
Expand Up @@ -21,7 +21,8 @@ res2 = get_interval(
#theta_bounds = [(1e-2,1e2)],
loss_crit = 8.,
local_alg = :LD_MMA,
silent = true
silent = true,
loss_grad = :AUTODIFF
)
@test isapprox(res2.result[1].value, (3-sqrt(3))^2,atol=1e-2)
@test isapprox(res2.result[2].value, (3+sqrt(3))^2,atol=1e-2)
Expand All @@ -35,7 +36,8 @@ res3 = get_interval(
#theta_bounds = [(1e-2,1e2)],
loss_crit = 8.,
local_alg = :LD_MMA,
silent = true
silent = true,
loss_grad = :AUTODIFF
)
@test isapprox(res3.result[1].value, 2*log10(3-sqrt(3)),atol=1e-2)
@test isapprox(res3.result[2].value, 2*log10(3+sqrt(3)),atol=1e-2)

0 comments on commit 1e47670

Please sign in to comment.