In [None]:
using Gen
using GaussianMixtures
using PyPlot
using Logging

In [None]:
#LogLevel Zero means that @info and @logmsg will be disabled (because author of GM.jl set logmsg to -1)
debug_gm_print = false
LL_idx = debug_gm_print ? -1 : 0
Logging.disable_logging(LogLevel(LL_idx))

In [None]:
# The data index (i) always be a the first (row) index
# The feature dimension index (k) always to be a the second (column) index
n = 1000
features = 2
identity = zeros(2,2)
identity[1,1] = 1; identity[2,2] = 1;
cov1 = zeros(2,2)
cov1[1,1] = 1; cov1[2,2] = 3; 
cov1[1,2] = 0.1; cov1[2,1] = 0.1;
covs = [ identity, cov1, 0.2*cov1]
mus = [ -5*ones(2), [0.0, 3.0],  [-1.0, 0] ]

X = zeros(n, 2)
for i = 1:n
    idx = Gen.categorical([0.4,0.2,0.4])
    X[i,:] = mvnormal(mus[idx],covs[idx])
end

In [None]:
nclusters = 10
sgmm_det_2 = GaussianMixtures.GMM(nclusters, X; method=:kmeansdet, kind=:full, nInit=10, nIter=10, nFinal=10, rng_seed=3, loglevel=LL_idx)

In [None]:
for j = 1:10
@time x = zeros(nclusters,2,2); 
@time for i=1:nclusters
    x[i,:,:] = GaussianMixtures.covar(sgmm_det_2.Σ[i])
end
    println("~~~")
end


In [None]:
for j = 1:10
@time x=permutedims(cat(GaussianMixtures.covars(sgmm_det_2)...,dims=3),[3,1,2]);
end

In [None]:
sgmm = GaussianMixtures.GMM(3, X; method=:kmeans, kind=:full, nInit=10, nIter=10, nFinal=10);

In [None]:
scatter(X[:,1],X[:,2])

In [None]:
# type GMM
#     n::Int                         # number of Gaussians
#     d::Int                         # dimension of Gaussian
#     w::Vector                      # weights: n
#     μ::Array                       # means: n x d
#     Σ::Union(Array, Vector{Array}) # diagonal covariances n x d, or Vector n of d x d full covariances
#     hist::Array{History}           # history of this GMM
# end

In [None]:
Y = zeros(n, 2)
for i = 1:n
    idx = Gen.categorical(sgmm.w)
    Y[i,:] = mvnormal(sgmm.μ[idx,:],GaussianMixtures.covar(sgmm.Σ[idx]))
end
scatter(X[:,1],X[:,2])
scatter(Y[:,1],Y[:,2])
legend(["Actual data", "Sample from posterior"])

In [None]:
sgmm_det = GaussianMixtures.GMM(3, X; method=:kmeansdet, kind=:full, nInit=10, nIter=10, nFinal=10, rng_seed=1)

In [None]:
Y = zeros(n, 2)
for i = 1:n
    idx = Gen.categorical(sgmm_det.w)
    Y[i,:] = mvnormal(sgmm_det.μ[idx,:],GaussianMixtures.covar(sgmm_det.Σ[idx]))
end
scatter(X[:,1],X[:,2],alpha=0.1)
scatter(Y[:,1],Y[:,2],alpha=0.1)
legend(["Actual data", "Sample from posterior"])

In [None]:
sgmm_det = GaussianMixtures.GMM(3, X; method=:kmeansdet, kind=:full, nInit=10, nIter=10, nFinal=10, rng_seed=1)

In [None]:
nclusters = 2
sgmm_det_2 = GaussianMixtures.GMM(nclusters, X; method=:kmeansdet, kind=:full, nInit=10, nIter=10, nFinal=10, rng_seed=3)

In [None]:
sgmm_det_3 = GaussianMixtures.GMM(nclusters, X; method=:kmeansdet, kind=:full, nInit=10, nIter=10, nFinal=10, rng_seed=3)

In [None]:
plt.figure(figsize=(8,8))
for i = [1,2]
    
    sgmm_to_use = i == 1 ? sgmm_det_2 : sgmm_det_3

    subplot(2,1,i)
    Y = zeros(n, 2)
    for i = 1:n
        idx = Gen.categorical(sgmm_to_use.w)
        Y[i,:] = mvnormal(sgmm_to_use.μ[idx,:],GaussianMixtures.covar(sgmm_to_use.Σ[idx]))
    end
    scatter(X[:,1],X[:,2],alpha=0.1)
    scatter(Y[:,1],Y[:,2],alpha=0.1)
    legend(["Actual data", "Sample from posterior"])
    title("Run $i, with rng_seed=3 (different rng for sampling), nclusters=$nclusters")
    
end
plt.tight_layout()

In [None]:
include("../model/extra_distributions.jl")
plt.figure(figsize=(8,8))
for i = [1]
    
    sgmm_to_use = i == 1 ? sgmm_det_2 

#     cs = zeros(nclusters, 2, 2)
#     for idx=1:nclusters
#         cs[idx,:,:] = GaussianMixtures.covar(sgmm_to_use.Σ[idx])
#     end
    cs = GaussianMixtures.covars
    
    subplot(2,1,i)
    Y = zeros(n, 2)
    for i = 1:n
        Y[i,:] = Gen.random(mvn_mixture,sgmm_to_use.μ,cs,sgmm_to_use.w)
    end
    scatter(X[:,1],X[:,2],alpha=0.1)
    scatter(Y[:,1],Y[:,2],alpha=0.1)
    legend(["Actual data", "Sample from posterior"])
    title("Using Gen.random mvn_mixture: Run $i, with rng_seed=3 for inference, nclusters=$nclusters")
    
end
plt.tight_layout()