# MCMC experiment: round 2

Mixture of Gaussians

In [None]:
using Gen
using PyPlot

┌ Info: Precompiling Gen [ea4f424c-a589-11e8-07c0-fd5c91b9da4a]
└ @ Base loading.jl:1242


## Define a model

In [None]:
@gen function gauss_mix(n_clusters::Int64, n_samples::Int64)
    
    cluster_means = [@trace(normal(0.0, 10.0), :means => i) for i=1:n_clusters]
    cluster_spreads = [@trace(gamma(1.0, 1.0), :spreads => i) for i=1:n_clusters]
    cluster_probs = fill(1.0/n_clusters, n_clusters)
    
    z = zeros(n_samples, 1)
    
    for j=1:n_samples
    
        c = @trace(categorical(cluster_probs), :cluster => j)
        z[j] = @trace(normal(cluster_means[c], cluster_spreads[c]), :z => j)
    end
    
    return z

end;

## Sanity check: simulate & visualize

In [None]:
n_samples = 500
n_clusters = 3
    
tr = simulate(gauss_mix, (n_clusters, n_samples));

In [None]:
function plot_trace(tr)
    
    zs = [tr[:z => j] for j=1:n_samples];
    means = [tr[:means => i] for i=1:n_clusters];
    assignments = [tr[:cluster => j] for j=1:n_samples]
    
    plot_choices(zs, means, assignments)

end

function plot_choices(zs, means, assignments)
    
    hist_colors = ["gray", "blue", "red", "orange", "yellow"]
    
    z_sets = []
    
    for (i, mean) in enumerate(means)
        plot([mean; mean], [0; n_samples], color="k")
        cluster_z = zs[assignments .== i]
        hist(cluster_z, 20, color=hist_colors[i])
    end
    
end;

In [None]:
plot_trace(tr)
ylim(0, 100)
show()

## Inference: block resimulation

In [None]:
function block_resimulation_update(tr, n_clusters, n_data)

    # block 1: cluster locations and variances
    for i=1:n_clusters
        (tr, _) = Gen.mh(tr, select(:means => i))
        (tr, _) = Gen.mh(tr, select(:spreads => i))
    end
    
    # block 2: cluster assignments
    for j=1:n_data
        (tr, _) = Gen.mh(tr, select(:cluster => j))
    end
    
    return tr
end

In [None]:
function block_resim_sample(n_clusters, data)
   
    # set up the constraints
    observations = Gen.choicemap()
    for (i, z) in enumerate(data)
        observations[:z => i] = z
    end
    
    # Initialize the state
    (tr, _) = Gen.generate(gauss_mix, (n_clusters, length(data)), observations)
    
    # This loop, in essence, performs thinning.
    for t=1:1
        tr = block_resimulation_update(tr, n_clusters, length(data))
    end
    
    return tr
end

In [None]:
function block_resim_inference(n_clusters, data, n_samples)
    
    centers = zeros((n_samples, n_clusters))
    assignments = zeros((length(data), n_clusters))
    
    for i=1:n_samples
        println("Sample ", i)
        sampled_tr = block_resim_sample(n_clusters, data)
        
        for k=1:n_clusters
            centers[i,k] = sampled_tr[:means => k]
        end
        for k=1:length(data)
            assignments[k,sampled_tr[:cluster => k]] += 1
        end
        
    end
    
    return centers, assignments 
end;

In [None]:
zs = [tr[:z => i] for i=1:n_samples];
centers, assignments = block_resim_inference(3, zs, 20)

In [None]:
centers

## Inference: k-means guided resampling

Our straightforward approach had some issues -- especially with identifiability.

During resampling, there was little reason for 

In [None]:
using Statistics

function mean_update(data, cluster_assignments, k)
    return [mean(data[cluster_assignments .== i]) for i=1:k]
end

function cluster_update(data, means)
    
    dists = map(abs, data .- transpose(means))
    min_inds = argmin(dists, dims=2)
    
    return [min_inds[i][2] for i=1:length(min_inds)]
end


function kmeans(data::Array{Float64,1}, k::Int64, spread=1.0, max_iter::Int64=1000)
    
    # random initialization
    cluster_assignments = zeros(size(data)[1])
    means = rand(data, k)
    
    i = 1
    while i <= max_iter
        
        cluster_assignments = cluster_update(data, means)
        new_means = mean_update(data, cluster_assignments, k)
        
        if new_means == means
            break
        end
        
        means = new_means
        i += 1
        
    end
    
    # Sort the means in order to ameliorate
    # identifiability issues.
    srt_inds = sortperm(means)
    inv_map = zeros(size(srt_inds))
    for (i, ind) in enumerate(srt_inds)
        inv_map[ind] = i
    end
    ca = map(x->inv_map[x], cluster_assignments)
    
    return means[srt_inds], ca
end

In [None]:

means, assignments = kmeans(zs, 5, 10.0, 1000)

plot_choices(zs, means, assignments)
ylim(0.0, 30.0)
show()

In [None]:
generate()