In [None]:
using Revise
using ARFF
using Random
using Statistics
using Distributions
using Printf
using LinearAlgebra
using Plots

In [None]:
default(lw=2, markersize=6,
    xtickfont=font(10), ytickfont=font(10),
    guidefont=font(12), legendfont=font(10), titlefont=font(12))

In [None]:
f1(x) = x[1]*x[2];
f2(x) = x[1]^2 - x[2]^2;

In [None]:
xx = LinRange(-2, 2, 100);
yy = LinRange(-2, 2, 100);

z1 = [f1([x_, y_]) for y_ in yy, x_ in xx];
z2 = [f2([x_, y_]) for y_ in yy, x_ in xx];

p1 = contourf(xx, yy, z1)
p2 = contourf(xx, yy, z2)
plot(p1, p2, layout=(1, 2))

In [None]:
N = 10^3;

Random.seed!(100)
x_data = [randn(2) for _ in 1:N];
y_data = [[f1(x_), f2(x_)] for x_ in x_data];
data = DataSet(x_data, y_data)

In [None]:
@show K = 2^6;
Random.seed!(200)
F0 = FourierModel([1.0 * randn(ComplexF64,K) for _ in 1:d], [1.0 * randn(2) for _ in 1:K])

In [None]:
d = 2;
δ = 2.4^2 / (15 * (d)); # rwm step size
Σ0 = diagm(ones(d));

n_epochs = 5 * 10^2; # total number of iterations
n_ω_steps = 2; # number of steps between full β updates
n_burn = Int(0.1 * n_epochs);
γ = optimal_γ(d);
# γ = 200;
@show γ;
ω_max = Inf;
adapt_covariance = true;

λ = 1e-6; # regularization;
function reg_β_solver!(β, S, y, λ, ω, r)
    N = length(y)
    β .= (S' * S + λ * N * diagm((1 .+ norm.(ω) .^ 2) .^ (r))) \ (S' * y)

end

r = 1;
β_solver! = (β, S, y, ω) -> reg_β_solver!(β, S, y, λ, ω, r);

opts = ARFFOptions(n_epochs, n_ω_steps, δ, n_burn, γ, ω_max,
    adapt_covariance, β_solver!, ARFF.mse_loss);

In [None]:
Random.seed!(1000);
F = deepcopy(F0);
Σ_mean, acceptance_rate, loss = train_rwm!(F, data, Σ0, opts, show_progress=true);