In [None]:


using Random
using Statistics
using Combinatorics
using ProgressMeter
using Gurobi
using JuMP
using Suppressor
using LinearAlgebra
using Plots
using Distributions
     

function brute_force(x::Vector{Float64}, y::Vector{Float64}, k::Int, score::Function)::Tuple{Vector{Int64},Vector{Int64}}
    @assert size(x)==size(y) "x and y must be the same shape"
    @assert k best_val
            best_val = val
            best_idxs = idxs
        end
    end

    # Get complement of best set
    comp_idxs = setdiff(1:n, best_idxs)

    return best_idxs, comp_idxs
end;
     

function find_hyperplane(A::Matrix{Float64}, B::Matrix{Float64}, ϵ::Float64 = 1e-10)::Vector{Float64}
    @suppress begin
        m, d = size(A)
        n, _ = size(B)

        # Create a Gurobi model
        model = Model(Gurobi.Optimizer)

        # Define variables for hyperplane coefficients w and bias b
        @variable(model, w[1:d+1])  # w is the normal vector to the hyperplane

        # Constraints: All points in A should be above the hyperplane, i.e., A * w + b >= 0
        @constraint(model, A * w[1:end-1] .+ w[end] .>= ϵ)
        @constraint(model, B * w[1:end-1] .+ w[end] .<= -ϵ)
        # @constraint(model, w[1] * w[3] - 1/4 * w[2]^2 <= 0)

        # Objective: Minimize the norm of the hyperplane coefficients (optional, for stability)
        @objective(model, Max, w[end] .^ 2)

        # Solve the model
        optimize!(model)

        # return termination_status(model)
        return value.(w)
    end
end;
     

function get_rect_hyperbola(w::Vector{Float64}, x::Vector{Float64})::Tuple{Tuple{Vector{Float64},Vector{Float64}},Tuple{Vector{Float64},Vector{Float64}}}
    y = (-w[4] .- w[1] .* x) ./ (w[2] .+ w[3] .* x)
    # Find the index where the denominator becomes small (to split the curve)
    denominator = w[2] .+ w[3] .* x
    split_idx = argmin(abs.(denominator))  # Index where the denominator is smallest (near the asymptote)
    # Split the curve into two parts: left and right of the asymptote
    x_left = x[1:split_idx-1]
    y_left = y[1:split_idx-1]
    x_right = x[split_idx+1:end]
    y_right = y[split_idx+1:end]
    (x_left, y_left), (x_right, y_right)
end;
     

function get_hyperbola(w::Vector{Float64}, x::Vector{Float64}, ϵ::Float64 = 1e-2)::Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}
    a1, b1, c1, d1, e1, f1 = w
    a = c1
    b = b1 .* x .+ e1
    c = a1 .* x .^ 2 .+ d1 .* x .+ f1
    disc = b.^2 .- 4 .* a .* c
    idxs = disc .> ϵ
    b = b[idxs]
    c = c[idxs]
    disc = disc[idxs]
    x = x[idxs]
    y1 = (-b .- sqrt.(disc)) ./ (2 .* a)
    y2 = (-b .+ sqrt.(disc)) ./ (2 .* a)
    x, y1, y2
end;
     

n = 30
k = 10;
     
Covariance and rectangular hyperbolas

# Define scoring and lifting functions
score = (x, y) -> cov(x, y)
lift = (x, y) -> hcat(x, y, x.*y)

# Generate random data
x, y = rand(n), rand(n)

# Partition points by brute-force
best_idxs, comp_idxs = brute_force(x, y, k, score)

# Separate sets using LP-based hyperplane
lifted = lift(x, y)
A = lifted[best_idxs, :]
B = lifted[comp_idxs, :]
w = find_hyperplane(A, B)

# Plot the left and right parts of the curve
x_plot = collect(range(minimum(x), maximum(x), length=1000))
(x_left, y_left), (x_right, y_right) = get_rect_hyperbola(w, x_plot)
plot(x_left, y_left, label=nothing, color=:red)
plot!(x_right, y_right, label="Boundary", color=:red)

# Plot the points
scatter!(x[comp_idxs], y[comp_idxs], label="Outliers", legend=:topleft)
scatter!(x[best_idxs], y[best_idxs], label="Inliers", marker=:circle)

# Set axis limits
xlims!(minimum(x), maximum(x))
ylims!(minimum(y), maximum(y))
     
General conic sections and 

Bivariate uniform

# Define scoring and lifting functions
score = (x, y) -> cor(x, y)
lift = (x, y) -> hcat(x.^2, x.*y, y.^2, x, y)

# Generate random data
x, y = rand(n), rand(n)

# Partition points by brute-force
best_idxs, comp_idxs = brute_force(x, y, k, score)

# Separate sets using LP-based hyperplane
lifted = lift(x, y)
A = lifted[best_idxs, :]
B = lifted[comp_idxs, :]
w = find_hyperplane(A, B)

# Plot the boundary
x_plot = collect(range(minimum(x), maximum(x), length=10000))
x_plot, y_plot1, y_plot2 = get_hyperbola(w, x_plot)
plot(x_plot, y_plot1, color=:red, label=nothing)
plot!(x_plot, y_plot2, color=:red, label="Boundary")

# Plot the points
scatter!(x[comp_idxs], y[comp_idxs], label="Outliers")
scatter!(x[best_idxs], y[best_idxs], label="Inliers")

# Set axis limits
xlims!(minimum(x), maximum(x))
ylims!(minimum(y), maximum(y))
     

# Define scoring and lifting functions
score = (x, y) -> cor(x, y)
lift = (x, y) -> hcat(x.^2, x.*y, y.^2, x, y)

# Generate random data
μ = [0.0, 0.0]
Σ = [1.0 0.7; 0.7 1.0]
dist = MvNormal(μ, Σ)
points = rand(dist, n)
x = points[1, :]
y = points[2, :]

# Partition points by brute-force
best_idxs, comp_idxs = brute_force(x, y, k, score)

# Separate sets using LP-based hyperplane
lifted = lift(x, y)
A = lifted[best_idxs, :]
B = lifted[comp_idxs, :]
w = find_hyperplane(A, B)

# Plot the boundary
x_plot = collect(range(minimum(x), maximum(x), length=10000))
x_plot, y_plot1, y_plot2 = get_hyperbola(w, x_plot)
plot(x_plot, y_plot1, color=:red, label=nothing)
plot!(x_plot, y_plot2, color=:red, label="Boundary")

# Plot the points
scatter!(x[comp_idxs], y[comp_idxs], label="Outliers")
scatter!(x[best_idxs], y[best_idxs], label="Inliers")

# Set axis limits
xlims!(minimum(x), maximum(x))
ylims!(minimum(y), maximum(y))
     
Difference of variances and rectangular hyperbolas

# Define scoring and lifting functions
score = (x, y) -> var(x) - var(y)
lift = (x, y) -> hcat(x, x.^2, y, y.^2)

# Generate random data
x, y = rand(n), rand(n)

# Partition points by brute-force
best_idxs, comp_idxs = brute_force(x, y, k, score)

lifted = lift(x, y)
A = lifted[best_idxs, :]
B = lifted[comp_idxs, :]
a1, b1, c1, d1, e1 = find_hyperplane(A, B)

x_plot = range(minimum(x), maximum(x), length=10000)
a = d1
b = c1
c = a1 .* x_plot .+ b1 .* x_plot .^2 .+ e1
disc = b.^2 .- 4 .* a .* c
idxs = disc .> 1e-2
disc = disc[idxs]
y_plot1 = (-b .- sqrt.(disc)) ./ (2 .* a)
y_plot2 = (-b .+ sqrt.(disc)) ./ (2 .* a)

plot(x_plot[idxs], y_plot1, color=:red, label=nothing)
plot!(x_plot[idxs], y_plot2, color=:red, label="Boundary")

# Plot the points
scatter!(x[comp_idxs], y[comp_idxs], label="Comp Set", legend=:topleft)
scatter!(x[best_idxs], y[best_idxs], label="Best Set", marker=:circle)

# Set axis limits
xlims!(minimum(x), maximum(x))
ylims!(minimum(y), maximum(y))
     
