# Dirichlet Process Clustering

In [1]:
using Gen
using Statistics
using PyPlot

We'll use a Cauchy distribution as the base for our DP.

In [35]:
struct Cauchy <: Gen.Distribution{Float64} end
const cauchy = Cauchy()

Gen.random(cauchy::Cauchy, loc, scale) = loc + scale*tan(pi*(rand() - 0.5))
Gen.logpdf(cauchy::Cauchy, x, loc, scale) = -log(pi*scale*(1. + ((x - loc)/scale)^2)) 

cauchy(loc,scale) = Gen.random(cauchy, loc, scale);

Build the Dirichlet Process model:

In [184]:
function compute_cluster_probs(counts, alpha, n)
    denominator = alpha + n
    probs = counts ./ denominator
    push!(probs, alpha/denominator)
    return probs
end
    
function update_counts!(counts, chosen_cluster)
    if chosen_cluster > length(counts)
        push!(counts, 1)
    else
        counts[chosen_cluster] += 1
    end
    return counts
end

@gen (static) function sample_point(n::Int64, counts::Array{Int,1}, alpha::Float64)
   
    probs = compute_cluster_probs(counts, alpha, n)
    cluster_idx = @trace(Gen.categorical(probs), :cl_idx)
    
    counts = update_counts!(counts, cluster_idx)
    return counts
    
end


dp_dist = Gen.Unfold(sample_point)


@gen (static) function dp_cluster(M::Int64, N::Int64, alpha::Float64)
        
    dp_counts = @trace(dp_dist(M-1, [1], alpha), :dp)
    dp_counts = dp_counts[1]
    dp_probs = dp_counts ./ sum(dp_counts)
    return dp_probs
    
end

getfield(Main, Symbol("##StaticGenFunction_dp_cluster#727"))(Dict{Symbol,Any}(), Dict{Symbol,Any}())

In [185]:
Gen.load_generated_functions()

In [191]:
tr = Gen.simulate(dp_cluster, (1000, 0, 1.0));
get_retval(tr)

6-element Array{Float64,1}:
 0.962
 0.028
 0.005
 0.003
 0.001
 0.001