In [1]:
import numpy as np
from numpy.random import multivariate_normal, gamma, choice, shuffle
import autograd.numpy as autonp


from scipy.special import loggamma, digamma
from scipy.optimize import brentq

from sklearn.cluster import KMeans

import multiprocessing
from joblib import Parallel, delayed
from itertools import permutations

import matplotlib.pyplot as plt

from tqdm import tqdm

import copy

import torch

from numba import jit

from GenStudentMixtures.Mixture_Multivariate_Student_Generalized import MMST

@jit(nopython=True)
def batch_diagonal(A):
    N = A.shape[1]
    A = np.expand_dims(A, axis=1)
    return A*np.eye(N)

# Distributions

In [2]:
def pdfMST(y, mu, A, D, nu):
    th2 = A * nu
    th1 = np.log(1 + (np.swapaxes(D, 1, 2)@np.expand_dims((y - mu), -1))[..., 0] ** 2 / th2)
    exponent = - (nu + 1) / 2
    
    main = exponent * th1
    
    gam1 = loggamma((nu + 1) / 2)
    gam2 = loggamma(nu / 2)
    th2 = gam1 - (gam2 + 0.5 * np.log(np.pi * th2))
    
    main += th2
    
    return np.exp(main.sum(1))

def pdfMMST(pi, MST=None, mu=None, A=None, D=None, nu=None):
    if MST is not None:
        return (pi * MST).sum()
    else:
        return (pi * pdfMST(y, mu, A, D, nu)).sum()

In [12]:
def sampleMST(N, mu, A, D, nu):
    
    batch, M = mu.shape
    X = multivariate_normal(np.zeros(M), cov=np.diag(np.ones(M)), size = (batch, N,))
    
    # TODO comment tirer en batch sur numpy ?????
    W = torch.distributions.Gamma(torch.tensor(nu) / 2, torch.tensor(nu) / 2).sample((N,)).numpy()
    W = np.swapaxes(W, 0, 1)
    
    X /= np.sqrt(W)
    
    matA = batch_diagonal(np.sqrt(A))
    coef = D@matA
    
    gen = np.expand_dims(mu, 1) + np.swapaxes(coef@np.swapaxes(X, 2, 1), 1, 2)
    
    return gen 

def sampleMMST(N, pi, mu, A, D, nu):
    classes = choice(len(pi), N, p=pi)
    
    gen = sampleMST(N, mu, A, D, nu)
    gen_mix = np.zeros((1, len(mu[0])))
  
    for k in range(len(pi)):
        gen_mix = np.concatenate((gen_mix, gen[k, classes == k, :]),)
    gen_mix = gen_mix[1:]
    shuffle(gen_mix)
    return gen_mix

# Update statistics

In [4]:
def alpha_beta(y, mu, A, D, nu):
    tmp = nu / 2 
    alpha = tmp + 0.5
    beta = tmp + (np.transpose(D, (0, 2, 1))@np.expand_dims((y - mu), -1))[..., 0] ** 2 / (2 * A)
    return alpha, beta

@jit(nopython=True)
def U(alpha, beta):
    return alpha / beta

def Utilde(alpha, beta):
    return digamma(alpha) - np.log(beta)

def updateStat(y, mu, A, D, nu, r, gam, stat):   
    stat['s0'] = gam * r  + (1 - gam) * stat['s0']
    
    alpha, beta = alpha_beta(y, mu, A, D, nu)
    u, utilde = U(alpha, beta), Utilde(alpha, beta)
    r = np.expand_dims(r, -1)
    ru, rutilde = r * u, r * utilde
    
    y_unsqueeze = np.expand_dims(y, -1)
    ymat = y_unsqueeze@y_unsqueeze.T
    
    stat['s1'] = gam * np.einsum('ij,k->ijk', ru , y, optimize=True) + (1 - gam) * stat['s1']
    stat['S2'] = gam * np.einsum('ij,kl->ijkl', ru , ymat, optimize=True) + (1 - gam) * stat['S2']
    stat['s3'] = gam * ru + (1 - gam) * stat['s3']
    stat['s4'] = gam * rutilde  + (1 - gam) * stat['s4']
    
    return stat

# Update parameters

In [5]:
# Update pi
def update_pi(s0):
    return s0 # / s0.sum()

# Update nu
def fun_nu(nukm, s3km, s4km):
    return s4km - s3km - digamma(nukm / 2) + np.log(nukm / 2) + 1

def update_nu(s3, s4):
    K, M = s3.shape
    new_nu = np.zeros((K, M))
    for k in range(K):
        for m in range(M):
            s3km, s4km = s3[k, m] , s4[k, m] 
            fun = lambda x : fun_nu(x, s3km, s4km)
            new_nu[k, m] = brentq(fun, .01, 100)
    return new_nu.astype(np.float64)

# Update mu
def update_mu(D, s1, s3):
    S3_inv = batch_diagonal(1 / s3)
    v = np.expand_dims(np.diagonal(np.transpose(D, (0, 2, 1))@np.transpose(s1, (0, 2, 1)), 0, -2, -1), -1)
    return (D@(S3_inv@v))[..., 0], v[..., 0]

# Update A
def update_A(D, v, S2, s3):
    tmp = np.swapaxes(D[:, None, ...], -2, -1)@S2
    tmp = tmp@D[:, None, ...]
    tmp = np.diagonal(tmp, 0, -2, -1)
    return np.diagonal(tmp, 0, -2, -1) - v ** 2 / s3

## D update pymanopt autodiff

In [6]:
import pymanopt
from pymanopt.manifolds import Stiefel
from pymanopt.solvers import ConjugateGradient

In [7]:
# Update D
def loss(D, s1k, S2k, s3k):
    loss = 0
    M = len(D[0])
    for m in range(M):
        tmp = s1k[m] / s3k[m]
        matQuad = (S2k[m] - np.expand_dims(tmp, -1)@np.expand_dims(s1k[m], -1).T)
        quadForm = D[:, m].T @ matQuad @ D[:, m]
        loss += np.log(quadForm)
    return loss  

def update_D(s1, S2, s3):
    def find_cost(s1k, S2k, s3k, manifold):
        @pymanopt.function.autograd(manifold)
        def cost(D):
            loss = 0
            M = len(D[0])
            E = np.eye(M)
            for m in range(M):
                tmp = s1k[m] / s3k[m]
                matQuad = (S2k[m] - np.expand_dims(tmp, -1)@np.expand_dims(s1k[m], -1).T)
                quadForm = (D@E[:,m]).T @ matQuad @ (D@E[:,m])
                loss += quadForm
            return loss
        return cost
    
    def opti_D(s1k, S2k, s3k):
        manifold = Stiefel(*s1k.shape)
        solver=ConjugateGradient(maxiter=4000)
        cost = find_cost(s1k, S2k, s3k, manifold)
        problem = pymanopt.Problem(manifold, cost, verbosity=0)
        return solver.solve(problem)
    
    d = (delayed(opti_D)(s1[k], S2[k], s3[k]) for k in range(len(s1)))
    D_tmp = np.array(Parallel(n_jobs=multiprocessing.cpu_count())(d))
    D_opt = np.zeros(D_tmp.shape)
    
    for k in range(len(D_tmp)):
        minim_permuted = np.inf
        for e in permutations(list(D_tmp[k].T)):
            D_permuted = np.vstack(e).T
            cost = loss(D_permuted, s1[k], S2[k], s3[k])
            if cost < minim_permuted:
                D_opt[k] = D_permuted.copy()
                minim_permuted = cost
    
    return D_opt

def updateParams(stat, D_prev):
    s0 = stat['s0']
    s1 = stat['s1'] / s0[:, None, None]
    S2 = stat['S2'] / s0[:, None, None, None]
    s3 = stat['s3'] / np.expand_dims(s0, -1)
    s4 = stat['s4'] / np.expand_dims(s0, -1)
    
    pi_new = update_pi(s0)  
    D_new = update_D(s1, S2, s3) 
    mu_new, v = update_mu(D_new, s1, s3)
    A_new = update_A(D_new, v, S2, s3)
    nu_new = update_nu(s3, s4)
    return pi_new, mu_new, A_new, D_new, nu_new

## D update pymanopt grad home

In [9]:
import pymanopt
from pymanopt.manifolds import Stiefel
from pymanopt.solvers import ConjugateGradient

In [10]:
# Update D
def loss(D, s1k, S2k, s3k):
    loss = 0
    M = len(D[0])
    for m in range(M):
        tmp = s1k[m] / s3k[m]
        matQuad = (S2k[m] - np.expand_dims(tmp, -1)@np.expand_dims(s1k[m], -1).T)
        quadForm = D[:, m].T @ matQuad @ D[:, m]
        loss += quadForm
    return loss  

def update_D(s1, S2, s3):
    def find_cost(s1k, S2k, s3k, manifold):
        @pymanopt.function.numpy(manifold)
        def cost(D):
            loss = 0
            M = len(D[0])
            for m in range(M):
                tmp = s1k[m] / s3k[m]
                matQuad = (S2k[m] - np.expand_dims(tmp, -1)@np.expand_dims(s1k[m], -1).T)
                quadForm = D[:, m].T @ matQuad @ D[:, m]
                loss += quadForm
            return loss
        
        @pymanopt.function.numpy(manifold)
        def grad(D):
            grad = np.zeros(D.shape)
            M = len(D[0])
            for m in range(M):
                tmp = s1k[m] / s3k[m]
                matQuad = (S2k[m] - np.expand_dims(tmp, -1)@np.expand_dims(s1k[m], -1).T)
                grad[m] = 2 * matQuad @ D[:, m]
            return grad.T
        return cost, grad
    
    def opti_D(s1k, S2k, s3k):
        manifold = Stiefel(*s1k.shape)
        solver = ConjugateGradient(maxiter=4000)
        cost, grad = find_cost(s1k, S2k, s3k, manifold)
        problem = pymanopt.Problem(manifold, cost, egrad=grad, verbosity=0)
        return solver.solve(problem)
    
    d = (delayed(opti_D)(s1[k], S2[k], s3[k]) for k in range(len(s1)))
    D_tmp = np.array(Parallel(n_jobs=multiprocessing.cpu_count())(d))

    D_opt = np.zeros(D_tmp.shape)
    for k in range(len(D_tmp)):
        minim_permuted = np.inf
        for e in permutations(list(D_tmp[k].T)):
            D_permuted = np.vstack(e).T
            cost = loss(D_permuted, s1[k], S2[k], s3[k])
            if cost < minim_permuted:
                D_opt[k] = D_permuted.copy()
                minim_permuted = cost

    return D_opt

def updateParams(stat, D_prev):
    s0 = stat['s0']
    s1 = stat['s1'] / s0[:, None, None]
    S2 = stat['S2'] / s0[:, None, None, None]
    s3 = stat['s3'] / np.expand_dims(s0, -1)
    s4 = stat['s4'] / np.expand_dims(s0, -1)
    
    pi_new = update_pi(s0)  
    D_new = update_D(s1, S2, s3) 
    mu_new, v = update_mu(D_new, s1, s3)
    A_new = update_A(D_new, v, S2, s3)
    nu_new = update_nu(s3, s4)
    return pi_new, mu_new, A_new, D_new, nu_new

In [36]:
(0.5 * stat['S2'] - np.expand_dims(mu_pred, -1) @ stat['s1'][:, :, None, :]) / A_pred

array([[[[ 1.89073126e-01,  5.94174152e-02, -9.74373226e-02],
         [-9.77522124e-03, -7.46409206e+00, -8.10082460e-01],
         [-1.05898699e-01,  9.42064077e-02,  9.30426027e-01]],

        [[ 8.14282408e-01, -1.68496362e+00, -1.00910753e-01],
         [-2.19118339e+00,  2.25220213e+01,  5.38620218e-01],
         [-1.15680227e-01,  5.61252217e-01,  4.16146729e-01]],

        [[ 2.64973509e+00,  7.05965538e+00,  2.89256784e-01],
         [-4.62499765e-01,  5.60294284e+00,  1.11892734e-01],
         [-6.77849117e-02, -7.03572813e-02,  2.59884593e-01]]],


       [[[ 1.29966218e-01, -1.24983928e-01,  1.23356521e-02],
         [-5.03907420e-01,  7.60161614e+00, -7.69324012e-01],
         [ 6.91333899e-03, -1.84987040e-01,  2.60443286e-01]],

        [[ 8.32225989e-01, -1.07665927e-01, -3.38920839e-02],
         [ 3.92484810e-01, -3.61540813e+00,  1.81120573e-01],
         [ 1.30038877e-02, -3.97388005e-01,  8.61519511e-02]],

        [[ 5.38264463e+00, -5.85156749e+00,  6.72248363e-0

In [37]:
tmp = mu_pred * np.expand_dims(stat['s3'], -1)
tmp = np.expand_dims(tmp, -1) @ mu_pred[:, :, None, :]

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

In [113]:
np.repeat(mat_mu[:,None,...], 3, 1).shape

(3, 3, 3, 3)

In [115]:
mat_mu[0] * stat['s3'][0, 0]

array([[ 1.99490182e-03, -2.12964245e-01,  2.88609181e-05],
       [-2.12964245e-01,  2.27348378e+01, -3.08102561e-03],
       [ 2.88609181e-05, -3.08102561e-03,  4.17540644e-07]])

In [135]:
np.repeat(mat_mu[:,None,...], 3, 1)[0] 

array([[[ 3.15236994e-03, -3.36528883e-01,  4.56064000e-05],
        [-3.36528883e-01,  3.59258880e+01, -4.86867695e-03],
        [ 4.56064000e-05, -4.86867695e-03,  6.59803184e-07]],

       [[ 3.15236994e-03, -3.36528883e-01,  4.56064000e-05],
        [-3.36528883e-01,  3.59258880e+01, -4.86867695e-03],
        [ 4.56064000e-05, -4.86867695e-03,  6.59803184e-07]],

       [[ 3.15236994e-03, -3.36528883e-01,  4.56064000e-05],
        [-3.36528883e-01,  3.59258880e+01, -4.86867695e-03],
        [ 4.56064000e-05, -4.86867695e-03,  6.59803184e-07]]])

In [127]:
stat['s3'][0]

array([0.63282605, 0.40823528, 0.47646178])

In [111]:
stat['s3'][0, 0]

0.63282605172886

In [110]:
np.repeat(stat['s3'][:,None,...], 3, 1)[0, 0]

array([0.63282605, 0.40823528, 0.47646178])

# 3D test

In [13]:
pi = np.array([.3, .5, .2], dtype=np.float64)
mu = np.array([[0, -6, 0], [0, 0, 0], [0, 6, 0]], dtype=np.float64)
angle = np.pi / 6
matRot = [[1, 0, 0], [0, np.cos(angle), -np.sin(angle)], [0, np.sin(angle), np.cos(angle)]]
D = np.array([matRot, matRot, matRot], dtype=np.float64)
A = np.ones((3, 3), dtype=np.float64)
nu = np.array([[1, 3, 5], [1, 3, 5], [1, 3, 5]], dtype=np.float64)


gen_mix = sampleMMST(50000, pi, mu, A, D, nu)
gen_mix = np.array(sorted(gen_mix, key=lambda e: abs(e.max()) < 10 and abs(e.min()) < 10)[::-1], dtype=np.float64)

In [14]:
# real estimation 
model = KMeans(4, max_iter=3000, tol=1e-5)
gen_mix_init = gen_mix[:500]
model.fit(gen_mix_init)
pi_pred = np.array([(model.labels_ == k).sum() / len(model.labels_) for k in range(3)], dtype=np.float64)
mu_pred = np.array([gen_mix_init[model.labels_ == k].mean(0) for k in range(3)], dtype=np.float64)
D_pred = np.array([np.eye(3), np.eye(3), np.eye(3)], dtype=np.float64)
A_pred = np.ones(A.shape, dtype=np.float64)
nu_pred = np.ones(nu.shape, dtype=np.float64)

shuffle(gen_mix)
stat = {'s0': np.zeros(len(pi)), 's1': np.zeros(D.shape), 'S2': np.zeros((*D.shape, mu.shape[-1])), 
             's3': np.zeros(A.shape), 's4': np.zeros(A.shape)}
m = 50
gam = (1-10e-10)*np.array([k for k in range(1, len(gen_mix) + 1)]) ** (-6/10)

    
D_hist, mu_hist, pi_hist, nu_hist, A_hist = [], [], [], [], []

for i in tqdm(range(0, len(gen_mix) - m, m)):
    stat_new = {'s0': np.zeros(len(pi)), 's1': np.zeros(D.shape), 'S2': np.zeros((*D.shape, mu.shape[-1])), 
             's3': np.zeros(A.shape), 's4': np.zeros(A.shape)}
    for k in range(m):
        y = gen_mix[i + k] 
        mst = pdfMST(y, mu_pred, A_pred, D_pred, nu_pred)
        r = pi_pred * mst / pdfMMST(pi_pred, mst)
        stat_tmp = updateStat(y, mu_pred, A_pred, D_pred, nu_pred, r, gam[i // m], copy.deepcopy(stat))
        stat_new['s0'] += stat_tmp['s0'] / m
        stat_new['s1'] += stat_tmp['s1'] / m
        stat_new['S2'] += stat_tmp['S2'] / m
        stat_new['s3'] += stat_tmp['s3'] / m
        stat_new['s4'] += stat_tmp['s4'] / m
    stat = copy.deepcopy(stat_new)
    pi_pred, mu_pred, A_pred, D_pred, nu_pred = updateParams(copy.deepcopy(stat), copy.deepcopy(D_pred))
    D_hist.append(D_pred)
    pi_hist.append(pi_pred)
    mu_hist.append(mu_pred)
    nu_hist.append(nu_pred)
    A_hist.append(A_pred)
    if (i // m) % 10 == 0:
        print(pi_pred)

  0%|          | 0/999 [00:04<?, ?it/s]


TypeError: __init__() got an unexpected keyword argument 'maxiter'

# 2D test

In [None]:
pi = np.array([.1, .2, .3, .4], dtype=np.float64)
mu = np.array([[0, -6], [0, 0], [0, 6], [-6, 6]], dtype=np.float64)
angle = np.pi / 6
matRot = [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
D = np.array([matRot, matRot, matRot, matRot], dtype=np.float64)
A = np.ones((4, 2), dtype=np.float64)
A = np.array([[2, 3], [1, 2.5], [5, 2], [1.5, 0.9]], dtype=np.float64)
nu = np.array([[1, 30], [1, 30], [1, 30], [1, 30]], dtype=np.float64)

gen_mix = sampleMMST(1000000, pi, mu, A, D, nu)
gen_mix = np.array(sorted(gen_mix, key=lambda e: abs(e.max()) < 10 and abs(e.min()) < 10)[::-1], dtype=np.float64)

# real estimation 
model = KMeans(4, max_iter=3000, tol=1e-5)
gen_mix_init = gen_mix[:500]
model.fit(gen_mix_init)

pi_pred = np.array([(model.labels_ == k).sum() / len(model.labels_) for k in range(4)], dtype=np.float64)
mu_pred = np.array([gen_mix_init[model.labels_ == k].mean(0) for k in range(4)], dtype=np.float64)
D_pred = np.array([np.eye(2), np.eye(2), np.eye(2), np.eye(2)], dtype=np.float64)
A_pred = np.ones(A.shape, dtype=np.float64)
nu_pred = 30 * np.ones(nu.shape, dtype=np.float64)

shuffle(gen_mix)
stat = {'s0': np.zeros(len(pi)), 's1': np.zeros(D.shape), 'S2': np.zeros((*D.shape, mu.shape[-1])), 
             's3': np.zeros(A.shape), 's4': np.zeros(A.shape)}
m = 50
gam = (1-10e-10)*np.array([k for k in range(1, len(gen_mix) + 1)]) ** (-6/10)

In [None]:
D_hist, mu_hist, pi_hist, nu_hist, A_hist, stat_hist = [], [], [], [], [], []

for i in tqdm(range(0, len(gen_mix) - m, m)):
    stat_new = {'s0': np.zeros(len(pi)), 's1': np.zeros(D.shape), 'S2': np.zeros((*D.shape, mu.shape[-1])), 
             's3': np.zeros(A.shape), 's4': np.zeros(A.shape)}
    for k in range(m):
        y = gen_mix[i + k] 
        mst = pdfMST(y, mu_pred, A_pred, D_pred, nu_pred)
        r = pi_pred * mst / pdfMMST(pi_pred, mst)
        stat_tmp = updateStat(y, mu_pred, A_pred, D_pred, nu_pred, r, gam[i // m], copy.deepcopy(stat))
        stat_new['s0'] += stat_tmp['s0'] / m
        stat_new['s1'] += stat_tmp['s1'] / m
        stat_new['S2'] += stat_tmp['S2'] / m
        stat_new['s3'] += stat_tmp['s3'] / m
        stat_new['s4'] += stat_tmp['s4'] / m
    stat = copy.deepcopy(stat_new)
    stat_hist.append(stat)
    pi_pred, mu_pred, A_pred, D_pred, nu_pred = updateParams(copy.deepcopy(stat), copy.deepcopy(D_pred))
    D_hist.append(D_pred)
    pi_hist.append(pi_pred)
    mu_hist.append(mu_pred)
    nu_hist.append(nu_pred)
    A_hist.append(A_pred)
    if (i // m) % 500 == 0:
        print(pi_pred)