# Constraint Enforcement
## Penalty Methods
Include a penalty term that modifies the fitness score for for a given candidate $\bm x$ with inequality constraints $\bm A\bm x\leq \bm b$:

$$ g^*(\bm x)=g(\bm x)+\lambda || \max (0, \bm A\bm x-\bm b) ||^2$$

where $\max$ is taken element-wise.

## Constraint Repair
### Linear Intersection
At some iteration $t-1$, the particle is in the feasible region, and at time $t$ it is violating the constraint. Find the location on the boundary where the particle escaped the feasible region:
$$
\begin{equation}
    \begin{array}{c}
        \lambda_{proj} := \underset{j}{\text{argmin}} \ \lambda_j \\[12pt]
        \text{s.t.} \quad \bm A_j(\lambda_j  \bm x_{int} + (1-\lambda_j)\bm x_{ext})=b_j
    \end{array}
\end{equation}
$$

and the projected point is $$\bm x_{proj}=\lambda_{proj}  \bm x_{int} + (1-\lambda_{proj})\bm x_{ext}.$$

### Method of Alternating Projections
The method of alternating projections produces a sequence of orthogonal projections onto the violated constraints:
$$\bm x_{i}^{(n+1)} = \bm x_i^{(n)} - \frac{\bm A_{j_n}'\bm x_i^{(n)} - b_j}{||\bm A_{j_n}||^2}\bm A_{j_n}$$ 

where $\bm A_{j_n}$ takes values from $\{\bm A_j\mid \bm A_j\bm x_i \nleq b_j\}$.



## Imports

In [14]:
include("./designs.jl")
using .Designs: hypercube_initializer, constrained_initializer

include("../src/optim/objectives.jl")
using .Objectives: rastrigin, griewank

include("./pso.jl")
using .PSO: initialize_swarm, optimize, get_optimizer, aggregate_results

include("./hit_and_run.jl")
using .HitAndRun: hit_and_run

include("./types.jl")
using .PSOTypes: LinearSwarmConstraint

using Statistics
using LinearAlgebra
using HiGHS
using Polyhedra



# Constraint Enforcement


In [32]:
N = 100
K = 4

initializer = hypercube_initializer(N, K; lower=-10, upper=10)
swarm = initialize_swarm(initializer, griewank)
final_swarm, history = optimize(swarm, aggregate_results(save_world=true))

(Main.PSO.PSOTypes.RunnerState(Main.PSO.PSOTypes.Swarm(Main.PSO.PSOTypes.ParticleState([-9.411838133271198e-10 1.5643547568798861e-10 … 2.1211274451702546e-9 -1.7905815038530097e-9; -6.154228118721113e-10 7.923876146866303e-10 … 1.2242341687014784e-9 4.125336741452174e-10; … ; -2.211623529720464e-10 -2.578125315082353e-10 … 1.8453912882035214e-9 -1.4520389194523428e-9; -5.123572564317772e-10 6.887699089351172e-10 … -6.046743815569994e-10 -8.903378712181337e-10;;; 1.9058747438101191e-10 -1.7039840388831084e-9 … 6.855118207147895e-10 4.984691753961001e-9; -4.164824166892524e-9 7.539216927429445e-9 … -9.800729661774405e-10 2.5061595630755627e-8; … ; 1.1452186508404985e-9 -1.583975364327778e-8 … -9.847528770758987e-10 1.2798748840331173e-8; -3.7078644706076985e-11 -1.465028950532493e-9 … 9.261655388756315e-11 -4.4581906135084257e-10;;; 1.9474402521786382e-10 -1.4506985972646857e-9 … 9.038248495999795e-9 -1.0271937903105904e-9; 2.0576161613801405e-10 1.936764110746544e-9 … 1.340423079581846

## Constraint Repair

## Penalty Methods

In [12]:
function unit_hypercube_constraints(d)
    A = [ Matrix(I, d, d); -Matrix(I, d, d) ]
    b = [ ones(d); zeros(d) ]
    return A, b
end

N = 100

100

In [22]:
A, b = unit_hypercube_constraints(4)
bad_initializer = hypercube_initializer(N, 4)
good_initializer = constrained_initializer(N, A, b)
constraints = LinearSwarmConstraint(A, b, 0.5)

LinearSwarmConstraint([1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; … ; 0.0 0.0 -1.0 0.0; 0.0 0.0 0.0 -1.0], [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], 0.5)

## Resampling

In [49]:
function resample_violating_rows!(X, constraints::LinearSwarmConstraint, initializer::Function)
    n, N, K = size(X)
    m = size(constraints.A, 1)

    # Reshape X to ((n*N), K).
    X_2d = reshape(X, n*N, K)

    # Compute violations
    violation_mat = (constraints.A * X_2d') .- constraints.b

    clamped = max.(violation_mat, 0)

    is_violating_flat = vec(any(clamped .> 0.0, dims=1))  # shape: (nN,)
    violation_mask = reshape(is_violating_flat, n, N)
    violating_particles = any(violation_mask, dims=2) |> vec  # shape: (p,)

    p = length(violating_particles)

    # Re-sample
    X[violating_particles, :, :] .= initializer(p)

    return p
end

resample_violating_rows! (generic function with 1 method)

In [None]:
X = bad_initializer(100)
num_changed_rows = resample_violating_rows!(X, constraints, good_initializer)

## Constraint Repair

In [56]:
X_int = good_initializer(100)
X_ext = X_int .+ 100

Running HiGHS 1.8.0 (git hash: fcfb534146): Copyright (c) 2024 HiGHS under MIT licence terms
Coefficient ranges:
  Matrix [1e+00, 1e+00]
  Cost   [0e+00, 0e+00]
  Bound  [0e+00, 0e+00]
  RHS    [1e+00, 1e+00]
Presolving model
Problem status detected on presolve: Infeasible
Model   status      : Infeasible
Objective value     :  0.0000000000e+00
HiGHS run time      :          0.00
ERROR:   No LP invertible representation for getDualRay
Running HiGHS 1.8.0 (git hash: fcfb534146): Copyright (c) 2024 HiGHS under MIT licence terms
Coefficient ranges:
  Matrix [1e+00, 1e+00]
  Cost   [1e+00, 1e+00]
  Bound  [0e+00, 0e+00]
  RHS    [1e+00, 1e+00]
Presolving model
8 rows, 5 cols, 16 nonzeros  0s
0 rows, 1 cols, 0 nonzeros  0s
0 rows, 0 cols, 0 nonzeros  0s
Presolve : Reductions: rows 0(-8); columns 0(-5); elements 0(-16) - Reduced to empty
Solving the original LP from the solution after postsolve
Model   status      : Optimal
Objective value     :  5.0000000000e-01
HiGHS run time      :       

100×100×4 Array{Float64, 3}:
[:, :, 1] =
 100.608  100.397  100.527  100.884  …  100.791  100.993  100.534  100.201
 100.473  100.651  100.362  100.803     100.615  100.984  100.939  100.161
 100.786  100.583  100.596  100.812     100.416  100.993  100.853  100.15
 100.772  100.573  100.715  100.691     100.395  100.976  100.856  100.174
 100.923  100.662  100.687  100.834     100.413  100.914  100.749  100.142
 100.931  100.464  100.098  100.888  …  100.124  100.914  100.467  100.091
 100.934  100.572  100.736  100.94      100.192  100.991  100.443  100.028
 100.901  100.658  100.733  100.953     100.333  100.969  100.472  100.038
 100.939  100.782  100.697  100.912     100.543  100.954  100.529  100.568
 100.891  100.907  100.699  100.932     100.422  100.332  100.492  100.433
 100.94   100.595  100.666  100.839  …  100.873  100.845  100.602  100.37
 100.899  100.53   100.95   100.8       100.951  100.855  100.541  100.444
 100.961  100.113  100.282  100.716     100.938  100.524  100

In [None]:
function linear_intersection(A, b, x_int, x_ext)

    # Identify which rows of A are violated
    violated = (A * x_ext) .> b

    # If none violated, return x_ext
    if !any(violated)
        return x_ext
    end

    # Solve for lambda for each of the violated constraints
    A_violated = A[violated, :]
    b_violated = b[violated]
    numerator = b_violated .- (A_violated * x_int)
    denominator = A_violated * (x_ext .- x_int)

    # If denominator has any zero, handle or skip accordingly
    λ_vec = numerator ./ denominator

    # Find the minimizing lambda
    λ_min = minimum(λ_vec)

    # Return the intersection point
    return x_int .+ λ_min .* (x_ext .- x_int)
end

function repair_violations!(X_int, X_ext, A, b)
    n, N, K = size(X_int)

    # Flatten to (n*N, K) so each design point is a row
    X_int_2d = reshape(X_int, n*N, K)
    X_ext_2d = reshape(X_ext, n*N, K)

    # Identify violating design points
    violation_mat = A * X_ext_2d' .- b
    clamped = max.(violation_mat, 0)

    # Create a mask for violating rows
    is_violating = vec(any(clamped .> 0, dims=1))

    # Loop over violating rows
    violating_indices = findall(is_violating)
    for i in violating_indices
        # Extract the row
        x_int_row = @view X_int_2d[i, :]
        x_ext_row = @view X_ext_2d[i, :]

        # Repair that row
        x_repaired = linear_intersection(A, b, x_int_row, x_ext_row)

        # In-place update
        x_ext_row .= x_repaired
    end

    X_ext .= reshape(X_ext_2d, n, N, K)
    return sum(is_violating)
end

repair_violations! (generic function with 2 methods)

In [76]:
num_repairs = repair_violations!(X_int, X_ext, A, b)

10000

In [77]:
X_ext

100×100×4 Array{Float64, 3}:
[:, :, 1] =
 0.613763  0.606228   1.0       1.0       …  1.0       0.924836  0.41895
 0.477757  1.0        0.877984  1.0          1.0       1.0       0.469383
 0.870406  1.0        0.64856   0.973388     1.0       1.0       0.458065
 0.843278  1.0        0.880605  1.0          1.0       1.0       0.46398
 0.945494  1.0        1.0       1.0          0.95495   1.0       0.358353
 0.955488  0.792296   0.765869  1.0       …  0.955049  0.892977  0.341737
 0.975703  1.0        0.799475  1.0          1.0       0.75467   0.261791
 0.937857  0.850819   0.795714  1.0          1.0       0.898513  0.301498
 0.94736   1.0        0.803838  1.0          1.0       0.982232  1.0
 0.939689  1.0        0.880116  1.0          0.575497  1.0       1.0
 1.0       0.946597   1.0       1.0       …  0.882042  1.0       0.935033
 1.0       0.883452   1.0       1.0          0.900933  0.97461   1.0
 1.0       0.746363   0.693886  0.869955     0.583846  1.0       1.0
 ⋮                 