In [None]:
# import Pkg
# Pkg.add("LinearRegression")
# Pkg.add("JLD")

In [None]:
using Random
rng = MersenneTwister(1234)
import Dates
using JLD
import LinearAlgebra as LA
import LinearRegression as LinReg

In [None]:
using PyPlot
using Colors
using LaTeXStrings

In [None]:
function eta_opt_simGDA(M)
    return minimum(real.(1 ./ LA.eigvals(M)))  # Gidel et al. '19 "Negative momentum..." [Thm 2]
end

function eta_opt_altGDA(M, eta0, cnt_max=10)
    A = (M-M')/2
    dim = size(M)[1]
    nablaT_alt(eta) = LA.I(dim) .- eta * M .+ (eta^2/2) * A * M  # actually an approximation
    rho(eta) = maximum(abs.(LA.eigvals(nablaT_alt(eta))))
    eta = eta0
    cnt = 0
    while rho(eta) >= 1 && (cnt+=1)<cnt_max
        eta = eta/2
    end
    if (cnt == cnt_max)   @warn("cnt == cnt_max")  end
    return eta
end

minRe_exact(M) = minimum(real.(LA.eigvals(M)))

function rho_altGDA(M, eta)
    A = (M-M')/2
    dim = size(M)[1]
    nablaT = LA.I(dim) .- eta * M .+ (eta^2/2) * A * M
    return maximum(abs.(LA.eigvals(nablaT)))
end

In [None]:
Random.seed!(rng, 1234)

dimx, dimy = 2, 2  # (need to be equal for the NE to be unique)
alpha_s = 10. .^ (-3:0.25:1)
N = 10
T_max = Int(4e7)
T0 = 8000  # must be large enough to capture the regime where the dynamics is driven by the least eigenvalue, with all other eigenspaces fully fitted
etamult = 0.1  # safety margin to ensure small-stepsize regime
P_s = zeros(dimx, dimy, N)
eta_s = zeros(length(alpha_s), N)
rate_s, linreg_relerr_s = zeros(length(alpha_s), N), zeros(length(alpha_s), N)
minRe_s, rho_s = zeros(length(alpha_s), N), zeros(length(alpha_s), N)

for n=1:N
    P = randn(rng, dimx, dimy)
    P_s[:,:,n] = P
    x0, y0 = randn(rng, dimx) .* 1e-3, randn(rng, dimy) .* 1e-3
    M = zeros(dimx+dimy, dimx+dimy)
    M[1:dimx, dimx+1:end] = P
    M[dimx+1:end, 1:dimx] = -P'

    for (i, alpha) in enumerate(alpha_s)
        if alpha <= 4e-3   T0 = 12000 else T0 = 8000 end

        M[1,1] = alpha
        eta_opt = eta_opt_altGDA(M, 1/LA.opnorm(P))
        eta = etamult * eta_opt
        eta_s[i,n] = eta

        minRe_s[i, n] = minRe_exact(M)
        rho_s[i,n] = rho_altGDA(M, eta)

        T = min(T_max, Int(floor(T0/eta)))
        dists = zeros(T+1)

        # ## Simultaneous GDA
        # z = vcat(x0, y0)
        # for t=1:T
        #     dists[t] = sum(z .^ 2)
        #     z = z .- eta .* M * z
        # end
        # dists[T+1] = sum(z .^ 2)

        ## Alternating GDA
        x, y = x0, y0
        for t=1:T
            dists[t] = sqrt(sum(x .^ 2)) + sqrt(sum(y .^ 2))
            # dists[t] = sum(x .^ 2) + sum(y .^ 2)
            x1prev = x[1]
            x = x .- eta .* P * y
            x[1] = x[1] - eta * alpha * x1prev
            y = y .+ eta .* P' * x
        end
        dists[T+1] = sqrt(sum(x .^ 2)) + sqrt(sum(y .^ 2))
        # dists[T+1] = sum(x .^ 2) + sum(y .^ 2)
        # dists = sqrt.(dists)

        T_machineprecision = findfirst(.<(5e-16), dists)
        T = isnothing(T_machineprecision) ? T : T_machineprecision
        T_burnin = Int(floor(3*T/4))
        T_rg = collect(T_burnin:1:T)
        slope, intercept = LinReg.coef(LinReg.linregress(T_rg, log.(dists[T_rg])))
        rate_s[i,n] = -slope
        linreg_err = sum(( intercept .+ T_rg * slope .- log.(dists[T_rg]) ) .^ 2) / length(T_rg)
        linreg_relerr_s[i,n] = sqrt(linreg_err) / sum(log.(dists[T_rg])) * length(T_rg)
    end
end
rate_normalized_s = (1 .- exp.(-rate_s)) ./ eta_s
linreg_relerr_s

In [None]:
save("reg_bilin_data.jld", 
    "alpha_s", alpha_s, 
    "P_s", P_s,
    "eta_s", eta_s,
    "rate_s", rate_s, 
    "linreg_relerr_s", linreg_relerr_s, 
    "minRe_s", minRe_s, 
    "rho_s", rho_s)

# d = load("reg_bilin_data.jld")
# alpha_s = d["alpha_s"]
# eta_s   = d["eta_s"]
# rate_s  = d["rate_s"]
# minRe_s = d["minRe_s"]
# rho_s   = d["rho_s"]
# rate_normalized_s = (1 .- exp.(-rate_s)) ./ eta_s

In [None]:
cm = get_cmap(:tab20)
colorrange = (0:19) ./ 20
figure(figsize=[5,3])
for n=1:N
    loglog(alpha_s, rate_normalized_s[:,n], lw=3, 
        label=(n==1 ? L"$r/\eta$ (observed)" : ""),
        "o", color=cm(colorrange[n]))
    loglog(alpha_s, minRe_s[:,n], lw=1, 
        label=(n==1 ? L"$\tilde{\mu}_M$" : ""),
        color=cm(colorrange[n]))
end
legend()
xlabel(L"\alpha")
grid("on")
savefig("reg_bilin_rates__dim$(dimx)$(dimy).png", bbox_inches="tight", dpi=200)

In [None]:
function spMapprox(S, A, alpha; ord=1, tol_eigvals_sep=1e-5)
    dim = size(A)[1]
    mus, eigvecs = LA.eigen(im*A)  # im*A is hermitian so Julia calls a specialized solver which returns unitary eigvecs
    eigvals = im .* mus
    mus_sep = minimum(sort(mus)[i+1] - sort(mus)[i] for i=1:length(mus)-1)
    @assert mus_sep > tol_eigvals_sep "the eigenvalues of A are not simple, not implemented"

    spM01 = [ eigvals[i] + alpha * eigvecs[:,i]' * S * eigvecs[:,i]    for i=1:dim]
    if ord == 1
        return spM01
    elseif ord == 2
        spM2 = zeros(Complex, dim)
        for i=1:dim, j=1:dim
            if i != j
                spM2[i] += alpha^2 * abs(eigvecs[:,i]' * S * eigvecs[:,j])^2 / (eigvals[i] - eigvals[j])
            end
        end
        return spM01 .+ spM2
    else
        throw("not implemented")
    end
end

In [None]:
Random.seed!(rng, 1234)

dimx, dimy = 3, 3  # (need to be equal for the NE to be unique)
P = randn(rng, dimx, dimy)
alpha_s = 10. .^ (-3:1:0)
spM_s       = zeros(Complex, dimx+dimy, length(alpha_s))
spMapprox_s = zeros(Complex, dimx+dimy, length(alpha_s))

A = zeros(dimx+dimy, dimx+dimy)
A[1:dimx, dimx+1:end] = P
A[dimx+1:end, 1:dimx] = -P'
S1 = zeros(dimx+dimy, dimx+dimy)
S1[1,1] = 1
for (i, alpha) in enumerate(alpha_s)
    spM_s[:, i] = LA.eigvals(A .+ alpha*S1)
    spMapprox_s[:, i] = spMapprox(S1, A, alpha; ord=2)
end

## figure out xmin and xmax
maxIm = maximum(abs.(imag.(vcat(spM_s, spMapprox_s))))
ymin, ymax = maxIm .* (-1.3, 1.3)
xmin, xmax = minimum(real.(vcat(spM_s, spMapprox_s))) / 2,  maximum(real.(vcat(spM_s, spMapprox_s))) * 40

In [None]:
markers = ["o", "D", "P", "X", "s"]
figure(figsize=[5,3])
plot([xmin,xmax], [0, 0], label="", color=:red)
for (i, alpha) in enumerate(alpha_s)
    semilogx(real.(spMapprox_s[:,i]), imag.(spMapprox_s[:,i]), markersize=11, lw=3,
        markers[i], color=:gray, alpha=0.7)
    semilogx(real.(spM_s[:,i]), imag.(spM_s[:,i]), markersize=7, lw=3,
        label=(i == 1 ? L"\alpha=10^{%$(Integer(log10(alpha)))}" : L"~~~~~~10^{%$(Integer(log10(alpha)))}"), 
        # label=L"10^{%$(Integer(log10(alpha)))}", 
        markers[i], color=:blue, alpha=0.7)
end
legend()
xlabel(L"\Re(\lambda)")
ylabel(L"\Im(\lambda)")
xlim([xmin, xmax])
ylim([ymin, ymax])
grid("on")
savefig("reg_bilin_spM__dim$(dimx)$(dimy).png", bbox_inches="tight", dpi=200)