In [None]:
import numpy as np
from numpy.random import multivariate_normal, gamma, choice, shuffle
import autograd.numpy as autonp


from scipy.special import loggamma, digamma
from scipy.optimize import brentq

from sklearn.cluster import KMeans

import pymanopt
from pymanopt.manifolds import Stiefel
from pymanopt.solvers import TrustRegions, ConjugateGradient, SteepestDescent

import matplotlib.pyplot as plt

from tqdm import tqdm

import torch

In [None]:
def batch_diagonal(A):
    N = A.shape[1]
    A = np.expand_dims(A, axis=1)
    return A*np.eye(N)

## PDF

In [None]:
def pdfMST(y, μ, A, D, ν):
    """
    pdf of MST
    y: M
    μ: batch x M
    D: batch x MxM
    A: batch x M
    ν: batch x M
    """
    th2 = A * ν
    th1 = np.log(1 + (np.swapaxes(D, 1, 2)@np.expand_dims((y - μ), -1))[..., 0] ** 2 / th2)
    exponent = - (ν + 1) / 2
    
    main = exponent * th1
    
    gam1 = loggamma((ν + 1) / 2)
    gam2 = loggamma(ν / 2)
    th2 = gam1 - (gam2 + 0.5 * np.log(np.pi * th2))
    
    main += th2
    
    return np.exp(main.sum(1))

def pdfMMST(π, MST=None, μ=None, A=None, D=None, ν=None):
    """
    log pdf of mixture of MST
    y: M
    π: K
    μ: K x M
    D: K x MxM
    A: K x M
    ν: K x M
    """
    if MST is not None:
        return (π * MST).sum()
    else:
        return (π * pdfMST(y, μ, A, D, ν)).sum()

## Sampling

In [None]:
def sampleMST(N, μ, A, D, ν):
    """
    sampling from MST
    μ: batch x M
    D: batch x MxM
    A: batch x M
    ν: batch x M
    """
    
    batch, M = μ.shape
    X = multivariate_normal(np.zeros(M), cov=np.diag(np.ones(M)), size = (batch, N,))
    
    # TODO comment tirer en batch sur numpy ?????
    W = torch.distributions.Gamma(torch.tensor(ν) / 2, torch.tensor(ν) / 2).sample((N,)).numpy()
    W = np.swapaxes(W, 0, 1)
    
    X /= np.sqrt(W)
    
    matA = batch_diagonal(np.sqrt(A))
    coef = D@matA
    
    gen = np.expand_dims(μ, 1) + np.swapaxes(coef@np.swapaxes(X, 2, 1), 1, 2)
    
    return gen 

def sampleMMST(N, π, μ, A, D, ν):
    """
    sampling from MMST
    N: int
    π: K
    μ: K x M
    D: K x MxM
    A: K x M
    ν: K x M    
    """
    classes = choice(len(π), N, p=π)
    
    gen = sampleMST(N, μ, A, D, ν)
    gen_mix = np.zeros((1, len(μ[0])))
  
    for k in range(len(π)):
        gen_mix = np.concatenate((gen_mix, gen[k, classes == k, :]),)
    gen_mix = gen_mix[1:]
    shuffle(gen_mix)
    return gen_mix

In [None]:
π = np.array([.3, .5, .2], dtype=np.float64)
μ = np.array([[0, -6], [0, 0], [0, 6]], dtype=np.float64)
angle = np.pi / 6
matRot = [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
D = np.array([matRot, matRot, matRot], dtype=np.float64)
A = np.ones((3, 2), dtype=np.float64)
ν = np.array([[1, 30], [1, 30], [1, 30]], dtype=np.float64)

gen_mix = sampleMMST(4000, π, μ, A, D, ν)

In [None]:
plt.figure(figsize=(10, 5))

plt.scatter(gen_mix[:, 0], gen_mix[:, 1], s=1.9, c='green')

plt.xlim(-10, 10)
plt.xlabel('First dimension')

plt.ylim(-10, 10)
plt.ylabel('Second dimension')

plt.title('MMST')
plt.show()

In [None]:
#3D mixtures
π = np.array([.3, .5, .2], dtype=np.float64)
μ = np.array([[0, -6, 0], [0, 0, 0], [0, 6, 0]], dtype=np.float64)
angle = np.pi / 6
matRot = [[1, 0, 0], [0, np.cos(angle), -np.sin(angle)], [0, np.sin(angle), np.cos(angle)]]
D = np.array([matRot, matRot, matRot], dtype=np.float64)
A = np.ones((3, 3), dtype=np.float64)
ν = np.array([[1, 10, 30], [1, 10, 30], [1, 10, 30]], dtype=np.float64)
gen_mix = sampleMMST(4000, π, μ, A, D, ν)

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(projection='3d')

ax.scatter(gen_mix[:, 0], gen_mix[:, 1], gen_mix[:, 2], s=1, c='green')
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 10)
ax.set_zlim(-10, 10)

ax.set_title("3D mixtures")

ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")

plt.show()

## Updates EM

### Update Stats

In [None]:
def alpha_beta(y, μ, A, D, ν):
    """
    y: M
    μ: batch x M
    D: batch x MxM
    A: batch x M
    ν: batch x M
    """
    tmp = ν / 2
    
    α = tmp + 0.5
    β = tmp + (np.swapaxes(D, 1, 2)@np.expand_dims((y - μ), -1))[..., 0] ** 2 / (2 * A)
    
    return α, β

def U(α, β):
    "U: batch x M"
    return α / β

def Utilde(α, β):
    "Utilde: batch x M"
    return digamma(α) - np.log(β)

In [None]:
def updateStat(y, μ, A, D, ν, r, γ, stat):   
    """
    Update stats only
    y: M
    μ: K x M
    D: K x MxM
    A: K x M
    ν: K x M
    r: K
    γ: float
    stat: dict of stat
    """
    stat['s0'] = γ * r  + (1 - γ) * stat['s0']
    
    α, β = alpha_beta(y, μ, A, D, ν)
    u, utilde = U(α, β), Utilde(α, β)
    r = np.expand_dims(r, -1)
    ru, rutilde = r * u, r * utilde
    
    y_unsqueeze = np.expand_dims(y, -1)
    ymat = y_unsqueeze@y_unsqueeze.T
    
    stat['s1'] = γ * np.einsum('ij,k->ijk', ru , y) + (1 - γ) * stat['s1']
    stat['S2'] = γ * np.einsum('ij,kl->ijkl', ru , ymat) + (1 - γ) * stat['S2']
    stat['s3'] = γ * ru + (1 - γ) * stat['s3']
    stat['s4'] = γ * rutilde  + (1 - γ) * stat['s4']
    
    return stat

### Update params

In [None]:
# Update ν
def fun_ν(νkm, s3km, s4km):
    return s4km - s3km - digamma(νkm / 2) + np.log(νkm / 2) + 1

def update_ν(s3, s4):
    K, M = s3.shape
    new_ν = np.zeros((K, M))
    for k in range(K):
        for m in range(M):
            s3km, s4km = s3[k, m] , s4[k, m] 
            fun = lambda x : fun_ν(x, s3km, s4km)
            new_ν[k, m] = brentq(fun, .001, 100)
    return np.array(new_ν, dtype=np.float64)

In [None]:
# Update μ
def update_μ(D, s1, s3):
    S3_inv = batch_diagonal(1 / s3)
    v = np.expand_dims(np.diagonal(np.swapaxes(D, 1, 2)@np.swapaxes(s1, 1, 2), 0, -2, -1), -1)
    return (D@(S3_inv@v))[..., 0], v[..., 0]


In [None]:
# Update A
# version sans boucle ? efficace ? 
def update_A(D, v, S2, s3):
    tmp = np.swapaxes(D[:, None, ...], -2, -1)@S2
    tmp = tmp@D[:, None, ...]
    tmp = np.diagonal(tmp, 0, -2, -1)
    return np.diagonal(tmp, 0, -2, -1) - v ** 2 / s3

In [None]:
# Update π
def update_π(s0):
    return s0 # / stat['s0'].sum() à rajouter si on initialise pas avec des probas

In [None]:
# Update D

# objective
def fun(D, s1, S2, s3):
    K, M = s3.shape

    tmp = s1 / np.expand_dims(s3, -1)
    matOpt = S2 - np.expand_dims(tmp, -1)@s1[:, :, None, :]
    
    tmp = np.swapaxes(D[:, None, ...], -2, -1)@matOpt
    tmp = tmp@D[:, None, ...]
    tmp = np.diagonal(tmp, 0, -2, -1)
    value = np.diagonal(tmp, 0, -2, -1)
    return np.log(value).sum(-1)

#### package ####
def update_D(D, s1, S2, s3, solver=TrustRegions()):
    manifold = Stiefel(D.shape[1], D.shape[2])
    
    def find_cost(k, s1, S2, s3):
        @pymanopt.function.autograd(manifold)
        def cost(D):
            sum_all = 0
            M = len(D[0])
            for m in range(M):
                tmp = s1[k, m] / s3[k, m]
                matQuad = (S2[k, m] - np.expand_dims(tmp, -1)@np.expand_dims(s1[k, m], -1).T)
                quadForm = D[:, m].T @ matQuad @ D[:, m]
                sum_all += autonp.log(quadForm)
            return sum_all
        return cost
    
    D_new = np.zeros(D.shape)
    for k in range(len(D)):
        cost = find_cost(k, s1, S2, s3)
        problem = pymanopt.Problem(manifold, cost, verbosity=0)
        D_new[k] = solver.solve(problem)
    return D_new

In [None]:
def updateParams(y, π, μ, A, D, ν, stat):
    
    s0 = stat['s0']
    s1 = stat['s1'] / s0[:, None, None]
    S2 = stat['S2'] / s0[:, None, None, None]
    s3 = stat['s3'] / np.expand_dims(s0, -1)
    s4 = stat['s4'] / np.expand_dims(s0, -1)
    
    π_new = update_π(s0)
    D_new = update_D(D, s1, S2, s3)
    μ_new, v = update_μ(D_new, s1, s3)
    A_new = update_A(D_new, v, S2, s3)
    ν_new = update_ν(s3, s4)
    return π_new, μ_new, A_new, D_new, ν_new

## EM

In [None]:
# K, M = 3, 2 real data
π = np.array([.3, .5, .2], dtype=np.float64)
μ = np.array([[0, -6], [0, 0], [0, 6]], dtype=np.float64)
angle = np.pi / 6
matRot = [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
D = np.array([matRot, matRot, matRot], dtype=np.float64)
A = np.ones((3, 2), dtype=np.float64)
ν = np.array([[1, 30], [1, 30], [1, 30]], dtype=np.float64)

gen_mix = sampleMMST(10000, π, μ, A, D, ν)
gen_mix = np.array(sorted(gen_mix, key=lambda e: abs(e.max()) < 10 and abs(e.min()) < 10)[::-1], dtype=np.float64)

# real estimation 
model = KMeans(3, max_iter=3000, tol=1e-5)
gen_mix_init = gen_mix[:500]
model.fit(gen_mix_init)

π_pred = np.array([(model.labels_ == k).sum() / len(model.labels_) for k in range(3)], dtype=np.float64)
μ_pred = np.array([gen_mix_init[model.labels_ == k].mean(0) for k in range(3)], dtype=np.float64)
D_pred = np.array([np.eye(2), np.eye(2), np.eye(2)], dtype=np.float64)
A_pred = np.ones(A.shape, dtype=np.float64)
ν_pred = 30 * np.ones(ν.shape, dtype=np.float64)

In [None]:
#burnin
#shuffle(gen_mix)
import time
stat = {'s0': np.zeros(len(π)), 's1': np.zeros(D.shape), 'S2': np.zeros((*D.shape, μ.shape[-1])), 
         's3': np.zeros(A.shape), 's4': np.zeros(A.shape)}
γ = (1-10e-10)*np.array([k for k in range(1, len(gen_mix) + 1)]) ** (-6/10)

for i in range(500):
    y = gen_mix[i]
    mst = pdfMST(y, μ_pred, A_pred, D_pred, ν_pred)
    r = π_pred * mst / pdfMMST(π_pred, mst)
    stat = updateStat(y, μ_pred, A_pred, D_pred, ν_pred, r, γ[i], stat)

In [None]:
# mini batch to compute statistics ?
import copy
shuffle(gen_mix)
m = 50
for i in tqdm(range(0, 10000 - m, m)):
    π_pred, μ_pred, A_pred, D_pred, ν_pred = updateParams(y, π_pred, μ_pred, A_pred, D_pred, ν_pred, stat)
    stat_new = {'s0': np.zeros(len(π)), 's1': np.zeros(D.shape), 'S2': np.zeros((*D.shape, μ.shape[-1])), 
         's3': np.zeros(A.shape), 's4': np.zeros(A.shape)}
    for k in range(m):
        y = gen_mix[i + m]
        mst = pdfMST(y, μ_pred, A_pred, D_pred, ν_pred)
        r = π_pred * mst / pdfMMST(π_pred, mst)
        stat_tmp = updateStat(y, μ_pred, A_pred, D_pred, ν_pred, r, γ[i], stat)
        stat_new['s0'] += stat_tmp['s0']
        stat_new['s1'] += stat_tmp['s1']
        stat_new['S2'] += stat_tmp['S2']
        stat_new['s3'] += stat_tmp['s3']
        stat_new['s4'] += stat_tmp['s4']
    stat = copy.deepcopy(stat_new)
    stat['s0'] /= m
    stat['s1'] /= m
    stat['S2'] /= m
    stat['s3'] /= m
    stat['s4'] /= m

In [None]:
for i in tqdm(range(500, 1200)):
    y = gen_mix[i]
    mst = pdfMST(y, μ_pred, A_pred, D_pred, ν_pred)
    r = π_pred * mst / pdfMMST(π_pred, mst)
    π_pred, μ_pred, A_pred, D_pred, ν_pred = updateParams(y, π_pred, μ_pred, A_pred, D_pred, ν_pred, stat)
    stat = updateStat(y, μ_pred, A_pred, D_pred, ν_pred, r, γ[i], stat)

## Clustering


In [None]:
cluster_lab = np.zeros(len(gen_mix))

for i, y in enumerate(gen_mix):
    mst = pdfMST(y, μ_pred, A_pred, D_pred, ν_pred)
    r = π_pred * mst / pdfMMST(π_pred, mst)
    cluster_lab[i] = np.argmax(r)
    
cdict = {0: 'red', 1: 'blue', 2: 'green'}

fig = plt.figure(figsize=(10, 5))

for g in np.unique(cluster_lab):
    ix = np.where(cluster_lab == g)
    plt.scatter(gen_mix[:,0][ix], gen_mix[:,1][ix], c = cdict[int(g)], label = int(g), s = 1)
plt.xlim(-10, 10)
plt.ylim(-10, 10)
plt.show()