-
Notifications
You must be signed in to change notification settings - Fork 7
/
linreg_sgd.jl
33 lines (25 loc) · 1.06 KB
/
linreg_sgd.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# use SGD for linear regression
using SGDOptim
function linreg_sgd(θ_g::Vector{Float64}, n::Int, σ::Float64)
# prepare experimental data
d = length(θ_g) - 1
X = randn(d, n)
y = vec(θ_g[1:d]'X) + θ_g[d+1] + σ * randn(n)
# initialize solution
θ_0 = zeros(d + 1)
# optimize
rmodel = riskmodel(AffinePred(d), SqrLoss())
θ = sgd(rmodel, θ_0,
minibatch_seq(X, y, 10), # configure the way data are supplied
reg = SqrL2Reg(1.0e-4), # regularization
lrate = t->1.0 / (100.0 + t), # learing rate policy
cbinterval = 5, # how frequently callback is invoked
callback = gtcompare_trace(θ_g)) # the callback function
# compare solution with initial guess
println()
@printf("Initial: deviation = %.4e | avg.risk = %.4e\n",
vecnorm(θ_0 - θ_g), value(rmodel, θ_0, X, y) / n)
@printf("Solution: deviation = %.4e | avg.risk = %.4e\n",
vecnorm(θ - θ_g), value(rmodel, θ, X, y) / n)
end
linreg_sgd([3.0, 5.0, 2.0], 10000, 0.1)