## Overview
In order to compute the I-Criterion for designs in any design space, it is necessary to compute the region moments matrix, which is a constant with respect to a given feasible region and model expansion function. For a given model expansion function $$\mathbf f:\mathcal{X} \rightarrow \mathbb R^{p}$$ that expands a design point into a feature vector of the $p$ corresponding model terms and a given design space $\mathcal X\subseteq \mathbb R^K$, the region moments matrix is $$\int_\mathcal{X} \mathbf f(\mathbf x')\mathbf f'(\mathbf x')\text{d}\mathbf x.$$ For constrained feasible regions defined by linear constraints in high-dimensions, this integral can be difficult to compute analytically or symbolically, and numerical approximation is typically employed. This notebook includes a simple Monte Carlo integration technique along with a rejection sampling algorithm to generate large, uniformly distributed subsets of the design space used in the integral estimation. 

## Imports

In [1]:
using LinearAlgebra
using Optim
using Polyhedra
using CDDLib
using Statistics
using Distributions

include("./model_builder/model_builder.jl")
using .ModelBuilder

include("./model_builder/design_initializer.jl")
using .DesignInitializer

include("./optimization/optimality_criterion.jl")
using .OptimalityCriterion

include("./utility/util.jl")
import .Util: squeeze

## Rejection Sampling and Region Moments Computation

In [2]:
# Sample points from the simplex using the Dirichlet distribution
function simple_simplex_sampler(N, K)
    a = ones(K)
    sampler = Dirichlet(a)
    samples = zeros(N, K)

    for i in axes(samples, 1)
        samples[i, :] .= rand(sampler)
    end

    return samples
end

# Rejection sampler
function rejection_sampler(n, K, A, b, sampler)
    X = sampler(n, K)
    satisfies = (x) -> all(A * x .<= b)

    while true
        # Get the design points in X that satisfy the constraints
        good_points = vec(mapslices(satisfies, X; dims=2))
        num_bad_points = sum(.!good_points)

        # If there are no bad points, return the good points
        if num_bad_points == 0
            return X
        end

        # Resample bad points
        X[.!good_points, :] .= sampler(num_bad_points, K)
    end
end

function get_simplex_constraints(n)
    # Non-negativity constraints
    A = [
        -1 * I(n);
        ones(1, n)
    ]
    b = zeros(n + 1)

    # Sum to one constraint
    b[end] = 1
    return A, b
end

function compute_volume(A, b; affines=BitSet([]))
    p = polyhedron(hrep(A, b, affines), CDDLib.Library())
    return volume(p)
end

function compute_outer_product_mean(X, f)
    expanded = f(X)
    n, p = size(expanded)
    total_sum = zeros((p, p))
    
    for i in axes(expanded, 1)
        outer_prod = expanded[i, :] * expanded[i, :]'
        total_sum .+= outer_prod
    end

    mean_est = total_sum ./ n

    return mean_est
end

function compute_elem_mean(X, f)
    expanded = f(X)
    n, N, K = size(expanded)
    return sum(expanded, dims=1) ./ n
end

function compute_mc_integral(X, f, A, b)
    # Compute volume of the polytope
    vol = compute_volume(A, b)

    # Compute the integral estimate
    mean_est = compute_outer_product_mean(X, f)
    return mean_est * vol
end

function mc_integrate_constrained_simplex(A, b, f; n=100_000)
    # Sample points from the simplex using the Dirichlet distribution
    X = rejection_sampler(n, size(A, 2), A, b, simple_simplex_sampler)

    # Get augmented constraints with simplex for volume computation
    A_simplex, b_simplex = get_simplex_constraints(size(A, 2))
    A = vcat(A, A_simplex)
    b = vcat(b, b_simplex)

    # Compute the integral estimate
    return compute_mc_integral(X, f, A, b)
end

function mc_integrate_constrained_hypercube(A, b, f; n=100_000)
    sampler = (n, K) -> squeeze(DesignInitializer.init_design(n, K))
    X = rejection_sampler(n, size(A, 2), A, b, sampler)
    return compute_mc_integral(X, f, A, b)
end

function mc_integrate(A, b, f; n=100_000, mixture=true)
    if mixture
        return mc_integrate_constrained_simplex(A, b, f, n=n)
    else
        return mc_integrate_constrained_hypercube(A, b, f, n=n)
    end
end

# Thanks ChatGPT
function format_matrix(matrix::Matrix{Float64}; precision::Int=4, threshold::Float64=1e-4)
    # Define a function to format a single number
    format_number = x -> abs(x) < threshold ? 0.0 : round(x, digits=precision)

    # Apply the formatting function to each element in the matrix
    formatted_matrix = map(format_number, matrix)
    
    return formatted_matrix
end

format_matrix (generic function with 1 method)

In [3]:
f = ModelBuilder.quadratic_interaction
model_builder = (x) -> f(x)[:, [1, 2, 4, 6, 3, 5]]

#20 (generic function with 1 method)

### Hypercube Example

In [5]:
A = [
    -1 0;
    0 -1;
    1 0;
    0 1;
]

b = [
    1;
    1;
    1;
    1;
]

4-element Vector{Int64}:
 1
 1
 1
 1

In [10]:
mat_integral = mc_integrate(A, b, model_builder; n=1_000_000, mixture=false)

6×6 Matrix{Float64}:
  4.0          -0.000556312  -0.00231114  …   1.33433       1.33255
 -0.000556312   1.33433       0.00295631     -0.000518295  -0.00164794
 -0.00231114    0.00295631    1.33255        -3.73151e-5   -0.00233222
  0.00295631   -3.73151e-5   -0.00164794      0.00137761    0.00207064
  1.33433      -0.000518295  -3.73151e-5      0.800964      0.444516
  1.33255      -0.00164794   -0.00233222  …   0.444516      0.799174

In [8]:
format_matrix(mat_integral)

6×6 Matrix{Float64}:
  4.0     -0.0034   0.001    0.0      1.3346   1.3336
 -0.0034   1.3346   0.0      0.001   -0.0028  -0.0013
  0.001    0.0      1.3336  -0.0013   0.001    0.0015
  0.0      0.001   -0.0013   0.445   -0.0003   0.0006
  1.3346  -0.0028   0.001   -0.0003   0.8011   0.445
  1.3336  -0.0013   0.0015   0.0006   0.445    0.7999

In [9]:
vol = compute_volume(A, b)
format_matrix(mat_integral ./ vol)

6×6 Matrix{Float64}:
  1.0     -0.0009   0.0003   0.0      0.3336   0.3334
 -0.0009   0.3336   0.0      0.0003  -0.0007  -0.0003
  0.0003   0.0      0.3334  -0.0003   0.0003   0.0004
  0.0      0.0003  -0.0003   0.1113   0.0      0.0001
  0.3336  -0.0007   0.0003   0.0      0.2003   0.1113
  0.3334  -0.0003   0.0004   0.0001   0.1113   0.2

### Constrained Simplex Example

In [11]:
A = [
    0 -1 0;
    0 0 1;
    5 4 0;
    -20 5 0;
]

b = [
    -1/10;
    3/5;
    39/10;
    -3;
]

4-element Vector{Float64}:
 -0.1
  0.6
  3.9
 -3.0

In [14]:
mat_integral = mc_integrate(A, b, model_builder; n=1_000_000, mixture=true)

6×6 Matrix{Float64}:
 0.0445833   0.0178404   0.0121779   0.014565    0.00765987   0.00393217
 0.0178404   0.00765987  0.00464327  0.0055373   0.00350854   0.00144064
 0.0121779   0.00464327  0.00393217  0.00360251  0.00188454   0.00144197
 0.014565    0.0055373   0.00360251  0.00542514  0.00226679   0.00104955
 0.00765987  0.00350854  0.00188454  0.00226679  0.00170058   0.000556067
 0.00393217  0.00144064  0.00144197  0.00104955  0.000556067  0.000579413

In [15]:
format_matrix(mat_integral)

6×6 Matrix{Float64}:
 0.0446  0.0178  0.0122  0.0146  0.0077  0.0039
 0.0178  0.0077  0.0046  0.0055  0.0035  0.0014
 0.0122  0.0046  0.0039  0.0036  0.0019  0.0014
 0.0146  0.0055  0.0036  0.0054  0.0023  0.001
 0.0077  0.0035  0.0019  0.0023  0.0017  0.0006
 0.0039  0.0014  0.0014  0.001   0.0006  0.0006

In [45]:
A_simplex, b_simplex = get_simplex_constraints(size(A, 2))
A = vcat(A, A_simplex)
b = vcat(b, b_simplex)
vol = compute_volume(A, b)

0.044583333333333336

### Integrating the Objective Function
Integrating the D-Criterion for $N=12, K=3$ mixture designs.

In [46]:
model_builder = ModelBuilder.quadratic_interaction
obj = OptimalityCriterion.d_criterion
f = obj ∘ model_builder
X = rejection_sampler(12 * 1_000_000, 3, A, b, simple_simplex_sampler)
X = reshape(X, (1_000_000, 12, 3))