## Particle Swarm Optimization Visualizer
---
Kyle Spurlock

12/11/2023

# Functions

In [1]:
function rastrigin(x)
    20 + sum(x .^ 2 + 10 * cos.(2π * x))
end

function schaffer_n4(x)
    x1 = x[1]
    x2 = x[2]
    a = cos(sin(abs(x1^2 - x2^2)))^2 - 0.5
    b = (1 + 0.001*(x1^2 + x2^2))^2
    c = 0.5 + a + b
end

function levy(x)
    w(u) = 1 + (u - 1) / 4
    ws = w.(x)

    function inner_levy(z)
        a = sin(π*z[1])^2
        b = ((z[1]-1)^2)*(1 + 10sin(π*z[1] + 1)^2)
        c = ((z[2]-1)^2)*(1 + sin(2π*z[2])^2)     
        a + b + c
    end
    inner_levy(ws)
end


function surface_grid(func, domain)
    n = length(domain)
    xx, yy, zz = zeros(n, n), zeros(n, n), zeros(n, n)

    for (i, x) in pairs(domain)
        for (j, y) in pairs(domain)
            xx[i, j] = x
            yy[i, j] = y
            zz[i, j] = func([x, y]) 
        end

    end 
    vec(xx), vec(yy), vec(zz)
end;

# Particle Swarm Optimization (PSO) Algorithm

In [2]:
using Random

struct SwarmState
    positions
    velocities
    costs

    best_costs
    best_pos

    best_global_cost
    best_global_pos
end

function update_best(costs, positions, best_costs, best_pos, best_global_cost, best_global_pos)
    new_best_costs, best_inds = findmin(hcat(costs, best_costs), dims=2)

    new_best_pos = copy(best_pos)
    temp = [positions, best_pos]

    for (i, ind) in enumerate(best_inds)
        # Update best positions 
        # Second index is important here as this specifies which is better of new/old 
        new_best_pos[i, :] = temp[ind[2]][i, :]
    end

    # Update best overall
    local_min_cost, local_min_ind = findmin(new_best_costs)

    new_best_global_cost = best_global_cost
    new_best_global_pos = copy(best_global_pos)

    if local_min_cost < best_global_cost
        new_best_global_cost = local_min_cost
        new_best_global_pos = new_best_pos[local_min_ind[1], :]
    end

    new_best_costs, new_best_pos, new_best_global_cost, new_best_global_pos
end

function update_velocity(velocities, positions, best_pos, best_global_pos, rng; constrict=0.01, inertia=1, damping=[-1, 1])
    # Dimension-wise random values for stochasticity
    phi1 = rand(Float32, size(positions))
    phi2 = rand(Float32, size(positions))

    b = phi1 .* best_pos .- positions
    c = phi2 .* best_global_pos' .- positions
    new_velocities = constrict * (inertia * velocities + b + c)

    # Damp the velocities from becoming too large
    new_velocities[findall(<(damping[1]), new_velocities)] .= damping[1]
    new_velocities[findall(>(damping[2]), new_velocities)] .= damping[2]

    new_velocities
end


function update_position(positions, velocities; bounds=nothing)
    new_pos = positions + velocities

    # Ensure position is bounded
    if !isnothing(bounds)
        new_pos[findall(<(bounds[1]), new_pos)] .= bounds[1]
        new_pos[findall(>(bounds[2]), new_pos)] .= bounds[2]
    end
    new_pos
end

function swarm_optimize(num_p, criterion, domain=-5.12:0.1:5.12; iterations=50, constrict=0.01, inertia=1, damping=[-1, 1], seed=nothing)
    bounds = [minimum(domain), maximum(domain)]
    rng = Random.seed!(seed)

    # Initialize swarm ---------------------------------
    positions = rand(rng, domain, num_p, 2)
    velocities = zeros(num_p, 2)

    costs = criterion.(eachrow(positions))

    best_costs = copy(costs)
    best_pos = copy(positions)

    best_global_cost, best_global_ind = findmin(costs)
    best_global_pos = positions[best_global_ind, :]
    
    velocities = update_velocity(velocities, positions, best_pos, best_global_pos, rng; constrict=constrict, inertia=inertia, damping=damping )

    initial_state = SwarmState(positions, velocities, costs, best_costs, best_pos, best_global_cost, best_global_pos)
    states = [initial_state]
    
    # Loop --------------------------------------------
    for i in range(1, iterations-1)
        positions = update_position(states[i].positions, states[i].velocities; bounds=bounds)
        costs = criterion.(eachrow(positions))

        best_costs, best_pos, best_global_cost, best_global_pos = update_best(costs, positions, states[i].best_costs, states[i].best_pos, states[i].best_global_cost, states[i].best_global_pos)
        velocities = update_velocity(states[i].velocities, positions, best_pos, best_global_pos, rng; constrict=constrict, inertia=inertia, damping=damping)

        new_state = SwarmState(positions, velocities, costs, best_costs, best_pos, best_global_cost, best_global_pos)
        push!(states, new_state)
    end
    states
end;

# Visualizer Code

In [3]:
using GLMakie
GLMakie.activate!(title="Particle Swarm Visualizer")
using Statistics, LinearAlgebra


# Figure components ----------------------------------------------
fig = Figure(resolution=(1600, 800), backgroundcolor=:lightgrey)

grid_options = fig[1, 1] = GridLayout()
grid_main = fig[1, 2] = GridLayout()

ax3 = Axis3(grid_main[2, 1])
ax21 = Axis(grid_main[2, 2], xrectzoom=false, yrectzoom=false, yzoomlock=true, xzoomlock=true)
ax22 = Axis(grid_main[3, 1:2], xlabel="Time", ylabel="Cost")


# Menu -----------------------------------------------------------
funcs = [rastrigin, schaffer_n4, levy]
func = Observable{Any}(funcs[1])

menu = Menu(fig,
    options = zip(["Rastrigin", "Schaffer N4", "Levy"], funcs),
    default = "Rastrigin")

on(menu.selection) do s
    func[] = s
end
notify(menu.selection)


# System parameters ---------------------------------------------------------------
low_bound = Observable{Float32}(-5.12)
step_size = Observable{Float32}(0.01)
high_bound = Observable{Float32}(5.12)
domain = Observable{Vector{Float32}}(collect(low_bound[]:step_size[]:high_bound[]))

num_particles = Observable{Int32}(5)
iterations = Observable{Int32}(50)
constrict = Observable{Float32}(0.01)
inertia = Observable{Float32}(1)
damp_low = Observable{Float32}(-1)
damp_high = Observable{Float32}(1)

seed = Observable{Any}(nothing)

initial_states = swarm_optimize(
        num_particles[],
        func[],
        domain[];
        iterations=iterations[],
        constrict=constrict[],
        inertia=inertia[],
        damping=[damp_low[], damp_high[]]
        )

states = Observable(initial_states)

px, py, pz = [Observable{Vector{Float64}}() for _ in range(1, 3)] # Positions
vu, vv = [Observable{Vector{Float64}}() for _ in range(1, 2)] # Velocities

# These are the stored statistics to avoid recomputing when changing slider
stats_x = Observable([z for z in range(1, 50)]) # Used to represent timesteps when plotting statistics
mins = Observable([minimum(s.costs) for s in initial_states])
means = Observable([mean(s.costs) for s in initial_states])
maxs = Observable([maximum(s.costs) for s in initial_states])

# These are the statistics points plotted up to time t
data_min = Observable{Matrix{Float64}}()
data_mean = Observable{Matrix{Float64}}()
data_max = Observable{Matrix{Float64}}()

# Slider -------------------------------------------------------------------------
sg = SliderGrid(fig, (label="Time", range = 1:1:50, startvalue=1))

lift(sg.sliders[1].value) do i
    # Update positions for plotting
    pos = states[][i].positions
    npx = pos[:, 1]
    npy = pos[:, 2]
    npz = states[][i].costs
    
    px.val = npx
    py.val = npy
    pz.val = npz

    # Update velocities for plotting
    vel = states[][i].velocities
    norms = norm.(eachrow(vel))
    normed_vel = vel ./ norms
    nvu = normed_vel[:, 1]
    nvv = normed_vel[:, 2]

    vu.val = nvu
    vv[] = nvv # Delay update
    notify(pz) # Notify pz to also notify px and py
    
    # update statistics for plotting
    selected_x = stats_x[][1:i]
    data_min[] = hcat(selected_x, mins[][1:i])
    data_mean[] = hcat(selected_x, means[][1:i])
    data_max[] = hcat(selected_x, maxs[][1:i])
end

# Buttons ------------------------------------------------------------------------
b_opt = Button(fig, label="Optimize", tellwidth=:false)
on(b_opt.clicks) do n
    # Get new swarm states
    new_states = swarm_optimize(
        num_particles[],
        func[],
        domain[];
        iterations=iterations[],
        constrict=constrict[],
        inertia=inertia[],
        damping=[damp_low[], damp_high[]],
        seed=seed[]
        )

    states[] = new_states
    sg.sliders[1].range = 1:iterations[] # Change slider range

    # Compute new plotted elements related to statistics
    stats_x[] = [z for z in range(1, iterations[])]
    mins[] = [minimum(s.costs) for s in new_states]
    means[] = [mean(s.costs) for s in new_states]
    maxs[] = [maximum(s.costs) for s in new_states]
    
    # Update axes of statistics plot to reflect new statistics
    xlims!(ax22, 1, iterations[] + 0.5)
    ylims!(ax22, minimum(mins[])-0.5, maximum(maxs[])+0.5)

    set_close_to!(sg.sliders[1], 1) # Reset the slider (will also reset point positions)
end

b_plot = Button(fig, label="Adjust plot", tellwidth=:false)
on(b_plot.clicks; priority=-1) do n
    domain[] = collect(low_bound[]:step_size[]:high_bound[])
    
    # Change the plotted data
    nxx, nyy, nzz = surface_grid(func[], domain[])
    xx.val = nxx
    yy.val = nyy
    zz[] = nzz # Delay update

    # Surface plot
    xlims!(ax3, minimum(xx[]), maximum(xx[]))
    ylims!(ax3, minimum(yy[]), maximum(yy[]))
    zlims!(ax3, minimum(zz[]), maximum(zz[]))
    
    # Contour plot
    xlims!(ax21, minimum(xx[]), maximum(xx[]))
    ylims!(ax21, minimum(yy[]), maximum(yy[]))
    notify(b_opt.clicks)
end

xx = Observable{Vector{Float64}}()
yy = Observable{Vector{Float64}}()
zz = Observable{Vector{Float64}}()
notify(b_plot.clicks) # Will initialize xx, yy, zz

surf = surface!(ax3, xx, yy, zz, colormap=:viridis)
co = contour!(ax21, xx, yy, zz, colormap=:viridis)
Colorbar(grid_main[1, 1:2], surf, vertical=false)

ylims!(ax22, minimum(mins[]), maximum(maxs[]))
xlims!(ax22, minimum(stats_x[]), maximum(stats_x[]))

# Textboxes ----------------------------------------------------------------------------------------------

# Plot bounds
t_low_bound = Textbox(fig, placeholder="-5.12", tellwidth=:false, validator=Float32, boxcolor=:white)
t_step = Textbox(fig, placeholder="0.01", tellwidth=:false, validator=Float32, boxcolor=:white)
t_high_bound = Textbox(fig, placeholder="5.12", tellwidth=:false, validator=Float32, boxcolor=:white)

# Swarm parameters
t_num_particles = Textbox(fig, placeholder="5", tellwidth=:false,width=100, validator=Int32, boxcolor=:white, halign=:right)
t_iterations = Textbox(fig, placeholder="50", tellwidth=:false,width=100, validator=Int32, boxcolor=:white, halign=:right)
t_constrict = Textbox(fig, placeholder="0.01", tellwidth=:false,width=100, validator=Float32, boxcolor=:white, halign=:right)
t_inertia = Textbox(fig, placeholder="1", tellwidth=:false,width=100, validator=Float32, boxcolor=:white, halign=:right)
t_low_damp = Textbox(fig, placeholder="-1", tellwidth=:false, validator=Float32, boxcolor=:white)
t_high_damp = Textbox(fig, placeholder="1", tellwidth=:false, validator=Float32, boxcolor=:white)
t_seed = Textbox(fig, placeholder="-1", tellwidth=:false, width=100,validator=Int32, boxcolor=:white, halign=:right)

# Textbox actions
on(t_low_bound.stored_string) do s
    low_bound[] = parse(Float32, s)
end

on(t_step.stored_string) do s
    step_size[] = parse(Float32, s)
end

on(t_high_bound.stored_string) do s
    high_bound[] = parse(Float32, s)
end

on(t_num_particles.stored_string) do s
    num_particles[] = parse(Int32, s)
end

on(t_iterations.stored_string) do s
    iterations[] = parse(Int32, s)
end

on(t_constrict.stored_string) do s
    constrict[] = parse(Float32, s)
end

on(t_inertia.stored_string) do s
    inertia[] = parse(Float32, s)
end

on(t_low_damp.stored_string) do s
    damp_low[] = parse(Float32, s)
end

on(t_high_damp.stored_string) do s
    damp_high[] = parse(Float32, s)
end

on(t_seed.stored_string) do s
    new_seed = parse(Int64, s)
    if new_seed < 0
        seed[] = nothing
    else
        seed[] = new_seed
    end
end

# Arranging option objects ------------------------------------------
plot_bounds_grid = hgrid!(Label(fig, "LB"), t_low_bound, Label(fig, "Step"), t_step, Label(fig, "UB"), t_high_bound)
damp_grid = hgrid!(Label(fig, "Low Damp"), t_low_damp, Label(fig, "High Damp"), t_high_damp)

grid_options[1:3, 1, ] = vgrid!(
    Label(fig, "Function", tellwidth=:false, halign=:left),
    menu,
    plot_bounds_grid,
    b_plot,
    hgrid!(Label(fig, "Particles", halign=:right), t_num_particles),
    hgrid!(Label(fig, "Iterations", halign=:right), t_iterations),
    hgrid!(Label(fig, "Constrict", halign=:right), t_constrict),
    hgrid!(Label(fig, "Inertia", halign=:right),t_inertia),
    damp_grid,
    hgrid!(Label(fig, "Seed", halign=:right), t_seed),
    sg,
    b_opt
    )

# Final plotting -----------------------------------------------------
scatter!(ax3, px, py, pz; color=:red)
scatter!(ax21, px, py, pz; color=:red)
arrows!(ax21, px, py, vu, vv; lengthscale=0.5)
max_line = lines!(ax22, data_max, label="max")
mean_line = lines!(ax22, data_mean, label="mean")
min_line = lines!(ax22, data_min, label="min")

axislegend(ax22, position=:lb)

colsize!(fig.layout, 2, Relative(7/10))
colgap!(fig.layout, 100)

fig