Skip to content

Commit 4d74943

Browse files
author
romain.veltz@inria.fr
committed
🚦 simplify syntax for callback in newton
1 parent 7186d46 commit 4d74943

File tree

10 files changed

+24
-24
lines changed

10 files changed

+24
-24
lines changed

β€Žexamples/SHpde_snaking.jlβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ br, = @time continuation(
103103
normN = x -> norm(x, Inf64),
104104
recordFromSolution = (x, p) -> (s5 = x[end Γ· 5], n2 = norm(x), nw = normweighted(x), s = sum(x), s2 = x[end Γ· 2], s4 = x[end Γ· 4],),
105105
# tangentAlgo = SecantPred(),
106-
# callbackN = (x, f, J, res, iteration, itlinear, options; kwargs...) ->(true)
107106
)
108107

109108
plot(br..., legend=false, linewidth=1, vars = (:param, :n2))

β€Žexamples/brusselator.jlβ€Ž

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ outpo_f, _, flag = @time newton(poTrap,
189189
# jacobianPO = :BorderedLU,
190190
# jacobianPO = :FullSparseInplace,
191191
normN = norminf,
192-
callback = (x, f, J, res, iteration, itl, options; kwargs...) -> (println("--> amplitude = ", BK.amplitude(x, n, M; ratio = 2));true)
192+
callback = (state; kwargs...) -> (println("--> amplitude = ", BK.amplitude(state.x, n, M; ratio = 2));true)
193193
)
194194
flag && printstyled(color=:red, "--> T = ", outpo_f[end], ", amplitude = ", BK.amplitude(outpo_f, n, M; ratio = 2),"\n")
195195
BK.plotPeriodicPOTrap(outpo_f, n, M; ratio = 2)
@@ -207,7 +207,6 @@ opts_po_cont = ContinuationPar(dsmin = 0.001, dsmax = 0.1, ds= 0.01, pMax = 3.0,
207207
jacobianPO = :BorderedSparseInplace,
208208
# tangentAlgo = BorderedPred(),
209209
verbosity = 3, plot = true,
210-
# callbackN = (x, f, J, res, iteration, options; kwargs...) -> (println("--> amplitude = ", BK.amplitude(x, n, M));true),
211210
# finaliseSolution = (z, tau, step, contResult; k...) ->
212211
# (Base.display(contResult.eig[end].eigenvals) ;true),
213212
plotSolution = (x, p;kwargs...) -> heatmap!(reshape(x[1:end-1], 2*n, M)'; ylabel="time", color=:viridis, kwargs...),
@@ -247,7 +246,6 @@ opts_po_cont = ContinuationPar(dsmin = 0.0001, dsmax = 0.05, ds= 0.01, pMax = 2.
247246
opts_po_cont; jacobianPO = :BorderedMatrixFree,
248247
verbosity = 2,
249248
plot = true,
250-
# callbackN = (x, f, J, res, iteration, options; kwargs...) -> (println("--> amplitude = ", BK.amplitude(x, n, M));true),
251249
# plotSolution = (x, p;kwargs...) -> heatmap!(reshape(x[1:end-1], 2*n, M)'; ylabel="time", color=:viridis, kwargs...)
252250
normC = norminf)
253251
####################################################################################################

β€Žexamples/cGL2d-shooting.jlβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ br_po, = continuation(
172172
# probSh;
173173
ShootingProblem(Mt, prob_sp, ETDRK2(krylov = true); abstol = 1e-10, reltol = 1e-8) ;
174174
verbosity = 3, plot = true, ampfactor = 1.5, Ξ΄p = 0.01,
175-
# callbackN = (x, f, J, res, iteration, itl, options; kwargs...) -> (println("--> amplitude = ", BK.amplitude(x, n, M; ratio = 2));true),
175+
# callbackN = (state; kwargs...) -> (println("--> amplitude = ", BK.amplitude(state.x, n, M; ratio = 2));true),
176176
linearAlgo = MatrixFreeBLS(@set ls.N = Mt*2n+2),
177177
finaliseSolution = (z, tau, step, contResult; k...) ->begin
178178
BK.haseigenvalues(contResult) && Base.display(contResult.eig[end].eigenvals)

β€Žexamples/cGL2d.jlβ€Ž

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ opt_po = @set opt_newton.verbose = true
226226
orbitguess_f, (@set par_cgl.r = r_hopf - 0.01),
227227
(@set opt_po.linsolver = ls); jacobianPO = :FullMatrixFree,
228228
normN = norminf,
229-
# callback = (x, f, J, res, iteration, options) -> (println("--> amplitude = ", BK.amplitude(x, Nx*Ny, M; ratio = 2));true)
229+
# callback = (state; k...) -> (println("--> amplitude = ", BK.amplitude(state.x, Nx*Ny, M; ratio = 2));true)
230230
)
231231
flag && printstyled(color=:red, "--> T = ", outpo_f[end], ", amplitude = ", BK.amplitude(outpo_f, Nx*Ny, M; ratio = 2),"\n")
232232
plot();BK.plotPeriodicPOTrap(outpo_f, M, Nx, Ny; ratio = 2);title!("")
@@ -255,7 +255,6 @@ br_po, _ = continuation(
255255
opts_po_cont, poTrapMF;
256256
ampfactor = 3., jacobianPO = :FullMatrixFree,
257257
verbosity = 3, plot = true,
258-
# callbackN = (x, f, J, res, iteration, itl, options; kwargs...) -> (println("--> amplitude = ", BK.amplitude(x, n, M; ratio = 2));true),
259258
finaliseSolution = (z, tau, step, contResult; k...) ->
260259
(BK.haseigenvalues(contResult) && Base.display(contResult.eig[end].eigenvals) ;true),
261260
plotSolution = (x, p; kwargs...) -> BK.plotPeriodicPOTrap(x, M, Nx, Ny; ratio = 2, kwargs...),
@@ -281,12 +280,12 @@ opt_po = @set opt_newton.verbose = true
281280
(@set opt_po.linsolver = ls); jacobianPO = :BorderedMatrixFree,
282281
normN = norminf)
283282

284-
function callbackPO(x, f, J, res, iteration, linsolver = ls, prob = poTrap, p = par_cgl; kwargs...)
283+
function callbackPO(state; linsolver = ls, prob = poTrap, p = par_cgl, kwargs...)
285284
@show ls.N keys(kwargs)
286285
# we update the preconditioner every 10 continuation steps
287286
if mod(kwargs[:iterationC], 10) == 9 && iteration == 1
288287
@info "update Preconditioner"
289-
Jpo = poTrap(Val(:JacCyclicSparse), x, (@set p.r = kwargs[:p]))
288+
Jpo = poTrap(Val(:JacCyclicSparse), state.x, (@set p.r = kwargs[:p]))
290289
Precilu = @time ilu(Jpo, Ο„ = 0.003)
291290
ls.Pl = Precilu
292291
end
@@ -585,7 +584,6 @@ outpo_f, hist, flag = @time newton(
585584
poTrapMF, orbitguess_f, (@set par_cgl.r = r_hopf - 0.01),
586585
(@set opt_po.linsolver = ls); jacobianPO = :FullMatrixFree,
587586
normN = x -> maximum(abs.(x)),
588-
# callback = (x, f, J, res, iteration, options) -> (println("--> amplitude = ", amplitude(x));true)
589587
) #14s
590588
flag && printstyled(color=:red, "--> T = ", outpo_f[end], ", amplitude = ", amplitude(outpo_f, Nx*Ny, M),"\n")
591589

@@ -596,7 +594,6 @@ opt_po = @set opt_newton.verbose = true
596594
orbitguess_cu, (@set par_cgl_gpu.r = r_hopf - 0.01),
597595
(@set opt_po.linsolver = lsgpu); jacobianPO = :FullMatrixFree,
598596
normN = x -> maximum(abs.(x)),
599-
# callback = (x, f, J, res, iteration, options) -> (println("--> amplitude = ", BK.amplitude(x, Nx*Ny, M));true)
600597
) #7s
601598
flag && printstyled(color=:red, "--> T = ", outpo_f[end:end], ", amplitude = ", amplitude(outpo_f, Nx*Ny, M),"\n")
602599

β€Žexamples/carrier.jlβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ outdef1, _, flag = @time newton(
7676
# perturbsol(deflationOp[1],0,0), par_def,
7777
perturbsol(-out, 0, 0), par_def,
7878
optdef, deflationOp;
79-
# callback = (x, f, J, res, iteration, itlinear, options; kwargs...) ->(res < 1e8)
8079
)
8180
flag && push!(deflationOp, outdef1)
8281

β€Žexamples/pd-1d.jlβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ ls = GMRESIterativeSolvers(reltol = 1e-7, N = length(initpo), maxiter = 50, verb
170170
optn = NewtonPar(verbose = true, tol = 1e-9, maxIter = 20, linsolver = ls)
171171
# deflationOp = BK.DeflationOperator(2 (x,y) -> dot(x[1:end-1], y[1:end-1]),1.0, [outpo])
172172
outposh, _, flag = @time newton(probSh, initpo, par_br_hopf, optn;
173-
callbackN = (x, f, J, res, iteration; kw...) -> (@show x[end];true),
173+
callbackN = (state; kw...) -> (@show state.x[end];true),
174174
normN = norminf)
175175
flag && printstyled(color=:red, "--> T = ", outposh[end], ", amplitude = ", BK.getAmplitude(probSh, outposh, par_br_hopf; ratio = 2),"\n")
176176

@@ -219,7 +219,7 @@ ls = GMRESIterativeSolvers(reltol = 1e-7, N = length(initpo_pd), maxiter = 50, v
219219
optn = NewtonPar(verbose = true, tol = 1e-9, maxIter = 120, linsolver = ls)
220220
# deflationOp = BK.DeflationOperator(2 (x,y) -> dot(x[1:end-1], y[1:end-1]),1.0, [outpo])
221221
outposh_pd, _, flag = @time newton(probSh, initpo_pd, par_br_pd, optn;
222-
callback = (x, f, J, res, iteration, itlinear, options; kwargs...) -> (@show x[end];true),
222+
callback = (state; kwargs...) -> (@show state.x[end];true),
223223
normN = norminf)
224224
flag && printstyled(color=:red, "--> T = ", outposh_pd[end], ", amplitude = ", BK.getAmplitude(probSh, outposh_pd, (@set par_br.C = -0.86); ratio = 2),"\n")
225225

β€Žsrc/DeflatedContinuation.jlβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ function continuation(F, J, par, lens::Lens, contParams::ContinuationPar, defOp:
141141
recordFromSolution = (x, p) -> norm(x),
142142
plotSolution = (x, p ;kwargs...) -> plot!(x; kwargs...),
143143
perturbSolution = (x, p, id) -> x,
144-
callbackN = (x, f, J, res, iteration, itlinear, options; kwargs...) -> true,
144+
callbackN = (state; kwargs...) -> true,
145145
acceptSolution = (x, p) -> true,
146146
updateDeflationOp = (defOp, x, p) -> push!(defOp, x),
147147
normN = norm) where vectype

β€Žsrc/Newton.jlβ€Ž

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ This is the Newton-Krylov Solver for `F(x, p0) = 0` with Jacobian w.r.t. `x` wri
4343
- `x0` initial guess
4444
- `p0` set of parameters to be passed to `F` and `J`
4545
- `options::NewtonPar` variable holding the internal parameters used by the `newton` method
46-
- `callback` function passed by the user which is called at the end of each iteration. The default one is the following `cbDefault(x, f, J, res, it, itlinear, options; k...) = true`. Can be used to update a preconditionner for example. You can use for example `cbMaxNorm` to limit the residuals norms. If yo want to specify your own, the arguments passed to the callback are as follows
46+
- `callback` function passed by the user which is called at the end of each iteration. The default one is the following `cbDefault(state; k...) = true`. Can be used to update a preconditionner for example. You can use for example `cbMaxNorm` to limit the residuals norms. If yo want to specify your own, the elements passed in `state` to the callback are the following
4747
- `x` current solution
4848
- `f` current residual
4949
- `J` current jacobian
@@ -112,7 +112,7 @@ function newton(Fhandle, Jhandle, x0, p0, options::NewtonPar; normN = norm, call
112112
verbose && displayIteration(it, res)
113113

114114
# invoke callback before algo really starts
115-
compute = callback(x, f, nothing, res, it, 0, options; x0 = x0, resHist = resHist, fromNewton = true, kwargs...)
115+
compute = callback((;x, f, nothing, res, it, options); x0 = x0, resHist = resHist, fromNewton = true, kwargs...)
116116
# Main loop
117117
while (res > tol) && (it < maxIter) && compute
118118
J = Jhandle(x, p0)
@@ -130,10 +130,10 @@ function newton(Fhandle, Jhandle, x0, p0, options::NewtonPar; normN = norm, call
130130

131131
verbose && displayIteration(it, res, itlinear)
132132

133-
compute = callback(x, f, J, res, it, itlinear, options; x0 = x0, resHist = resHist, fromNewton = true, kwargs...)
133+
compute = callback((;x, f, J, res, it, itlinear, options, x0, resHist); fromNewton = true, kwargs...)
134134
end
135135
((resHist[end] > tol) && verbose) && @error("\n--> Newton algorithm failed to converge, residual = $(res[end])")
136-
flag = (resHist[end] < tol) & callback(x, f, nothing, res, it, nothing, options; x0 = x0, resHist = resHist, fromNewton = true, kwargs...)
136+
flag = (resHist[end] < tol) & callback((;x, f, res, it, options, x0, resHist); fromNewton = true, kwargs...)
137137
verbose && displayIteration(0, res, 0, true) # display last line of the table
138138
return x, resHist, flag, it, itlineartot
139139
end
@@ -146,7 +146,7 @@ end
146146

147147

148148
# default callback
149-
cbDefault(x, f, J, res, it, itlinear, options; k...) = true
149+
cbDefault(state; k...) = true
150150

151151
# newton callback to limit residual
152152
"""
@@ -157,4 +157,4 @@ Create a callback used to reject residals larger than `cb.maxres` in the Newton
157157
struct cbMaxNorm{T}
158158
maxres::T
159159
end
160-
(cb::cbMaxNorm)(x, f, J, res, it, itlinear, options; k...) = (return res < cb.maxres)
160+
(cb::cbMaxNorm)(state; k...) = (return state.res < cb.maxres)

β€Žsrc/Predictor.jlβ€Ž

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ function newtonPALC(F, Jh, par, paramlens::Lens,
503503
line_step = true
504504

505505
# invoke callback before algo really starts
506-
compute = callback(x, res_f, nothing, res, 0, 0, contparams; p = p, resHist = resHist, fromNewton = false, kwargs...)
506+
compute = callback((;x, res_f, res, contparams, p, resHist); fromNewton = false, kwargs...)
507507

508508
# Main loop
509509
while (res > tol) && (it < maxIter) && line_step && compute
@@ -565,10 +565,10 @@ function newtonPALC(F, Jh, par, paramlens::Lens,
565565
verbose && displayIteration(it, res, itlinear)
566566

567567
# shall we break the loop?
568-
compute = callback(x, res_f, J, res, it, itlinear, contparams; p = p, resHist = resHist, fromNewton = false, kwargs...)
568+
compute = callback((;x, res_f, J, res, it, itlinear, contparams, z0, p, resHist); fromNewton = false, kwargs...)
569569
end
570570
verbose && displayIteration(it, res, 0, true) # display last line of the table
571-
flag = (resHist[end] < tol) & callback(x, res_f, nothing, res, it, -1, contparams; p = p, resHist = resHist, fromNewton = false, kwargs...)
571+
flag = (resHist[end] < tol) & callback((;x, res_f, res, it, contparams, p, resHist); fromNewton = false, kwargs...)
572572
return BorderedArray(x, p), resHist, flag, it, itlineartot
573573
end
574574

β€Žsrc/Utils.jlβ€Ž

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
rightmost(ev) = ev[sortperm(ev, by = abs∘real)]
22
getinterval(a, b) = (min(a, b), max(a, b))
33
####################################################################################################
4+
# display eigenvals with color
5+
function displayEV(eigenvals, color = :black)
6+
for r in eigenvals
7+
printstyled(color=color, r, "\n")
8+
end
9+
end
10+
####################################################################################################
411
function displayIteration(i, residual, itlinear = 0, lastRow = false)
512
if lastRow
613
lastRow && println("β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")

0 commit comments

Comments
Β (0)