diff --git a/src/FluxOptTools.jl b/src/FluxOptTools.jl index 69b8a2f..4e88041 100644 --- a/src/FluxOptTools.jl +++ b/src/FluxOptTools.jl @@ -27,13 +27,13 @@ function optfuns(loss, pars::Union{Flux.Params, Zygote.Params}) end fg! = function (F,G,w) copy!(pars, w) - if G != nothing + if !isnothing(G) l, back = Zygote.pullback(loss, pars) grads = back(1) copy!(G, grads) return l end - if F != nothing + if !isnothing(F) return loss() end end diff --git a/test/runtests.jl b/test/runtests.jl index 0284f10..b1ca43e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using FluxOptTools, Optim, Zygote, Flux, Plots, Test, Statistics, Random @info "Testing copy" m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1)) -x = collect(LinRange(-pi,pi,100)') +x = collect(LinRange{Float32}(-pi,pi,100)') y = sin.(x) sp = sortperm(x[:]) @@ -50,7 +50,7 @@ end # NOTE: tests below fail if they are in a testset, probably Zygote's fault m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1)) -x = collect(LinRange(-pi,pi,100)') +x = collect(LinRange{Float32}(-pi,pi,100)') y = sin.(x) sp = sortperm(x[:]) @@ -59,7 +59,7 @@ loss() = mean(abs2, m(x) .- y) Zygote.refresh() pars = Flux.params(m) -opt = ADAM(0.01) +opt = Flux.Adam(0.01) @show loss() for i = 1:500 grads = Zygote.gradient(loss, pars) @@ -84,12 +84,12 @@ losses_adam = map(1:10) do i @show i Random.seed!(i) m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1)) - x = collect(LinRange(-pi,pi,100)') + x = collect(LinRange{Float32}(-pi,pi,100)') y = sin.(x) loss() = mean(abs2, m(x) .- y) Zygote.refresh() pars = Flux.params(m) - opt = Flux.ADAM(0.2) + opt = Flux.Adam(0.2) trace = [loss()] for i = 1:500 l,back = Zygote.pullback(loss, pars) @@ -104,7 +104,7 @@ res_lbfgs = map(1:10) do i @show i Random.seed!(i) m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1)) - x = LinRange(-pi,pi,100)' + x = LinRange{Float32}(-pi,pi,100)' y = sin.(x) loss() = mean(abs2, m(x) .- y) Zygote.refresh() @@ -118,7 +118,7 @@ losses_SLBFGS = map(1:10) do i @show i Random.seed!(i) m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1)) - x = LinRange(-pi,pi,100)' + x = LinRange{Float32}(-pi,pi,100)' y = sin.(x) loss() = mean(abs2, m(x) .- y) Zygote.refresh()