In [None]:
using Logging
logger = ConsoleLogger(stdout)
# debuglogger = ConsoleLogger(stderr, Logging.Debug)
global_logger(logger)
using JLD

using Random
rng = MersenneTwister(1234)
import Dates

In [None]:
using Plots
pyplot()

In [None]:
using ForwardDiff
using ProgressMeter
using LinearAlgebra: dot
include("utils/misc.jl")
using .MiscUtils: myCircle, mySphere

In [None]:
ts = Dates.now()

## model parameters
input_dim = 2
bias = true
m = 50  # nb hidden neurons for each sign
N = 5   # nb samples
activation_str = "ReLUcub"

## logging
ts_fsfriendly = Dates.format(ts, "yyyy-mm-ddTHHMMSS") # filesystem-friendly string for ts
resultdir = mkpath("results/maxF1margin__dim$(input_dim)bias$(bias)__m$(m)__N$(N)__$(activation_str)__$(ts_fsfriendly)")

logfile = "$resultdir/log.txt"
touch(logfile)
open(logfile, "a") do f
    write(f, "input_dim=$input_dim\n")
    write(f, "bias: $bias\n")
    write(f, "m=$m\n")
    write(f, "N=$N\n")
    write(f, "activation: $activation_str\n")
end

In [None]:
alpha = 8
scaling = 1 # unknown what the "correct" scaling typically is!

## algo parameters
T = Int(floor(4000 / (alpha^scaling)))
eta0_a = 1e-1 * alpha # initial stepsize
eta0_w = 1e-1 * alpha
eta0_theta = 1e-2 * alpha
constrain_theta = true
extrasteps = 2 # extrasteps=1: CP-MDA, extrasteps=2: CP-MP

open(logfile, "a") do f
    write(f, "T=$T\n")
    write(f, "eta0_a=$(eta0_a)\n")
    write(f, "eta0_w=$(eta0_w)\n")
    write(f, "eta0_theta=$(eta0_theta)\n")
    write(f, "constrain_theta: $(constrain_theta)\n")
    write(f, "extrasteps=$(extrasteps)\n")
end

In [None]:
## plotting parameters
TNI_min = 1 # plot NI error
TNI_max = T
evalNIevery = Int(floor((TNI_max-TNI_min)/50))
Ntheta = Int(1e4)
skip_avg = true # avg not implemented

@assert input_dim==2
# r = range(-3, 3, length=101)
r = range(-2.2, 2.2, length=101)

Tplotreg_min = 1 # plot decision regions gif
Tplotreg_max = T
plotregevery = Int(floor((Tplotreg_max-Tplotreg_min)/50))
hidetitle = false
hidelabel = false
skip_gifs = true

In [None]:
if activation_str == "ReLU"
    activation(x) = max(0, x)
elseif activation_str == "abs"
    activation(x) = abs(x)
elseif activation_str == "ReLUsq"
    activation(x) = max(0, x)^2
elseif activation_str == "ReLUcub"
    activation(x) = max(0, x)^3
elseif activation_str == "ReLUquar"
    activation(x) = max(0, x)^4
# elseif activation_str == "sq" # polynomial activation is silly
#     activation(x) = x^2
elseif activation_str == "sigmoid"
    activation(x) = 1/(1+exp(-x))
elseif activation_str == "tanh"
    activation(x) = tanh(x)
else
    error("activation_str not recognized, should be 
        \"ReLU\" or 
        \"abs\" or 
        \"ReLUsq\" or 
        \"ReLUcub\" or 
        \"ReLUquar\" or 
        \"sigmoid\" 
        or \"tanh\"")
end

if bias == true
    d = input_dim + 1
else
    d = input_dim
end

In [None]:
function makedataset(d, N; bias, rng=MersenneTwister(1234))
    x = randn(rng, d, N)
    if bias
        x[d,:] .= 1
    end
    y = sign.(rand(rng, N).-0.5) # -1 or 1 uniformly
    return x, y
end

Random.seed!(rng, 1234)
x, y = makedataset(d, N; bias=bias, rng=rng)
save("$resultdir/dataset.jld", "x", x, "y", y)

scatter(x[1, y.==1], x[2, y.==1], m=:circ, markersize=8, label="+", color=:green)
scatter!(x[1, y.==-1], x[2, y.==-1], m=:utriangle, markersize=8, label="-", color=:red)

In [None]:
##   max_nu min_i int_Theta yi*phi(theta, xi) dnu(theta) subj. to nu in M(Theta) and ||nu||=1
## = min_a max_nu sum_i int_{Theta+ U Theta-} ai yi*phitilde(theta, xi) dnu(theta) subj. to nu in P(Theta+ U Theta-) and a in P([N]), 
## where phitilde(theta, x) = phi(theta, x) if theta in Theta+ and -phi(theta, x) if theta in Theta- (cf Appdx A of Chizat21)
# the network is parametrized such that the first m neurons are positively weighted and the last m are negatively weighted
gfun(i, theta) = y[i] * activation(dot(theta, x[:,i]))

function logit(x, w, theta)
    out = 0
    for j=1:m
        out += w[j] * activation(dot(theta[:,j], x))
    end
    for j=m+1:2m
        out += -w[j] * activation(dot(theta[:,j], x))
    end
    return out
end

if bias
    predict(x1, x2; w, theta, hard=true) = hard ? sign(logit([x1, x2, 1], w, theta)) : logit([x1, x2, 1], w, theta)
else
    predict(x1, x2; w, theta, hard=true) = hard ? sign(logit([x1, x2], w, theta)) : logit([x1, x2], w, theta)
end

In [None]:
Random.seed!(rng, 1234)
# initialize adversary weights
a = ones(N) ./ N
# initialize network weights and positions nu=(w, theta); theta constrained to unit l2 sphere
w = ones(2m) ./ (2m)
theta = randn(rng, d, 2m)
for j=1:2m
    theta[:,j] ./= sqrt(sum(theta[:,j].^2))
end

# for extragradient step
ap = similar(a)
wp = similar(w)
thetap = similar(theta)
# to store intermediate values
copies_a = Array{Float64}(undef, N, T+1)
copies_w = Array{Float64}(undef, 2m, T+1)
copies_theta = Array{Float64}(undef, d, 2m, T+1)
;

In [None]:
## plot the decision region at random initialization https://discourse.julialang.org/t/plotting-decision-boundary-regions-for-classifier/21397
contour(r, r, 
    (x1, x2) -> predict(x1, x2; w=w, theta=theta, hard=false),
    f=true)

In [None]:
"""
Take a CP gradient step
- starting from a, w, theta
- evaluating the gradients at ap, wp, thetap
- with stepsizes eta_a, eta_w, eta_theta
(taking care of the fact that the network is parametrized such that the first m neurons are positively weighted and the last m are negatively weighted)
Returns the updated particles a1, w1, theta1
"""
function step_CPMDA(
        f,
        a, w, theta,
        ap, wp, thetap,
        eta_a, eta_w, eta_theta;
        true_prox=false,
        constrain_theta=true
)
    N = length(a)
    d = size(theta)[1]
    m = Int(size(theta)[2] / 2)
    
    neuronsigns = ones(2m)
    neuronsigns[m+1:2m] .= -1

    Dfp = Array{Float64}(undef, d, N, 2m) # gradient of f w.r.t theta at thetap
    for i=1:N, j=1:2m
        Dfp[:,i,j] = ForwardDiff.gradient(tt -> f(i,tt), thetap[:,j])
    end

    # take step: adversary
    s = Array{Float64}(undef, N)
    for i=1:N
        s[i] = sum( neuronsigns[j] * wp[j] * f(i, thetap[:,j]) for j=1:2m )
    end
    if eta_a == Inf
        a1 = zeros(N)
        a1[argmax(s)] = 1.
    else
        a1 = a .* exp.(-eta_a * s)
        a1 ./= sum(a1)
    end
    
    # take step: network
    w1 = similar(w)
    theta1 = similar(theta)
    for j=1:2m
        s = sum( ap[i] * f(i, thetap[:,j]) for i=1:N )
        w1[j] = w[j] * exp(eta_w*neuronsigns[j]*s)
    end
    w1 ./= sum(w1)
    for j=1:2m
        s = sum( ap[i] * Dfp[:,i,j] for i=1:N )
        @assert size(s) == (d,)
        if true_prox
            theta1[:,j] = theta[:,j] + eta_theta * neuronsigns[j] * wp[j] / w[j] * s
        else
            theta1[:,j] = theta[:,j] + eta_theta * neuronsigns[j] * s
        end
    end
    # retract (just project) theta1 back to unit l2 sphere
    if constrain_theta
        for j=1:2m
            theta1[:,j] ./= sqrt(sum(theta1[:,j].^2))
        end
    end
    
    return a1, w1, theta1
end


In [None]:
eta_a, eta_w, eta_theta = eta0_a, eta0_w, eta0_theta # constant stepsizes

@showprogress 1 for t=1:T # minimum update interval of 1 second
    copies_a[:,t] = copy(a)
    copies_w[:,t] = copy(w)
    copies_theta[:,:,t] = copy(theta)

    # extragradient ("ghost" step)
    # extrasteps=1: CP-MDA, extrasteps=2: CP-MP
    ap, wp, thetap = a, w, theta
    for s=1:extrasteps
        ap, wp, thetap = step_CPMDA(    
            gfun,
            a, w, theta,
            ap, wp, thetap,
            eta_a, eta_w, eta_theta;
            true_prox=true,
            constrain_theta=constrain_theta
        )
    end

    # take step
    a, w, theta = ap, wp, thetap
end

copies_a[:,T+1] = copy(a)
copies_w[:,T+1] = copy(w)
copies_theta[:,:,T+1] = copy(theta)

af, wf, thetaf = a, w, theta
save("$resultdir/iterates.jld", "copies_a", copies_a, "copies_w", copies_w, "copies_theta", copies_theta)

In [None]:
af

In [None]:
## plot the NI error
"""
Compute the "global" Nikaido-Isoda (NI) error
    max_{a0, nu0} <a|F|nu0> - <a0|F|nu> = max_theta <a|F|delta_theta> - min_i <ei|F|nu>
"""
function glob_NI_err(gfun, a, w, theta; Ntheta=Int(1e4))
    N = length(a)
    d = size(theta)[1]
    m = Int(size(theta)[2] / 2)
    
    if d==2
        apcont = myCircle(Ntheta)
    elseif d==3
        apcont, Ntheta_new = mySphere(Ntheta)
    else
        error("glob_NI_err not implemented for d>3")
    end
    maxtheta = -Inf
    mintheta = +Inf
    for k=1:Ntheta_new
        theta0 = apcont[:,k]
        s = 0
        for i=1:N
            s += a[i] * gfun(i, theta0)
        end
        maxtheta = max(maxtheta, s)
        mintheta = min(mintheta, s)
    end
    Fnu(i) = sum( w[j] * gfun(i, theta[:,j]) for j=1:m ) - sum( w[j] * gfun(i, theta[:,j]) for j=m+1:2m )
    mini = minimum(Fnu(i) for i=1:N)
    return max(maxtheta, -mintheta) - mini
end

In [None]:
nierrs = Array{Float64}(undef, T+2)
@showprogress 1 for t=TNI_min:evalNIevery:TNI_max+1 # minimum update interval of 1 second
# for t=TNI_min:evalNIevery:TNI_max+1
    nierrs[t] = glob_NI_err(gfun, copies_a[:,t], copies_w[:,t], copies_theta[:,:,t]; Ntheta=Ntheta)
end
if !skip_avg
    nierrs[T+2] = glob_NI_err(gfun, avg_a, avg_w, avg_theta; Ntheta=Ntheta)
end
open(logfile, "a") do f
    for t=TNI_min:evalNIevery:TNI_max+1
        write(f, "glob_NI_err at iteration#$t: $(nierrs[t])\n")
    end
    if !skip_avg 
        write(f, "glob_NI_err at avg iterate: $(nierrs[T+2])\n") 
    end
end
save("$resultdir/nierrs__every$(evalNIevery)__t=$(TNI_min)--$(TNI_max).jld", "nierrs", nierrs)

plt_NI = plot(range(TNI_min, stop=TNI_max+1, step=evalNIevery), nierrs[TNI_min:evalNIevery:TNI_max+1], xlabel="k", label="")
if !skip_avg
    hline!([nierrs[T+2]], label="avg iterate")
end
if !hidetitle
    title!("NI error of iterates")
end
fn = "$resultdir/NI_errors.png"
savefig(plt_NI, fn)

eps = 1e-10 # numerical stability (we use approximations (with deltax, deltay) to compute glob_NI_err)
plt_NI_log = plot(range(TNI_min, stop=TNI_max+1, step=evalNIevery), eps .+ max.(0, nierrs[TNI_min:evalNIevery:(TNI_max+1)]), xlabel="k", label="", yscale=:log10)
if !skip_avg
    hline!([ eps + max(0, nierrs[T+2]) ], label=(hidelabels ? "" : "avg iterate"))
end
if !hidetitle
    title!("NI error of iterates (log-linear scale)")
end
fn = "$resultdir/NI_errors_logscale.png"
savefig(plt_NI_log, fn)

plt_NI_log

In [None]:
## decision region at the last iterate
pltt = contour(r, r,
    (x1, x2) -> predict(x1, x2; w=w, theta=theta, hard=false),
    f=true)
contour!(r, r,
    (x1, x2) -> predict(x1, x2; w=w, theta=theta, hard=false),
    levels=[0.],
    seriescolor=:blues,
    linestyle=:dash,
    linewidth=3)
scatter!(x[1, y.==1], x[2, y.==1], m=:circ, markersize=8, label="+", color=:green)  
scatter!(x[1, y.==-1], x[2, y.==-1], m=:utriangle, markersize=8, label="-", color=:red)

fn = "$resultdir/contour_soft_lastiter.png"
savefig(pltt, fn)
pltt

In [None]:
## plot the neurons (are they sparsely concentrated?)
function plot_neurons_2d(w, theta; resultdir=nothing)
    d = size(theta)[1]
    m = Int(length(w)/2)
    @assert d==2
    r0 = 1/m
    circle = r0 .* myCircle(500)
    plt = plot(circle[1,:], circle[2,:], aspect_ratio=1.0, label=false)
    neur = zeros(2, 2m)
    for j=1:2m
        neur[:,j] = w[j] * theta[:,j]
    end
    for j=1:m
        plot!([0, neur[1,j]], [0, neur[2,j]], color=:red, label=false)
    end
    for j=m+1:2m
        plot!([0, neur[1,j]], [0, neur[2,j]], color=:blue, label=false)
    end
    scatter!(neur[1,1:m],    neur[2,1:m],    markersize=4, markercolor=:red, label=false)
    scatter!(neur[1,m+1:2m], neur[2,m+1:2m], markersize=4, markercolor=:blue, label=false)
    if !isnothing(resultdir)
        fn = "$resultdir/neurons_lastiter.png"
        savefig(plt, fn)
    end
    return plt
end

function plot_neurons_fake3d(w, theta; hidetitle=false, resultdir=nothing)
    plt_e1 = plot_neurons_2d(w, theta[ [2,3], :])
    xlabel!("theta_2")
    ylabel!("theta_3")
    if !hidetitle
        title!("neurons projected along axis theta_1")
    end
    plt_e2 = plot_neurons_2d(w, theta[ [1,3], :])
    xlabel!("theta_1")
    ylabel!("theta_3")
    if !hidetitle
        title!("neurons projected along axis theta_2")
    end
    plt_e3 = plot_neurons_2d(w, theta[ [1,2], :])
    xlabel!("theta_1")
    ylabel!("theta_2")
    if !hidetitle
        title!("neurons projected along axis theta_3")
    end
    plt_combined = plot(plt_e1, plt_e2, plt_e3, layout=(1, 3), size=(1500, 500))
    if !isnothing(resultdir)
        fn = "$resultdir/neurons_lastiter_projtheta1.png"
        savefig(plt_e1, fn)
        fn = "$resultdir/neurons_lastiter_projtheta2.png"
        savefig(plt_e2, fn)
        fn = "$resultdir/neurons_lastiter_projtheta3.png"
        savefig(plt_e3, fn)
        fn = "$resultdir/neurons_lastiter.png"
        savefig(plt_combined, fn)
    end
    plt_e1, plt_e2, plt_e3, plt_combined
end

In [None]:
## neurons at last iterate
if d==3
    plt_e1, plt_e2, plt_e3, plt_combined = plot_neurons_fake3d(copies_w[:,T+1], copies_theta[:,:,T+1];
            hidetitle=hidetitle, resultdir=resultdir)
elseif d==2
    plt_combined = plot_neurons_2d(copies_w[:,T+1], copies_theta[:,:,T+1];
            resultdir=resultdir)
end
plt_combined

In [None]:
copies_w[:,T+1]

In [None]:
## plot the decision region across iterations (gif)
if !skip_gifs
    p = Progress(length(Tplotreg_min:plotregevery:Tplotreg_max))

    anim = @animate for t=Tplotreg_min:plotregevery:Tplotreg_max
        contour(r, r,
            (x1, x2) -> predict(x1, x2; w=copies_w[:,t], theta=copies_theta[:,:,t], hard=false),
            f=true)
        contour!(r, r,
            (x1, x2) -> predict(x1, x2; w=copies_w[:,t], theta=copies_theta[:,:,t], hard=false),
            levels=[0.],
            seriescolor=:blues,
            linestyle=:dash,
            linewidth=3)
        scatter!(x[1, y.==1], x[2, y.==1], m=:circ, markersize=8, label="+", color=:green)  
        scatter!(x[1, y.==-1], x[2, y.==-1], m=:utriangle, markersize=8, label="-", color=:red)
        title!("Iteration $t")
        next!(p)
    end

    fn = "$resultdir/contour_soft__every$(plotregevery)__$Tplotreg_min-$Tplotreg_max.gif"
    gif(anim, fn, fps=10)

    cp(anim.dir, resultdir*"/frames_contour_soft__every$(plotregevery)__$Tplotreg_min-$Tplotreg_max", force=true)
end