In [None]:
from vMFne.logpartition import gradΦ, invgradΦ, vMF_entropy_Φ, logχ
from vMFne.negentropy import Ψ, gradΨ

import numpy as np
import matplotlib.pyplot as plt

Ds = [50, 500, 5000]

fig = plt.figure(figsize=(12, 12))
for i,D in enumerate(Ds):

    K = 100
    V = np.ones((K,D))/np.sqrt(D)

    μs_norm = np.linspace(0, 1, K+2)[1:-1]
    _, gradΨμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[None, 1e-5], return_grad=True)

    κs = np.linalg.norm(gradΨμ,axis=-1)

    # re-compute ||μ|| = Φ'(||η||) to make sure H[η] is not at a disadvantage because η are bad- if anything, μ are bad ! 
    μs_norm = np.linalg.norm(gradΦ(κs.reshape(-1,1)*V),axis=-1)
    Ψμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[None, 1e-5])
    
    atol = 1e-12
    κs_est, diffs = invgradΦ(μs_norm,D,max_iter=10,atol=atol)
    κs_est_0, diffs = invgradΦ(μs_norm,D,max_iter=0,atol=atol)

    H = vMF_entropy_Φ(κs_est.reshape(-1,1) * V) 
    H0 = vMF_entropy_Φ(κs_est_0.reshape(-1,1) * V) 
    plt.subplot(2, len(Ds), i+1)
    plt.plot(μs_norm, Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[None, 1e-5]), label='negative entropy Ψ(μ)')
    plt.plot(μs_norm, - H - logχ(D), '--', color='orange', label='-H[η(μ)] - log χ(x)')
    plt.plot(μs_norm, - H0 - logχ(D), '--', color='violet', label='-H[κ~(μ)] - log χ(x)')
    plt.xlabel("||μ|| ")
    plt.legend()
    plt.ylabel('D = ' + str(D))
    plt.subplot(2, len(Ds), i+1+len(Ds))
    plt.semilogy(μs_norm, np.abs(Ψμ + (H + logχ(D))), '--', color='orange', label='|Ψ(μ) - (-H[η(μ)] - log χ(x))|')
    plt.semilogy(μs_norm, np.abs(Ψμ + (H0 + logχ(D))), '--', color='violet', label='|Ψ(μ) - (-H[κ~(μ)] - log χ(x))|')
    plt.xlabel("||μ|| ")
    plt.legend()
plt.show()

In [None]:
# numerically check negEntropies vs negative entropies 

from vMFne.logpartition import log_besseli, ratio_besseli
from vMFne.negentropy import Ψ
import scipy.stats

def vMF_entropy_Φ(ηs):

    ηs = np.atleast_2d(ηs)
    K,D = ηs.shape
    κs = np.linalg.norm(ηs,axis=-1)    
    H = - (D/2.-1.) * np.log(κs) + D/2. * np.log(2.0*np.pi) 
    log_I = np.array([log_besseli(D/2.-1, κ) for κ in κs])
    ratio_I = np.array([ratio_besseli(D/2, κ) for κ in κs])
    H = H + log_I - κs * ratio_I

    return H

μs = 1. * μs_true
ηs = 1. * ηs_true

μs[0] *= 1e-9
ηs[0] *= 1e-9

Hsp = [scipy.stats.vonmises_fisher.entropy(mu=ηs[i]/np.linalg.norm(ηs,axis=1)[i], 
                                           kappa=np.linalg.norm(ηs,axis=1)[i]) for i in range(len(ηs))]

Ψ(μs, D=D) + vMF_entropy_Φ(ηs) - D/2 *np.log(2*np.pi)


# towards mean-parameterized (Hyper-)spherical VAEs
Quick idea to make something out of mean parameterization for hyperspherical VAEs:
- hyperspherical VAEs are defined by von Mises-Fisher p(z), q(z|x) and general (typically Gaussian) p(x|z).
- as such they require the reparametrization trick to get training gradients for q(z|x) from the ELBO
- reparametrizaition for von Mises-Fisher latents is known, but is i) cumbersome and ii) formulated in natural parameterization (one samples a univariate $\omega \sim p(\omega \ | \ \kappa = ||\eta||, D)$.
- we here try a quick idea for $D=2$ and $D=3$ whereafter one only approximately samples $q(z|x)$ by sampling $\tilde{z} \sim \mathcal{N}(\tilde{z}| \mu(x), \sigma_\mu^2)$, where $\mu(x)$ is the mean parameter of the vMF $q(z|x)$. Then $z = \tilde{z}/||\tilde{z}||$, which is differentiable almost surely. The question is for the best-approximating variance function $\sigma^2_\mu$, i.e. a function in $\mu(x)$ (or more sensibly in $||\mu(x)||$).
- for $D=2,3$, it seems that $\sigma^2_\mu = \frac{1-||\mu||^{(8-2D))}}{\sqrt(2\pi)}$ works quite well. 
- generalization to $D > 3$ currently unclear.

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import vonmises  
from matplotlib.pyplot import cm
from vMFne.utils_angular import cart2spherical, spherical_rotMat
from vMFne.sample import sample_vMF_Ulrich

D = 3
N = 1000000

def sigma2(norm_mu):
    c = 2
    renorm =  1./np.sqrt(2*np.pi) * (1 - norm_mu**((4-D)*c))
    return renorm

mu_base = np.array([0., 0.0, 1.0])[-D:].reshape(1,D)
mu_base = mu_base / np.sqrt( (mu_base**2).sum() )
mu_norms = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]

plt.figure(figsize=(8,16))
for jj in range(len(mu_norms)):
    mu = mu_norms[jj] * mu_base
    norm_mu = np.sqrt( (mu**2).sum() )

    # numerically approximate \grad\Psi(||mu||) = ||eta|| = kappa
    etas = np.linspace(0,100, 100000)
    target = norm_mu
    kappa = np.linalg.norm(gradΨ(μ=mu,D=D))

    # sample from Gaussian proposal
    renorm = sigma2(norm_mu)
    x = mu + np.random.normal(size=[N,D]) * np.sqrt(renorm)
    x_norm = x / np.sqrt((x**2).sum(axis=-1)).reshape(-1,1)

    phi_x = cart2spherical(x.T)

    plt.subplot(np.int32(np.ceil(len(mu_norms)/2.)), 2, jj+1)
    if D == 2:
        xx = np.linspace(-np.pi, np.pi, 100)
        phi = cart2spherical(x.T)
        h_x,bins_x = np.histogram(phi, bins=xx, density=True)
        phi_mu = np.arctan2(mu[...,1], mu[...,0])
        phi_vmf = np.mod(vonmises.rvs(kappa, size=N) + phi_mu + np.pi, 2*np.pi) - np.pi
        plt.hist(phi_vmf, bins=xx, density=True)
        plt.plot(bins_x[:-1]+np.diff(bins_x[:2])[0]/2, h_x)
    elif D == 3:
        xx = np.linspace(0, np.pi, 50)        
        x_vmf = sample_vMF_Ulrich(N=N, m=mu.flatten()/norm_mu, kappa=kappa)
        phi_vmf = np.mod(cart2spherical(x_vmf.T) + np.pi, 2*np.pi) - np.pi
        h_vmf,_ = np.histogram(phi_vmf[0], xx, density=True)
        h_x,_   = np.histogram(phi_x[0], xx, density=True)        
        plt.plot(xx[:-1] + (xx[1]-xx[0])/2., h_x, label='angles of Gaussian draws')
        plt.plot(xx[:-1] + (xx[1]-xx[0])/2., h_vmf, label='von Mises-Fisher distribution')
    plt.title(r'$||\mu||=' + "{:10.2f}".format(norm_mu) + ', \kappa=' + "{:10.2f}".format(kappa) + '$')
    if jj == 0:
        plt.ylabel('radial profiles of angles')
        plt.legend()

    """
    For 3D plotting (plotting on S^2 in 3D plots), code adapted from
    https://stackoverflow.com/questions/22128909/plotting-the-temperature-distribution-on-a-sphere-with-python
    """
    """
    from mpl_toolkits.mplot3d import Axes3D
    from sklearn.metrics import pairwise

    if D == 3:
        fig = plt.figure()

        u = np.linspace( 0, 2 * np.pi, 120)
        v = np.linspace( 0, np.pi, 60 )

        # create the sphere surface
        XX = np.outer( np.cos( u ), np.sin( v ) )
        YY = np.outer( np.sin( u ), np.sin( v ) )
        ZZ = np.outer( np.ones( np.size( u ) ), np.cos( v ) )
        locs = np.stack([XX.flatten(),YY.flatten(),ZZ.flatten()], axis=-1)

        d0 = 0.1
        WW_vmf = (pairwise.pairwise_distances(locs, x_vmf.T)<d0).sum(axis=-1)
        myheatmap_vmf = WW_vmf.reshape(len(u), len(v)) / WW_vmf.max()
        WW_x = (pairwise.pairwise_distances(locs, x_norm.T)<d0).sum(axis=-1)
        myheatmap_x = WW_vmf.reshape(len(u), len(v)) / WW_x.max()

        # ~ ax.scatter( *zip( *pointList ), color='#dd00dd' )
        ax = fig.add_subplot( 1, 2, 1, projection='3d')
        ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_x ) )
        plt.title('angles of Gaussian')

        ax = fig.add_subplot( 1, 2, 2, projection='3d')
        ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_vmf ) )
        plt.title('von Mises-Fisher')
        plt.show() 
    """
