# Bregman clustering
## make results figure

In [None]:
import numpy as np
import matplotlib.pyplot as plt


alg_names = ['spkmeans',
             'softBregClust',
             'softBregClust',
             'softmovMF',
             'softmovMF',
             'hardBregClust',
             'hardBregClust',
             'hardmovMF',
             'hardmovMF']

Kminmax= [[2,11], [4,40]]
true_Ks = [3, 20]

n_repets = 10
version = 0
seed = 0

clrs = ['b', 'magenta', 'magenta', 'black', 'black', 'orange', 'orange', 'green', 'green' ]
lnstl = ['-', '--', '-', '--', '-', '--', '-', '--', '-']
mrkrs = ['o','o','o','.','.','o','o','.','.']


fn_rootroots = [['results/classic3_', 'results/classic300_'],
                ['results/news20_', 'results/news20small_']]
plt.figure(figsize=(9,8))

for j in range(len(fn_rootroots)):
    for i, fn_rootroot in enumerate(fn_rootroots[j]):

        plt.subplot(2,2,2*i + j + 1)
        out = dict(np.load(
            f'{fn_rootroot}NMIs_{n_repets}repets_seed_0_v{version}__K{Kminmax[j][0]}_{Kminmax[j][1]}.npy',
            allow_pickle=True).tolist())

        for f in range(len(fn_roots)):
            if lnstl[f] == '-':
                plt.plot(out['K_range'], out['MIs'].mean(axis=1)[f], label=alg_names[f], 
                         color=clrs[f], linestyle=lnstl[f], marker=mrkrs[f])
            else:
                plt.plot(out['K_range'], out['MIs'].mean(axis=1)[f], 
                         color=clrs[f], linestyle=lnstl[f], marker=mrkrs[f])
        plt.title(fn_rootroot[8:-1])
        plt.plot(0, out['MIs'].mean(axis=1)[f][0], 'k-', label='tied variance')
        plt.plot(0, out['MIs'].mean(axis=1)[f][0], 'k--', label='free variance')
        plt.xlim([Kminmax[j][0], Kminmax[j][1]])
        plt.plot([true_Ks[j], true_Ks[j]], [0.9*out['MIs'].mean(axis=1).min(), 1.05*out['MIs'].mean(axis=1).max()],
                 color='gray', alpha=0.5)
        plt.ylim([0.9*out['MIs'].mean(axis=1).min(), 1.05*out['MIs'].mean(axis=1).max()])
        if i == 0 and j == 0:
            plt.legend()
        if j == 0:
            plt.ylabel('normalized mutual information')
        if i == 1:
            plt.xlabel('# of clusters K')
        
plt.savefig('main_results_clusteringonhypersphere.pdf')
plt.show()

# Bregman clustering contd'
## compute & store (normalized) mutual information across all algorithms and datasets

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics

from exps import load_classic3_sklearn, load_news20_sklearn
from vMFne.bregman_clustering import posterior_marginal_vMF_mixture_Ψ
from vMFne.moVMF import posterior_marginal_vMF_mixture_Φ

Ψ0 = [None, 0.]

# Calculate MI between two clusterings
def mi(class_true, class_est):
    cont = metrics.cluster.contingency_matrix(class_true, class_est, sparse=True).astype(np.float64, copy=False)
    mi = metrics.cluster.mutual_info_score(class_true, class_est, contingency=cont)
    return mi

def get_MIs(fn_rootroot, fn_roots, n_repets, K_range, class_true, NMI=False):
    MIs = np.zeros((len(fn_roots), n_repets, len(K_range)))
    comp_mi = metrics.normalized_mutual_info_score if NMI else mi
    for f,fn_root in enumerate(fn_roots):
        for k,K in enumerate(K_range):
            if f == 0:
                fn = fn_rootroot + fn_root + str(K) + '.npy'
                out = np.load(fn, allow_pickle=True)
                for i in range(n_repets):
                    class_est = out[i]
                    MIs[f,i,k] =  comp_mi(class_true.flatten(), class_est.flatten())
            elif f in [1,2,5,6]:
                fn = fn_rootroot + fn_root + str(K) + '.npz'
                out = np.load(fn, allow_pickle=True)['out'].tolist()
                for i in range(n_repets):
                    ph_x_μ, _ = posterior_marginal_vMF_mixture_Ψ(X,out['w'][i],out['μs'][i], Ψ0=Ψ0)
                    class_est = np.argmax(ph_x_μ,axis=1)
                    MIs[f,i,k] = comp_mi(class_true.flatten(), class_est.flatten()) 
            elif f in [3,4,7,8]:
                fn = fn_rootroot + fn_root + str(K) + '.npz'
                out = np.load(fn, allow_pickle=True)['out'].tolist()
                for i in range(n_repets):
                    ph_x_η, _ = posterior_marginal_vMF_mixture_Φ(X,out['w'][i],out['ηs'][i])
                    class_est = np.argmax(ph_x_η,axis=1)
                    MIs[f,i,k] = comp_mi(class_true.flatten(), class_est.flatten()) 
    return MIs

def get_fn_roots(n_repets, version):
    return [f'spkmeans_{n_repets}repets_seed_0_v{version}__K_',
            f'softBregClust_{n_repets}repets_seed_0_no_tying__v{version}__K_',
            f'softBregClust_{n_repets}repets_seed_0_with_tying__v{version}__K_',
            f'softmovMF_{n_repets}repets_seed_0_no_tying__v{version}__K_',
            f'softmovMF_{n_repets}repets_seed_0_with_tying__v{version}__K_',
            f'hardBregClust_{n_repets}repets_seed_0_no_tying__v{version}__K_',
            f'hardBregClust_{n_repets}repets_seed_0_with_tying__v{version}__K_',
            f'hardmovMF_{n_repets}repets_seed_0_no_tying__v{version}__K_',
            f'hardmovMF_{n_repets}repets_seed_0_with_tying__v{version}__K_']


In [None]:
fn_rootroot = 'results/classic3_'
K_range = np.arange(2,12,1)
n_repets = 10
version = 0
seed = 0

X, labels, dictionary = load_classic3_sklearn()
class_true = sum([ (1.*i) * (labels==np.unique(labels)[i]) for i in range(len(np.unique(labels)))])

fn_roots = get_fn_roots(n_repets, version)
MIs_classic3 = get_MIs(fn_rootroot, fn_roots, n_repets, K_range, class_true, NMI=False)

plt.figure(figsize=(6,4))
clrs = ['b', 'magenta', 'magenta', 'black', 'black', 'orange', 'orange', 'green', 'green' ]
lnstl = ['-', '--', '-', '--', '-', '--', '-', '--', '-']
mrkrs = ['o','o','o','.','.','o','o','.','.']
for f in range(len(fn_roots)):
    plt.plot(K_range, MIs_classic3.mean(axis=1)[f], label=fn_roots[f], 
             color=clrs[f], linestyle=lnstl[f], marker=mrkrs[f])
plt.xlabel('K')
plt.ylabel('NMI')
#plt.legend()
plt.show()

np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', MIs_classic3)
np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', 
        {'MIs' : MIs_classic3,
         'algorithms' : fn_roots,
         'K_range' : K_range,
         'seed' : seed,
         'n_repets' : n_repets,
        })


In [None]:
fn_rootroot = 'results/classic300_'
K_range = np.arange(2,12,1)
n_repets = 10
version = 0
seed = 0

X, labels, dictionary = load_classic3_sklearn(classic300=True, min_df=2, max_df=0.15)
class_true = sum([ (1.*i) * (labels==np.unique(labels)[i]) for i in range(len(np.unique(labels)))])

fn_roots = get_fn_roots(n_repets, version)
MIs_classic300 = get_MIs(fn_rootroot, fn_roots, n_repets, K_range, class_true, NMI=False)

plt.figure(figsize=(6,4))
clrs = ['b', 'magenta', 'magenta', 'black', 'black', 'orange', 'orange', 'green', 'green' ]
lnstl = ['-', '--', '-', '--', '-', '--', '-', '--', '-']
mrkrs = ['o','o','o','.','.','o','o','.','.']
for f in range(len(fn_roots)):
    plt.plot(K_range, MIs_classic300.mean(axis=1)[f], label=fn_roots[f], 
             color=clrs[f], linestyle=lnstl[f], marker=mrkrs[f])
plt.xlabel('K')
plt.ylabel('NMI')
#plt.legend()
plt.show()

np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', MIs_classic300)
np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', 
        {'MIs' : MIs_classic300,
         'algorithms' : fn_roots,
         'K_range' : K_range,
         'seed' : seed,
         'n_repets' : n_repets,
        })


In [None]:
fn_rootroot = 'results/news20_'
K_range = np.arange(4,41,4)
n_repets = 10
version = 0
seed = 0

X, labels, dictionary = load_news20_sklearn()
class_true = sum([ (1.*i) * (labels==np.unique(labels)[i]) for i in range(len(np.unique(labels)))])

fn_roots = get_fn_roots(n_repets, version)
MIs_news20 = get_MIs(fn_rootroot, fn_roots, n_repets, K_range, class_true, NMI=False)

plt.figure(figsize=(6,4))
clrs = ['b', 'magenta', 'magenta', 'black', 'black', 'orange', 'orange', 'green', 'green' ]
lnstl = ['-', '--', '-', '--', '-', '--', '-', '--', '-']
mrkrs = ['o','o','o','.','.','o','o','.','.']
for f in range(len(fn_roots)):
    plt.plot(K_range, MIs_news20.mean(axis=1)[f], label=fn_roots[f], 
             color=clrs[f], linestyle=lnstl[f], marker=mrkrs[f])
plt.xlabel('K')
plt.ylabel('NMI')
#plt.legend()
plt.show()

np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', MIs_news20)
np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', 
        {'MIs' : MIs_news20,
         'algorithms' : fn_roots,
         'K_range' : K_range,
         'seed' : seed,
         'n_repets' : n_repets,
        })


In [None]:
fn_rootroot = 'results/news20small_'
K_range = np.arange(4,41,4)
n_repets = 10
version = 0
seed = 0
X, labels, dictionary = load_news20_sklearn(news20_small=True, min_df=2, max_df=0.15)
class_true = sum([ (1.*i) * (labels==np.unique(labels)[i]) for i in range(len(np.unique(labels)))])

fn_roots = get_fn_roots(n_repets, version)
MIs_news20small = get_MIs(fn_rootroot, fn_roots, n_repets, K_range, class_true, NMI=False)

plt.figure(figsize=(6,4))
clrs = ['b', 'magenta', 'magenta', 'black', 'black', 'orange', 'orange', 'green', 'green' ]
lnstl = ['-', '--', '-', '--', '-', '--', '-', '--', '-']
mrkrs = ['o','o','o','.','.','o','o','.','.']
for f in range(len(fn_roots)):
    plt.plot(K_range, MIs_news20small.mean(axis=1)[f], label=fn_roots[f], 
             color=clrs[f], linestyle=lnstl[f], marker=mrkrs[f])
plt.xlabel('K')
plt.ylabel('NMI')
#plt.legend()
plt.show()

np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', MIs_news20small)
np.save(f'{fn_rootroot}MIs_{n_repets}repets_seed_0_v{version}__K{min(K_range)}_{max(K_range)}', 
        {'MIs' : MIs_news20small,
         'algorithms' : fn_roots,
         'K_range' : K_range,
         'seed' : seed,
         'n_repets' : n_repets,
        })


# Bregman clustering contd'
## run all algorithms for spherical clustering on both datasets

In [None]:
from exps import run_all_classic3

version = '0'
classic300 = True

run_all_classic3(fn_root='results/classic300_', n_repets=10, K_range=[2,3,4,5,6,7,8,9,10,11], 
                 seed=0, max_iter=100, κ_max=10000., Ψ0=[None, 0.], version=version, 
                 classic300=classic300, verbose=True, min_df=2, max_df=0.15)


In [None]:
from exps import run_all_news20

version = '0'
news20_small = True

run_all_news20(fn_root='results/news20small_', n_repets=10, K_range=[4,8,12,16,20,24,28,32,36,40], 
                 seed=0, max_iter=100, κ_max=10000., Ψ0=[None, 0.], version=version, 
                 news20_small=news20_small, verbose=True, min_df=2, max_df=0.15)


In [None]:
from exps import run_all_classic3

version = '0'
classic300 = False

run_all_classic3(fn_root='results/classic3_', n_repets=10, K_range=[2,3,4,5,6,7,8,9,10,11], 
                 seed=0, max_iter=100, κ_max=10000., Ψ0=[None, 0.], version=version, 
                 classic300=classic300, verbose=True)


In [None]:
from exps import run_all_news20

version = '0'
news20_small = False

run_all_news20(fn_root='results/news20_', n_repets=10, K_range=[4,8,12,16,20,24,28,32,36,40], 
                 seed=0, max_iter=100, κ_max=10000., Ψ0=[None, 0.], version=version, 
                 news20_small=news20_small, verbose=True)


# numerical evaluation of the Negentropy computation and approximation

In [None]:
from vMFne.logpartition import gradΦ, invgradΦ, vMF_entropy_Φ, logχ, log_besseli, banerjee_44
from vMFne.negentropy import Ψ, Ψ_base, gradΨ, dΨ_base
from scipy.special import loggamma
import mpmath

import numpy as np
import matplotlib.pyplot as plt

Ds = [2, 10, 100, 1000]

t0 = 0.0

fig = plt.figure(figsize=(12, 10))
for i,D in enumerate(Ds):

    #Ψ0 = [0., invgradΦ(np.array([t0]),D,max_iter=10, atol=1e-12)[0][0]]
    Ψ0 = [0., 1e-6]

    K = 100
    V = np.ones((K,D))/np.sqrt(D)

    μs_norm = np.linspace(0, 1, K+2)[1:-1]
    _, gradΨμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=Ψ0, t0=t0, return_grad=True)

    κs = np.linalg.norm(gradΨμ,axis=-1)

    # re-compute ||μ|| = Φ'(||η||) to make sure H[η] is not at a disadvantage because η are bad- if anything, μ are bad ! 
    μs_norm = np.linalg.norm(gradΦ(κs.reshape(-1,1)*V),axis=-1)
    Ψμ, gradΨμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=Ψ0, t0=t0, return_grad=True, solve_delta=False)
    atol = 1e-12
    κs_est, diffs = invgradΦ(μs_norm,D,max_iter=10,atol=atol)
    κs_est_0, diffs = invgradΦ(μs_norm,D,max_iter=0,atol=atol)

    Ψμ_delta, gradΨμ_delta = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=Ψ0, t0=t0, return_grad=True)

    # NegEntropy
    plt.subplot(2,2,1)
    H = vMF_entropy_Φ(κs_est.reshape(-1,1) * V)
    Ψμ0 = (D/2. - 1.) * np.log(2) + loggamma(D/2)
    negH = - H - logχ(D) - Ψμ0
    plt.plot(μs_norm, negH, '-', label='D='+str(D))
    plt.plot(μs_norm, Ψ_base(μs_norm,D=D), '--')
    plt.plot(μs_norm, Ψμ_delta, ':')
    plt.legend()

    plt.subplot(2,2,3)
    plt.plot(μs_norm, (negH - Ψμ), ':', label='D='+str(D))
    plt.plot(μs_norm, (negH - Ψ_base(μs_norm,D=D)), '--')
    plt.plot(μs_norm, (negH - Ψμ_delta), ':')
    plt.legend()

    # Gradient of NegEntropy
    plt.subplot(2,2,2)
    plt.plot(μs_norm, np.linalg.norm(gradΨμ,axis=-1), label='D='+str(D))
    plt.plot(μs_norm, dΨ_base(μs_norm,D=D), '--')
    plt.plot(μs_norm, np.linalg.norm(gradΨμ_delta,axis=-1), ':')
    plt.legend()

    plt.subplot(2,2,4)
    #plt.plot(μs_norm, np.linalg.norm(gradΨμ,axis=-1) - banerjee_44(μs_norm,D=D), ':',
    #             label='D='+str(D))
    plt.plot(μs_norm, np.linalg.norm(gradΨμ,axis=-1) - dΨ_base(μs_norm,D=D), '--',
                 label='D='+str(D))
    plt.plot(μs_norm, np.linalg.norm(gradΨμ,axis=-1) - np.linalg.norm(gradΨμ_delta,axis=-1), ':')
    plt.legend()

plt.show()


In [None]:
plt.plot(μs_norm[:-1], np.diff(np.linalg.norm(gradΨμ,axis=-1) - banerjee_44(μs_norm,D=D))/ np.diff(μs_norm)[0] + μs_norm[:-1], ':')
plt.plot(μs_norm[:-1], 0.08 * np.sin(2*np.pi*μs_norm[:-1]))

In [None]:
from vMFne.logpartition import gradΦ, invgradΦ, vMF_entropy_Φ, logχ, log_besseli, banerjee_44
from vMFne.negentropy import Ψ, Ψ_base, gradΨ, dΨ_base
from scipy.special import loggamma
import mpmath

import numpy as np
import matplotlib.pyplot as plt

Ds = [10, 100, 1000, 25000]

t0 = 0.0

fig = plt.figure(figsize=(12, 10))
for i,D in enumerate(Ds):

    Ψ0 = [0., 1e-4]

    K = 100
    V = np.ones((K,D))/np.sqrt(D)

    μs_norm = np.linspace(0, 1, K+2)[1:-1]
    Ψμ, gradΨμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=Ψ0, t0=t0, return_grad=True, solve_delta=False)
    Ψμ_delta, gradΨμ_delta = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=Ψ0, t0=t0, return_grad=True, solve_delta=True)

    # NegEntropy
    plt.subplot(2,2,1)
    plt.plot(μs_norm, Ψμ, label='D='+str(D))
    plt.plot(μs_norm, Ψ_base(μs_norm,D=D), '--')
    plt.plot(μs_norm, Ψμ_delta, ':')
    plt.legend()

    plt.subplot(2,2,3)
    plt.plot(μs_norm, (Ψμ - Ψ_base(μs_norm,D=D)) / Ψμ, '--', label='D='+str(D))
    plt.plot(μs_norm, (Ψμ - Ψμ_delta) / Ψμ, ':')
    plt.legend()

    # Gradient of NegEntropy
    plt.subplot(2,2,2)
    plt.plot(μs_norm, np.linalg.norm(gradΨμ,axis=-1), label='D='+str(D))
    plt.plot(μs_norm, dΨ_base(μs_norm,D=D), '--')
    plt.plot(μs_norm, np.linalg.norm(gradΨμ_delta,axis=-1), ':')
    plt.legend()

    plt.subplot(2,2,4)
    plt.plot(μs_norm, (np.linalg.norm(gradΨμ,axis=-1) - dΨ_base(μs_norm,D=D)) / np.linalg.norm(gradΨμ,axis=-1), '--',
                 label='D='+str(D))
    plt.plot(μs_norm, (np.linalg.norm(gradΨμ,axis=-1) - np.linalg.norm(gradΨμ_delta,axis=-1))/np.linalg.norm(gradΨμ,axis=-1) , ':')
    plt.legend()

plt.show()


In [None]:
from vMFne.logpartition import gradΦ
import numpy as np
import matplotlib.pyplot as plt
import scipy

N = 3891
D = 4255

Ψ0 = [None, 0.]

#weights = np.array([0.10, 0.15, 0.2, 0.25, 0.3])
#kappas = D/50. * np.array([1., 10., 15., 25., 0.1])

weights = np.array([0.26548445, 0.37522488, 0.35929067])
kappas = D/50. * np.array([8., 10., 11.])

Ns = np.int32(np.round(N*weights))
K = len(weights)
class_true = np.concatenate([k * np.ones(Ns[k]) for k in range(K)])

mus = np.random.normal(size=(K,D))
mus = mus / np.linalg.norm(mus,axis=-1).reshape(-1,1)
mu_norms = np.linalg.norm(gradΦ(kappas.reshape(-1,1) * mus), axis=-1)

vmf = sample_vMF_Ulrich if D==3 else ()
x_vmf = []
for Nk,muk,kappak in zip(Ns, mus, kappas):
    x_vmf.append(scipy.stats.vonmises_fisher(mu=muk, kappa=kappak).rvs(Nk))
x_vmf = np.concatenate(x_vmf, axis=0)

if D == 3:
    from mpl_toolkits.mplot3d import Axes3D
    from sklearn.metrics import pairwise
    from matplotlib.pyplot import cm

    fig = plt.figure()
    u = np.linspace( 0, 2 * np.pi, 60)
    v = np.linspace( 0, np.pi, 30 )

    # create the sphere surface
    XX = np.outer( np.cos( u ), np.sin( v ) )
    YY = np.outer( np.sin( u ), np.sin( v ) )
    ZZ = np.outer( np.ones( np.size( u ) ), np.cos( v ) )
    locs = np.stack([XX.flatten(),YY.flatten(),ZZ.flatten()], axis=-1)

    d0 = 0.1
    WW_vmf = (pairwise.pairwise_distances(locs, x_vmf)<d0).sum(axis=-1)
    myheatmap_vmf = WW_vmf.reshape(len(u), len(v)) / WW_vmf.max()
    ax = fig.add_subplot( 1, 2, 2, projection='3d')
    ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_vmf ) )
    plt.title('von Mises-Fisher')
    plt.show()

from vMFne.moVMF import posterior_marginal_vMF_mixture_Φ
from vMFne.bregman_clustering import posterior_marginal_vMF_mixture_Ψ
from sklearn import metrics
from sklearn.metrics import confusion_matrix

def mi(class_true, class_est):
    cont = metrics.cluster.contingency_matrix(class_true, class_est, sparse=True).astype(np.float64, copy=False)
    mi = metrics.cluster.mutual_info_score(class_true, class_est, contingency=cont)
    return mi

ph_x_μ_true_Ψ, log_px_true_Ψ = posterior_marginal_vMF_mixture_Ψ(x_vmf,weights,mu_norms.reshape(-1,1)*mus, Ψ0=Ψ0)
LL_true_Ψ = log_px_true_Ψ.sum()

if D <= 1000:
    ph_x_μ_true_Φ, log_px_true_Φ = posterior_marginal_vMF_mixture_Φ(x_vmf,weights,kappas.reshape(-1,1)*mus)
    LL_true_Φ = log_px_true_Φ.sum()
    class_est_μ_true = np.argmax(ph_x_μ_true_Φ,axis=1)
    LL_true = LL_true_Φ
else:
    class_est_μ_true = np.argmax(ph_x_μ_true_Ψ,axis=1)
    LL_true = LL_true_Ψ
M_μ_true = confusion_matrix(class_true, class_est_μ_true)

plt.imshow(M_μ_true)
plt.colorbar()
plt.title('learned model - LL= ' + str(LL_true))
plt.ylabel('MI=' + str(mi(class_true.flatten(), class_est_μ_true.flatten())))
plt.show()

In [None]:
plt.figure(figsize=(16,6))

if D <= 1000:
    plt.subplot(1,3,2)
    plt.loglog(log_px_true_Ψ, log_px_true_Φ, '.')

    plt.subplot(1,3,3)
    plt.loglog(ph_x_μ_true_Ψ.flatten(), ph_x_μ_true_Φ.flatten(), '.')
    plt.subplot(1,3,1)
plt.plot(kappas, mu_norms.flatten(), 'o')

plt.show()

In [None]:
from vMFne.bregman_clustering import spherical_kmeans, softBregmanClustering_vMF
from vMFne.bregman_clustering import posterior_marginal_vMF_mixture_Ψ
from sklearn.metrics import confusion_matrix
import scipy

all_μs, all_w, all_LL = [], [], []
all_μs_kmean, all_w_kmean, all_c_kmean = [], [], []

n_repets = 10
for ii in range(n_repets):
    _, w, c = spherical_kmeans(X=x_vmf, K=K, max_iter=100, verbose=False)
    μs = np.stack([x_vmf[c==k].mean(axis=0) for k in range(K)],axis=0)
    all_c_kmean.append(1 * c)
    all_w_kmean.append(1. * w)
    all_μs_kmean.append(1. * μs)
    μs, w, LL = softBregmanClustering_vMF(X=x_vmf, K=K, max_iter=100, w_init=w, μs_init=μs, Ψ0=Ψ0, verbose=False)

    all_μs.append(μs)
    all_w.append(w)
    all_LL.append(LL)
    print(' - ' + str(ii+1) + '/' + str(n_repets))

plt.plot(np.stack(all_LL).T)
plt.plot([0, len(LL)], [LL_true, LL_true], 'k--')
plt.xlabel('iteration')
plt.ylabel('log-likelihood')
plt.show()

MIs = np.zeros(n_repets)
for i in range(n_repets):
    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    #class_est = all_c_kmean[i]
    μs = all_μs_kmean[i]
    w = all_w_kmean[i]
    ph_x, px = posterior_marginal_vMF_mixture_Ψ(x_vmf,w,μs, Ψ0=Ψ0)
    class_est = np.argmax(ph_x,axis=1)
    M = confusion_matrix(class_true, class_est)
    _, idx_class_align = scipy.optimize.linear_sum_assignment(-M.T)
    class_est_aligned = idx_class_align[class_est]
    M = confusion_matrix(class_true, class_est_aligned)    
    plt.imshow(M)
    plt.colorbar()
    plt.title('learned model - sph. K-means')
    MIs[i] = mi(class_true.flatten(), class_est.flatten()) 
    plt.ylabel('MI=' + str(MIs[i]))
    
    plt.subplot(1,2,2)
    w = all_w[i]
    μs = all_μs[i]
    LL = all_LL[i]
    ph_x, px = posterior_marginal_vMF_mixture_Ψ(x_vmf,w,μs, Ψ0=Ψ0)
    class_est = np.argmax(ph_x,axis=1)
    M = confusion_matrix(class_true, class_est)
    _, idx_class_align = scipy.optimize.linear_sum_assignment(-M.T)
    class_est_aligned = idx_class_align[class_est]
    M = confusion_matrix(class_true, class_est_aligned)    
    plt.imshow(M)
    plt.colorbar()
    plt.title('learned model - LL= ' + str(LL[-1]))
    MIs[i] = mi(class_true.flatten(), class_est.flatten()) 
    plt.ylabel('MI=' + str(MIs[i]))

    plt.show()

plt.plot(np.stack(all_LL,axis=0)[:,-1], MIs, 'o')
plt.plot([np.stack(all_LL,axis=0)[:,-1].min(), np.stack(all_LL,axis=0)[:,-1].max()], 
          np.ones(2) * mi(class_true.flatten(), class_est_μ_true.flatten()), 'k--')
plt.show()