# Sampling Exploration
This notebook explores different algorithms for sampling from a constrained subregion of a simplex.  

## Imports

In [48]:
using LinearAlgebra
using Optim
using Polyhedra
using CDDLib
using Statistics
using Distributions
using GLPK
using JuMP
using MathOptInterface

include("../model_builder/model_builder.jl")
using .ModelBuilder

include("../model_builder/design_initializer.jl")
using .DesignInitializer

include("../optimization/optimality_criterion.jl")
using .OptimalityCriterion

include("./tensor_ops.jl")
import .TensorOps: squeeze



## Gibbs Sampler
The Gibbs sampler for the Dirichlet distribution works by sampling from a single component beta distribution conditioned on the other components.   

In [2]:
# A: 2n x n matrix
# b: 2n x 1 vector
A = [
    I(3);
    -I(3)
]

b = [0.4, 0.7, 1, -0.1, -0.2, 0]

6-element Vector{Float64}:
  0.4
  0.7
  1.0
 -0.1
 -0.2
  0.0

In [49]:
function compute_centroid(A, b)
    p = polyhedron(hrep(A, b), CDDLib.Library())
    verts = collect(points(vrep(p)))
    centroid = mean(verts, dims=1)
    return centroid[1]
end

function get_optimizer(A, b, fixed_indices, fixed_values)
    n = size(A, 2)
    model = Model(GLPK.Optimizer)
    @variable(model, x[1:n])
    @constraint(model, A*x .<= b)

    # Fix the values of the fixed indices
    for (i, val) in zip(fixed_indices, fixed_values)
        set_lower_bound(x[i], val)
        set_upper_bound(x[i], val)
    end

    return model
end

function optim(model)
    optimize!(model)
    if termination_status(model) == MOI.OPTIMAL
        x = model[:x]
        return value.(x)
    else
        return nothing
    end
end

function maximize_at_index(model, target_index)
    x = model[:x]
    set_objective_function(model, Max, x[target_index])
    return optim(model)
end

function minimize_at_index(model, target_index)
    x = model[:x]
    set_objective_function(model, Min, x[target_index])
    return optim(model)
end

minimize_at_index (generic function with 1 method)

In [50]:
m = get_optimizer(A, b, [2, 3], [cand[2], cand[3]])

A JuMP Model
Feasibility problem with:
Variables: 3
`AffExpr`-in-`MathOptInterface.LessThan{Float64}`: 6 constraints
`VariableRef`-in-`MathOptInterface.GreaterThan{Float64}`: 2 constraints
`VariableRef`-in-`MathOptInterface.LessThan{Float64}`: 2 constraints
Model mode: AUTOMATIC
CachingOptimizer state: EMPTY_OPTIMIZER
Solver name: GLPK
Names registered in the model: x

In [51]:
maximize_at_index(m, 1)

UndefVarError: UndefVarError: `Max` not defined

In [35]:
# Initialize
cent = compute_centroid(A, b)
cand = cent

# Iterate through the vector
# for i in axes(cent, 1)
    
# end

# Fix the current element and consider the sub-vector of the remaining elems
i = 1
subvec = vcat(cand[1:i-1], cand[i+1:end])
xi_star = 1 - sum(subvec)


2-element Vector{Float64}:
 0.44999999999999996
 0.5

In [10]:
dist = truncated(Beta(1, 1), .3, .6)
rand(dist, 10)

10-element Vector{Float64}:
 0.32279157066047887
 0.563001066739471
 0.44065402202756315
 0.3175684759392291
 0.36800795658459345
 0.37249660983465255
 0.5928533865087219
 0.5875394829875608
 0.4846212249720399
 0.46337527620424346

In [8]:
# Gibbs sampler for unconstrained simplex
# A is a matrix describing the upper and lower bounds for each component
# function gibbs_sample_constrained_simplex(A, b, N, K; n = 1)

m, n = size(A)

# Initialize with a value respecting all of the upper and lower bound constraints
# Since we only have upper and lower bounds we can take the mean of the two bounds
x = zeros(n)
dists = []
for i in 1:n
    # Get the positive indices and negative indices for upper/lower bounds
    pos_idx = findall(x -> x > 0, A[:, i])
    neg_idx = findall(x -> x < 0, A[:, i])

    upper, lower = b[pos_idx[1]], abs(b[neg_idx[1]])

    # Compute initial value in feasible region
    x[i] = (upper - lower) / 2

    # Build beta distribution for component
    dist = Beta(1, n - 1)
    trunc_dist = Truncated(dist, lower, upper)
    push!(dists, trunc_dist)
end

# Enforce sum-to-one constraint
x = x / sum(x)


3-element Vector{Float64}:
 0.16666666666666669
 0.27777777777777773
 0.5555555555555556

## Smallest Bounding Hypercube Rejection Sampling

In [10]:
function simple_simplex_sampler(N, K)
    a = ones(K)
    sampler = Dirichlet(a)
    samples = zeros(N, K)

    for i in axes(samples, 1)
        samples[i, :] .= rand(sampler)
    end

    return samples
end



# Rejection sampler
function rejection_sampler(n, K, A, b, sampler)
    X = sampler(n, K)
    satisfies = (x) -> all(A * x .<= b)

    while true
        # Get the design points in X that satisfy the constraints
        good_points = vec(mapslices(satisfies, X; dims=2))
        num_bad_points = sum(.!good_points)

        # If there are no bad points, return the good points
        if num_bad_points == 0
            return X
        end

        # Resample bad points
        X[.!good_points, :] .= sampler(num_bad_points, K)
    end
end

function get_simplex_constraints(n)
    # Non-negativity constraints
    A = [
        -1 * I(n);
        ones(1, n)
    ]
    b = zeros(n + 1)

    # Sum to one constraint
    b[end] = 1
    return A, b
end

# Given a set of linear constraints, return a dirichlet distribution for the smallest bounding simplex containing the feasible region
# function get_min_bounding_simplex(A, b)
#     p = polyhedron(vrep(A, b), CDDLib.Library())



get_simplex_constraints (generic function with 1 method)

In [9]:
A = [
    0 -1 0;
    0 0 1;
    5 4 0;
    -20 5 0;
]

b = [
    -1/10;
    3/5;
    39/10;
    -3;
]

4-element Vector{Float64}:
 -0.1
  0.6
  3.9
 -3.0

In [11]:
A_simplex, b_simplex = get_simplex_constraints(size(A, 2))
A = vcat(A, A_simplex)
b = vcat(b, b_simplex)

8-element Vector{Float64}:
 -0.1
  0.6
  3.9
 -3.0
  0.0
  0.0
  0.0
  1.0

In [13]:
p = polyhedron(hrep(A, b), CDDLib.Library())

Polyhedron CDDLib.Polyhedron{Float64}:
8-element iterator of HalfSpace{Float64, Vector{Float64}}:
 HalfSpace([0.0, -1.0, 0.0], -0.1)
 HalfSpace([0.0, 0.0, 1.0], 0.6)
 HalfSpace([5.0, 4.0, 0.0], 3.9)
 HalfSpace([-20.0, 5.0, 0.0], -3.0)
 HalfSpace([-1.0, 0.0, 0.0], 0.0)
 HalfSpace([0.0, -1.0, 0.0], 0.0)
 HalfSpace([0.0, 0.0, -1.0], 0.0)
 HalfSpace([1.0, 1.0, 1.0], 1.0)

In [50]:
verts = collect(points(vrep(p)))
centroid = mean(verts, dims=1)[1]

# Compute vector between centroid and simplex centroid
simplex_centroid = ones(size(A, 2)) / size(A, 2)
translation_vector = centroid - simplex_centroid

# Find the vertex having the maximum distance from the centroid
dists = [norm(centroid - v) for v in verts]
max_dist = max(dists)

# Scale the simplex to have the maximum distance from the centroid


8-element Vector{Float64}:
 0.40701850387912347
 0.34736733078975635
 0.373800163857642
 0.36875
 0.4510837643941533
 0.45384640849080216
 0.3754684573968898
 0.4012188461426008