In [2]:
using Random
using LinearAlgebra: I
using Distributions

In [5]:
"""
Hyperparameters for the object-aware HDP+,infinity generative model.

m0      : prior mean for feature vectors (d-dimensional)
k_clu   : shrinkage from cluster means mu_k toward m0
k_obj   : shrinkage from object means phi_i toward their cluster mean mu_k
k_per   : shrinkage from percepts y_io toward their object mean phi_i
a0,b0   : Inverse-Gamma hyperparameters for sigma^2 (per category)
"""
struct ObjectAwareHDPHyperparams
    m0::Vector{Float64}
    k_clu::Float64
    k_obj::Float64
    k_per::Float64
    a0::Float64
    b0::Float64
end

"Data for one object: its cluster label and its percepts."
struct ObjectData
    z::Int                              # cluster index
    phi::Vector{Float64}                # object mean (latent)
    percepts::Vector{Vector{Float64}}   # y_o \in R^d for o = 1..O_i
end

"All data for a single category."
struct CategorySample
    sigma2::Float64                     # category variance
    mus::Vector{Vector{Float64}}        # cluster means mu_k
    objects::Vector{ObjectData}         # objects i = 1..I
end

"Full dataset across multiple categories."
struct Dataset
    categories::Vector{CategorySample}  # length J
end

Dataset

In [6]:
"""
    sample_crp_cluster!(n_k, alpha; rng)

Given current cluster counts n_k (length K), sample a cluster index k for a new object
using a CRP(alpha). If k == K+1, this means "start a new cluster".
"""
function sample_crp_cluster!(n_k::Vector{Int}, alpha::Float64; rng=Random.default_rng())
    K = length(n_k)
    total = sum(n_k) + alpha
    weights = K == 0 ? [1.0] : [n / total for n in n_k]  # existing clusters
    push!(weights, alpha / total)                        # new cluster option
    k = rand(rng, Categorical(weights))
    if k == K + 1
        push!(n_k, 0)  # initialize count for new cluster
    end
    n_k[k] += 1
    return k
end


sample_crp_cluster!

In [None]:
"""
    sample_category(I, O, d, alpha, hyper; rng=Random.default_rng())

Sample one category under the object-aware HDP+,âˆž generative model.

Arguments
--------
I     : number of objects in the category
O     : number of percepts per object (fixed for simplicity)
d     : feature dimension
alpha : CRP concentration (controls number of clusters)
hyper : ObjectAwareHyperparams

Returns
-------
CategorySample with sigma^2, cluster means mu_k, and per-object percept data.
"""
function sample_category(I::Int, O::Int, d::Int,
                         alpha::Float64,
                         hyper::ObjectAwareHyperparams;
                         rng = Random.default_rng())

    # 1. Category-specific variance
    sigma2 = rand(rng, InverseGamma(hyper.a0, hyper.b0))

    # 2. Storage for cluster means and counts
    mus = Vector{Vector{Float64}}()  # mu_k
    n_k = Int[]                      # counts per cluster

    # 3. Sample objects
    objects = Vector{ObjectData}(undef, I)

    for i in 1:I
        # 3a. CRP: choose cluster for object i
        k = sample_crp_cluster!(n_k, alpha; rng=rng)

        # If this is a brand-new cluster, sample its mean mu_k
        if k > length(mus)
            sigma_clu = (sigma2 / hyper.k_clu) * I(d)
            mu_k = rand(rng, MvNormal(hyper.m0, sigma_clu))
            push!(mus, mu_k)
        end
        mu_k = mus[k]

        # 3b. Sample object-level mean phi_i given mu_k
        sigma_obj = (sigma2 / hyper.k_obj) * I(d)
        phi_i = rand(rng, MvNormal(mu_k, sigma_obj))

        # 3c. Sample percepts for this object
        sigma_per = (sigma2 / hyper.k_per) * I(d)
        ys_i = Vector{Vector{Float64}}(undef, O)
        for o in 1:O
            ys_i[o] = rand(rng, MvNormal(phi_i, sigma_per))
        end

        objects[i] = ObjectData(k, phi_i, ys_i)
    end

    return CategorySample(sigma2, mus, objects)
end


Pseudocode for constructing training/test set

Within each category, 

10 objects & 50 percepts in training
- Skewed: [18, 9, 6, 5, 4, 3, 2, 2, 1]
- Uniform: 5 percepts/objects 

10 objects & 50 percepts in test
- 5 percepts/objects

Use sample_category() to sample 10 objects, each with 18 percepts / category
Construct the sets s.t.
- Each object gets to be the most dominant one
- Each object appears in both training and test
- If an object is not the most dominant one in training / is part of test set, randomly sample the number of percepts needed

Get permutation of object ids 1-20. The first 10 becomes training (with frequency in this order) and the last 10 becomes test.