## Imports

In [4]:
# Custom module for building model matrices
include("model_builder.jl")
using .ModelBuilder

include("design_initializer.jl")
using .DesignInitializer

using LinearAlgebra



In [16]:
model_builder = ModelBuilder.create()
model_builder(rand(2))

2-element Vector{Float64}:
 1.0
 0.11537119699121912

## Initialization

In [17]:
# Fill an nxNxK matrix with values sampled from a uniform dist on [lower, upper]
function init_design(N, K; n=1, lower=-1, upper=1)
    lower .+ rand(n, N, K) .* (upper - lower)
end

# Fill an nxNxK matrix with random values ensuring each row sums to 1
function init_mixture_design(N, K; n=1)
    designs = rand(n, N, K)
    designs ./= sum(designs, dims=3)
    designs
end

init_mixture_design (generic function with 1 method)

In [76]:
function fill_invalid!(X, model_builder, initializer)
    n, N, K = size(X)

    # Find invalid designs
    invalid_indices = findall(ModelBuilder.squeeze(mapslices(x -> det(x'x) < eps(), model_builder(X), dims=[2,3])))

    # If no invalid designs, return
    if length(invalid_indices) == 0
        return X
    end

    # Replace invalid designs with new ones in-place
    X[invalid_indices, :, :] = initializer(N, K, n = length(invalid_indices))

    # Recursively fill invalid designs
    return fill_invalid!(X, model_builder, initializer)
end

function init_filtered_design(N, K, model_builder; n = 1, initializer = init_design)
    # Initialize designs
    designs = initializer(N, K, n = n)

    # Filter out invalid designs
    fill_invalid!(designs, model_builder, initializer)
    return designs
end

init_filtered_design (generic function with 1 method)

In [77]:
model_builder = ModelBuilder.create()
init_filtered_design(3, 2, model_builder, n=2, initializer=init_mixture_design)

2×3×2 Array{Float64, 3}:
[:, :, 1] =
 0.54212   0.839057  0.241001
 0.793928  0.252795  0.92137

[:, :, 2] =
 0.45788   0.160943  0.758999
 0.206072  0.747205  0.0786297