-
Notifications
You must be signed in to change notification settings - Fork 4
/
runtests.jl
148 lines (124 loc) · 3.67 KB
/
runtests.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
if haskey(ENV, "CI")
ENV["PLOTS_TEST"] = "true"
ENV["GKSwstype"] = "100" # gr segfault workaround
end
using FluxOptTools, Optim, Zygote, Flux, Plots, Test, Statistics, Random
##
@testset "FluxOptTools" begin
@info "Testing FluxOptTools"
@testset "copy" begin
@info "Testing copy"
m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1))
x = collect(LinRange{Float32}(-pi,pi,100)')
y = sin.(x)
sp = sortperm(x[:])
loss() = mean(abs2, m(x) .- y)
Zygote.refresh()
pars = Flux.params(m)
pars0 = deepcopy(pars)
npars = veclength(pars)
@test npars == 46
copy!(pars, zeros(pars))
@test all(all(iszero, p) for p in pars)
p = zeros(pars)
copy!(pars, 1:npars)
copy!(p, pars)
@test p == 1:npars
grads = Zygote.gradient(loss, pars)
grads0 = deepcopy(grads)
copy!(grads, zeros(grads))
@test all(all(iszero,grads[k]) for k in keys(grads.grads))
p = zeros(grads)
copy!(grads, 1:npars)
copy!(p, grads)
@test p == 1:npars
end
## Test optimization ============================================
end
# NOTE: tests below fail if they are in a testset, probably Zygote's fault
m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1))
x = collect(LinRange{Float32}(-pi,pi,100)')
y = sin.(x)
sp = sortperm(x[:])
loss() = mean(abs2, m(x) .- y)
@show loss()
Zygote.refresh()
pars = Flux.params(m)
opt = Flux.Adam(0.01)
@show loss()
for i = 1:500
grads = Zygote.gradient(loss, pars)
Flux.Optimise.update!(opt, pars, grads)
end
@show loss()
@test loss() < 1e-1
plot(x[sp], [y[sp] m(x)[sp]]) |> display
contourf(() -> log10(1 + loss()), pars, color=:turbo, npoints=50, lnorm=1, seed=1234) |> display
lossfun, gradfun, fg!, p0 = optfuns(loss, pars)
res = Optim.optimize(Optim.only_fg!(fg!), p0, BFGS())
@test loss() < 1e-3
contourf(() -> log10(1 + loss()), pars, color=:turbo, npoints=50, lnorm=1) |> display
plot(x[sp], [y[sp] m(x)[sp]]) |> display
## Benchmark Optim vs ADAM
losses_adam = map(1:10) do i
@show i
Random.seed!(i)
m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1))
x = collect(LinRange{Float32}(-pi,pi,100)')
y = sin.(x)
loss() = mean(abs2, m(x) .- y)
Zygote.refresh()
pars = Flux.params(m)
opt = Flux.Adam(0.2)
trace = [loss()]
for i = 1:500
l,back = Zygote.pullback(loss, pars)
push!(trace, l)
grads = back(l)
Flux.Optimise.update!(opt, pars, grads)
end
trace
end
res_lbfgs = map(1:10) do i
@show i
Random.seed!(i)
m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1))
x = LinRange{Float32}(-pi,pi,100)'
y = sin.(x)
loss() = mean(abs2, m(x) .- y)
Zygote.refresh()
pars = Flux.params(m)
lossfun, gradfun, fg!, p0 = optfuns(loss, pars)
res = Optim.optimize(Optim.only_fg!(fg!), p0, BFGS(), Optim.Options(iterations=500, store_trace=true))
res
end
losses_SLBFGS = map(1:10) do i
@show i
Random.seed!(i)
m = Chain(Dense(1,5,tanh), Dense(5,5,tanh) , Dense(5,1))
x = LinRange{Float32}(-pi,pi,100)'
y = sin.(x)
loss() = mean(abs2, m(x) .- y)
Zygote.refresh()
pars = Flux.params(m)
lossfun, gradfun, fg!, p0 = optfuns(loss, pars)
opt = SLBFGS(lossfun,p0; m=3, ᾱ=1., ρ=false, λ=.0001, κ=0.1)
function train(opt, p0, iters=20)
p = copy(p0)
g = zeros(veclength(pars))
trace = [loss()]
for i = 1:iters
g = gradfun(g,p)
p = apply(opt, g, p)
push!(trace, opt.fold)
end
trace
end
trace = train(opt,p0, 500)
end
##
valuetrace(r) = getfield.(r.trace, :value)
valuetraces = valuetrace.(res_lbfgs)
plot(valuetraces, yscale=:log10, xscale=:identity, lab="", c=:red)
plot!(losses_adam, lab="", c=:blue, xlabel="Epochs", ylabel="Loss")
plot!(losses_SLBFGS, lab="", c=:green)