In [1]:
using Random
using Statistics
using Combinatorics
using ProgressMeter
using Gurobi
using JuMP
using Suppressor
using LinearAlgebra
using Plots
using Distributions
using DataFrames
using CSV
using LaTeXStrings

In [2]:
function brute_force(x::Vector{Float64}, y::Vector{Float64}, k::Int, score::Function; min_::Bool=false)::Tuple{Vector{Int64},Vector{Int64}}
    @assert size(x)==size(y) "x and y must be the same shape"
    @assert k<size(x, 1) "k must be less than n"

    n = size(x, 1)
    best_val = min_ ? Inf : -Inf
    best_idxs = nothing
    
    # Try each subset of size k
    for idxs in combinations(1:n, k)
        val = score(x[idxs], y[idxs])
        if (min_ && val < best_val) || (!min_ && val > best_val)
            best_val = val
            best_idxs = idxs
        end
    end

    # Get complement of best set
    comp_idxs = setdiff(1:n, best_idxs)

    return best_idxs, comp_idxs
end;

In [3]:
function get_triple(X, Y, idxs, lift)
     LX = lift(X, Y)
     d = size(LX, 2)
     LX = LX[idxs, :]
     LX_aug = hcat(LX, ones(d))
     N = nullspace(LX_aug)
     N[:]
end

get_triple (generic function with 1 method)

In [4]:
function get_rect_hyperbola(N::Vector{Float64}, x::Vector{Float64})
    a, b, c, d, e = N
    disc = d^2 .- 4 .* b .* (a .* x.^2 .+ c .* x .+ e)
    idxs = disc .>= 0
    disc = Complex.(disc)
    y1 = (-d .+ sqrt.(disc)) ./ (2 * b)
    y2 = (-d .- sqrt.(disc)) ./ (2 * b)
    x, y1, y2 = x[idxs], y1[idxs], y2[idxs]
    x, y1, y2 = Real.(x), Real.(y1), Real.(y2)
    x, y1, y2
end;

In [5]:
function get_pivots(L, best_idxs; eps=1e-10)
    n, d = size(L)
    L_aug = hcat(L, ones(n))
    comp_idxs = setdiff(1:n, best_idxs)
    pivots = []
    for P in combinations(1:n, d)
        N = nullspace(L_aug[P, :])
        best_ = setdiff(best_idxs, P)
        comp_ = setdiff(comp_idxs, P)
        best_prod = L_aug[best_, :] * N
        comp_prod = L_aug[comp_, :] * N
        if (all(best_prod .<= eps) && all(comp_prod .>= -eps)) ||  (all(best_prod .>= -eps) && all(comp_prod .<= eps))
            push!(pivots, (P, N[:]))
        end
    end
    pivots
end

get_pivots (generic function with 1 method)

## Visualize pivots and boundaries for dataset below

In [None]:
n = 15
k = 8
lift = (x, y) -> hcat(x.^2, y.^2, x, y)

seed = 123
seed = rand(UInt128)
Random.seed!(seed)

x = rand(n)
y = rand(n)

best_idxs, comp_idxs = brute_force(x, y, k, (x, y) -> var(x) - var(y))

p = scatter(x[best_idxs], y[best_idxs], label="I")
scatter!(x[comp_idxs], y[comp_idxs], label="O")

L = hcat(x.^2, y.^2, x, y)
pivots = get_pivots(L, best_idxs)

for (_, N) in pivots
     xplot = collect(0:0.01:1)
     xplot, y1, y2 = get_rect_hyperbola(N, xplot)
    plot!(xplot, y1, color=:red, label=nothing)
    plot!(xplot, y2, color=:red, label=nothing)
end

# idxs = [1, 2, 3, 4]
# N = get_triple(x, y, idxs, lift)
# # @show N
#


# scatter!(x[idxs], y[idxs])
p
xlims!(0, 1)
ylims!(0, 1)

## Check pivots

In [6]:
n = 15
k = 8
score = (x, y) -> cor(x, y)
lift = (x, y) -> hcat(x.^2, y.^2, x.*y, x, y)

score = (x, y) -> cov(x, y)
lift = (x, y) -> hcat(x.*y, x, y)

score = (x, y) -> -var(x) - var(y)
lift = (x, y) -> hcat(x.^2, y.^2, x, y)

@showprogress for i=1:10_000
    seed = rand(UInt128)
    Random.seed!(seed)
    
    x = rand(n)
    y = rand(n)
    
    best_idxs, comp_idxs = brute_force(x, y, k, score)

    pivots = get_pivots(lift(x, y), best_idxs, eps=0.)
    
    # x_ctr = mean(x[best_idxs])
    # y_ctr = mean(y[best_idxs])
    # dists = (x[best_idxs] .- x_ctr).^2 .- (y[best_idxs] .- y_ctr).^2
    # i = argmin(dists)

    inters = [length(intersect(pivot[1], best_idxs)) for pivot in pivots]
    # intersection1 = intersect(pivot_points, comp_idxs)
    # intersection2 = intersect(pivot_points, best_idxs)
    @assert maximum(inters) > 1
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:01[39m
