In [81]:
from jax import vmap, jit
import jax.numpy as jnp
from jax.lax import lgamma
from jax.random import multivariate_normal, gamma, choice, permutation, PRNGKey

import numpy as np
import torch

In [82]:
def pdfMST(y, mu, A, D, nu):
    th2 = A * nu
    th1 = jnp.log(1 + (jnp.swapaxes(D, 1, 2)@jnp.expand_dims((y - mu), -1))[..., 0] ** 2 / th2)
    exponent = - (nu + 1) / 2
    
    main = exponent * th1
    
    gam1 = lgamma((nu + 1) / 2)
    gam2 = lgamma(nu / 2)
    th2 = gam1 - (gam2 + 0.5 * jnp.log(np.pi * th2))
    
    main += th2
    
    return jnp.exp(main.sum(1))

def pdfMMST(pi, MST=None, mu=None, A=None, D=None, nu=None):
    if MST is not None:
        return (pi * MST).sum()
    else:
        return (pi * pdfMST(y, mu, A, D, nu)).sum()
    
def sampleMST(N, mu, A, D, nu, seed=42):
    key = PRNGKey(seed)
    batch, M = mu.shape
    X = multivariate_normal(key, np.zeros(M), np.diag(np.ones(M)), (batch, N,), dtype=jnp.float32)
    
    # TODO comment tirer en batch sur numpy ?????
    W = torch.distributions.Gamma(torch.tensor(np.array(nu)) / 2, torch.tensor(np.array(nu)) / 2).sample((N,)).numpy()
    W = jnp.swapaxes(jnp.array(W), 0, 1)
    
    X /= jnp.sqrt(W)
    
    matA = vmap(jnp.diag)(jnp.sqrt(A))
    coef = D@matA
    
    gen = jnp.expand_dims(mu, 1) + jnp.swapaxes(coef@jnp.swapaxes(X, 2, 1), 1, 2)
    
    return gen 

def sampleMMST(N, pi, mu, A, D, nu, seed=42):
    key = PRNGKey(seed)
    classes = choice(key, len(pi), (N,), p=pi)
    
    gen = sampleMST(N, mu, A, D, nu)
    gen_mix = jnp.zeros((1, len(mu[0])))
  
    for k in range(len(pi)):
        gen_mix = jnp.concatenate((gen_mix, gen[k, classes == k, :]),)
    gen_mix = gen_mix[1:]
    permutation(key, gen_mix)
    return gen_mix


In [91]:
@jit
def alpha_beta(y, mu, A, D, nu):
    tmp = nu / 2 
    alpha = tmp + 0.5
    beta = tmp + (jnp.swapaxes(D, 1, 2)@jnp.expand_dims((y - mu), -1))[..., 0] ** 2 / (2 * A)
    return alpha, beta

@jit
def U(alpha, beta):
    return alpha / beta

@jit
def Utilde(alpha, beta):
    return digamma(alpha) - jnp.log(beta)

def updateStat(y, mu, A, D, nu, r, gam, stat):   
    stat['s0'] = gam * r  + (1 - gam) * stat['s0']
    
    alpha, beta = alpha_beta(y, mu, A, D, nu)
    u, utilde = U(alpha, beta), Utilde(alpha, beta)
    r = jnp.expand_dims(r, -1)
    ru, rutilde = r * u, r * utilde
    
    y_unsqueeze = jnp.expand_dims(y, -1)
    ymat = y_unsqueeze@y_unsqueeze.T
    
    stat['s1'] = gam * jnp.einsum('ij,k->ijk', ru , y, optimize=True) + (1 - gam) * stat['s1']
    stat['S2'] = gam * jnp.einsum('ij,kl->ijkl', ru , ymat, optimize=True) + (1 - gam) * stat['S2']
    stat['s3'] = gam * ru + (1 - gam) * stat['s3']
    stat['s4'] = gam * rutilde  + (1 - gam) * stat['s4']
    
    return stat

In [95]:
%timeit alpha_beta(gen_mix[0], mu, A, D, nu)

797 µs ± 38.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [83]:
pi = jnp.array([.1, .2, .3, .4], dtype=jnp.float32)
mu = jnp.array([[0, -6], [0, 0], [0, 6], [-6, 6]], dtype=jnp.float32)
angle = jnp.pi / 6
matRot = [[np.cos(angle), -np.sin(angle)], [np.sin(angle), jnp.cos(angle)]]
D = jnp.array([matRot, matRot, matRot, matRot], dtype=jnp.float32)
A = jnp.ones((4, 2), dtype=jnp.float32)
A = jnp.array([[2, 3], [1, 2.5], [5, 2], [1.5, 0.9]], dtype=jnp.float32)
nu = jnp.array([[1, 3], [1, 3], [1, 3], [1, 3]], dtype=jnp.float32)

In [89]:
gen_mix = sampleMMST(3000, pi, mu, A, D, nu)