---
# Section 7.4: Hyperparameters
---

In [None]:
using Plots, LaTeXStrings
using LinearAlgebra, Statistics, Random, Printf
using Parameters: @with_kw

using MLDataUtils
using MLBase

using Flux
using Flux: params
using Flux.Losses: mse
using Flux.Data: DataLoader
using Flux.Optimise
using Flux.Losses: binarycrossentropy

In [None]:
# Himmelblau function

f(x) = (x[1]^2 + x[2] - 11)^2 + (x[1] + x[2]^2 - 7)^2
f(x,y) = f([x,y])

ax, bx = -6, 6
ay, by = -6, 6

xx = range(ax, bx, length=200)
yy = range(ay, by, length=200)
flevels = [0, 5, 20, 40, 60, 80, 100, 120, 150, 180, 300, 400, 600]

plt1 = plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by), legend=:none)
contour!(xx, yy, f, levels=flevels, color=1, contour_labels=true)

In [None]:
N = 2000

X = 12*rand(Float32, 2, N) .- 6
y = reshape([f(X[:,i]) for i=1:N],1,N)

train_inds, test_inds = splitobs(N, at=0.8)

X_train, X_test = X[:,train_inds], X[:,test_inds]
y_train, y_test = y[:,train_inds], y[:,test_inds]
size(X_train), size(X_test)

In [None]:
@with_kw mutable struct Args
    optalg = ADAM
    batchsize::Int = 80
    lr::Float64 = NaN
    epochs::Int = 1000
end

args = Args()

In [None]:
function estfun(train_inds; kws...)
    
    X̄ = X[:, train_inds]
    ȳ = y[:, train_inds]

    args = Args(; kws...)
    
    Random.seed!(1234)
    model = Chain(
        Dense(2, 16, relu),
        Dense(16, 16, relu),
        Dense(16, 16, relu),
        Dense(16, 8, relu),
        Dense(8, 8, relu),
        Dense(8, 8, relu),
        Dense(8, 4, relu),
        Dense(4, 4, relu),
        Dense(4, 2, relu),
        Dense(2, 1))

    loss(x, y) = mse(model(x), y)
    ps = params(model)
    data = DataLoader((X̄, ȳ), batchsize=args.batchsize)
    if isnan(args.lr)
        opt = args.optalg()
    else
        opt = args.optalg(args.lr)
    end

    for epoch = 1:args.epochs
        train!(loss, ps, data, opt)
    end
    
    return model
end    

In [None]:
function evalfun(model, test_inds; kws...)
        
    loss(x, y) = mse(model(x), y)
    
    X̄ = X[:, test_inds]
    ȳ = y[:, test_inds]
    
    score = loss(X̄, ȳ)
    
    return score
end

In [None]:
model = estfun(train_inds; epochs=100)
score = evalfun(model, test_inds)

---

# Cross-validation

[MLBase cross-validation documentation](https://mlbasejl.readthedocs.io/en/latest/crossval.html)

In [None]:
# Cross-validation for optimization algorithm

@printf("%14s %20s\n", "algorithm", "score")
for optalg in [Descent, Momentum, Nesterov, ADAM]
    scores = cross_validate(
        inds -> estfun(inds; optalg=optalg, epochs=100),
        evalfun,
        N,
        Kfold(N, 4))

    m, s = mean_and_std(scores)
    
    score = @sprintf("%.1f ± %.1f", m, s)
    @printf("%14s %20s\n", string(optalg), score)
end

In [None]:
# Cross-validation for batchsize

@printf("%14s %20s\n", "batchsize", "score")
for batchsize in [40, 80, 100]
    scores = cross_validate(
        inds -> estfun(inds; batchsize=batchsize, epochs=1000),
        evalfun,
        N,
        Kfold(N, 4))

    m, s = mean_and_std(scores)
    
    score = @sprintf("%.1f ± %.1f", m, s)
    @printf("%14d %20s\n", batchsize, score)
end

In [None]:
# Cross-validation for learning rate

@printf("%14s %20s\n", "learning rate", "score")
for lr in [1e-2, 1e-3, 1e-4]
    scores = cross_validate(
        inds -> estfun(inds; lr=lr, epochs=1000),
        evalfun,
        N,
        Kfold(N, 4))

    m, s = mean_and_std(scores)
    
    score = @sprintf("%.1f ± %.1f", m, s)
    @printf("%14.0e %20s\n", lr, score)
end

In [None]:
model = estfun(train_inds)
score = evalfun(model, test_inds)

In [None]:
F(x,y) = model(Matrix([x y]'))[1]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contour!(xx, yy, f, levels=flevels, color=1, contour_labels=true)
contour!(xx, yy, F, levels=flevels, color=:black, contour_labels=true)

In [None]:
plt1 = plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by), legend=:none)
contour!(xx, yy, f, levels=flevels, color=1, contour_labels=true)

plt2 = plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contour!(xx, yy, F, levels=flevels, color=:black, contour_labels=true)

plot(plt1, plt2, layout=(1,2), size=(900,500))

---
# Batch Normalization

In [None]:
function estfun_batch(train_inds; kws...)
    
    X̄ = X[:, train_inds]
    ȳ = y[:, train_inds]

    args = Args(; kws...)
    
    Random.seed!(1234)
    model = Chain(
        Dense(2, 16, sigmoid),
        BatchNorm(16),
        Dense(16, 8, sigmoid),
        BatchNorm(8),
        Dense(8, 4, sigmoid),
        BatchNorm(4),
        Dense(4, 2, sigmoid),
        BatchNorm(2),
        Dense(2, 1))

    loss(x, y) = mse(model(x), y)
    ps = params(model)
    data = DataLoader((X̄, ȳ), batchsize=args.batchsize)
    if isnan(args.lr)
        opt = args.optalg()
    else
        opt = args.optalg(args.lr)
    end

    for epoch = 1:args.epochs
        train!(loss, ps, data, opt)
    end
    
    return model
end    

In [None]:
model = estfun_batch(train_inds)
score = evalfun(model, test_inds)

In [None]:
k = 8
@show k
@show model[k]
layers = Flux.activations(model, X_train)
C = cov(layers[k], dims=2)
maxC = maximum(abs.(C))
heatmap(C, c=:balance, clims=(-maxC, maxC), yflip=true)

In [None]:
C

In [None]:
k = 9
@show model[k]
psk = params(model[k])
[psk[1] psk[2]] 

In [None]:
F(x,y) = model(Matrix([x y]'))[1]

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, colorbar=:none, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contour!(xx, yy, f, levels=flevels, color=1, contour_labels=true)
contour!(xx, yy, F, levels=flevels, color=:black, contour_labels=true)

---
# Dropout

[Dropout: A Simple Way to Prevent Neural Networks from Overfitting](https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf)

In [None]:
cutoff = 60

posinds = findall(y_train[:] .<= cutoff)
neginds = findall(y_train[:] .> cutoff)

length(posinds), length(neginds)

In [None]:
plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contourf!(xx, yy, (x,y) -> f(x,y) <= cutoff, c=:binary)
Plots.scatter!(X_train[1,posinds], X_train[2,posinds], c=2, label=:none)
Plots.scatter!(X_train[1,neginds], X_train[2,neginds], c=3, label=:none)

In [None]:
N_train = length(y_train)
yb = 1f0*(y .<= cutoff)
yb_train, yb_test = yb[:,1:N_train], yb[:,N_train+1:N]
size(X_train), size(yb_train)

In [None]:
Random.seed!(1234)

model = Chain(
    Dense(2, 16, relu),
    Dense(16, 16, relu),
    Dense(16, 8, relu),
    Dropout(0.5),
    Dense(8, 4, relu),
    Dense(4, 2, relu),
    Dense(2, 1, sigmoid))

loss(x, y) = binarycrossentropy(model(x), y)
accuracy(x,y) = 100*sum(abs.(round.(model(x)) .== y))/length(y)

ps = params(model)

F(x,y) = round(model([x,y])[1])

In [None]:
data = DataLoader((X_train, yb_train), batchsize=400)

opt = ADAM()

In [None]:
Flux.trainmode!(model)

@time begin
    epochs = 2000
    for epoch = 1:epochs
        train!(loss, ps, data, opt)
        if epoch%100==0
            @show loss(X_train, yb_train)
        end
    end
end

Flux.testmode!(model)
@show accuracy(X_train, yb_train)
@show accuracy(X_test, yb_test)

plot(xlabel=L"x", ylabel=L"y", aspect_ratio=:equal, size=(600,600),
    xlims=(ax,bx), ylims=(ay,by))
contourf!(xx, yy, F, c=:binary)

In [None]:
posinds = findall(y_test[:] .<= cutoff)
neginds = findall(y_test[:] .> cutoff)

plt1 = plot(aspect_ratio=:equal, size=(600,600), xlims=(ax,bx), ylims=(ay,by), legend=:none)
contourf!(xx, yy, (x,y) -> f(x,y) <= cutoff, c=:binary)
Plots.scatter!(X_test[1,posinds], X_test[2,posinds], c=2, label=:none)
Plots.scatter!(X_test[1,neginds], X_test[2,neginds], c=3, label=:none)

plt2 = plot(aspect_ratio=:equal, size=(600,600), xlims=(ax,bx), ylims=(ay,by), legend=:none)
contourf!(xx, yy, F, c=:binary)
Plots.scatter!(X_test[1,posinds], X_test[2,posinds], c=2, label=:none)
Plots.scatter!(X_test[1,neginds], X_test[2,neginds], c=3, label=:none)
    
plot(plt1, plt2, layout=(1,2), size=(900,500))

---