In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt

from vMFne.logpartition import log_besseli, ratio_besseli

def besseli_sra(s,x, tol=1e-12):
    R = np.ones_like(x)
    t2 = 1 + 1/(12*s) + 1/(288*s**2) - (139/51840)/s**3
    t1 =  (x*np.exp(1)/(2.0*s))**s * np.sqrt(0.5*s/np.pi) / t2 
    M = np.ones_like(x)/s
    k = 2. # note this differs from Sra (2012) - to do an index shift k'<-k+1 below, so need to start from k'=2
    while np.any(R/M >= tol):
        R = R * (0.5*x)**2 / (k*(s+k))
        M = M + R
        k = k + 1.
    return t1 * M

v = 1000.0
x = 1000.0

np.log(besseli_sra(v,x, tol=1e-12)), log_besseli(v,x)


In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt

from vMFne.negentropy import banerjee_44
from vMFne.logpartition import A

def A_sra(κs, D, tol=1e-12):
    return besseli_sra(D/2.0, κs, tol=tol) / besseli_sra(D/2.0-1.0, κs, tol=tol)

def sra_2012(rbar,D,order=2,tol=1e-12):
    κ = banerjee_44(rbar,D)
    for order in range(order):
        Aκ = A_sra(κ,D,tol=tol)    
        κ = κ - (Aκ - rbar) / (1. - Aκ**2 - (D-1) * Aκ/κ )
    return κ

Ds = [100, 1000, 10000]

plt.figure(figsize=(12,4*len(Ds)))

for i,D in enumerate(Ds):
    plt.subplot(len(Ds),3,3*i+1)
    mu_norms = np.linspace(0,1, 102)[1:-1]
    plt.semilogy(mu_norms, banerjee_44(mu_norms,D), label='Banerjee (with mpmath)')
    plt.semilogy(mu_norms, sra_2012(mu_norms,D,order=10), label='Sra (2012)')
    plt.ylabel('D='+str(D))
    plt.title(r"$\Psi'(||\mu||)$")
    plt.legend()
    if i == len(Ds)-1:
        plt.xlabel(r'$||\mu||$')

    plt.subplot(len(Ds),3,3*i+2)
    ks = 10**np.linspace(-2, np.log(D)/np.log(10)+0.1, 100)
    plt.loglog(ks, A(ks,D), label='mpmath')
    plt.loglog(ks, A_sra(ks,D,tol=1e-12), label='Sra')
    plt.title(r"$A(\kappa)$")
    plt.legend()
    if i == len(Ds)-1:
        plt.xlabel(r'$||\eta||$')

    plt.subplot(len(Ds),3,3*i+3)
    plt.loglog(ks, np.array([log_besseli(D/2.0, k) for k in ks]), label='mpmath')
    plt.loglog(ks, np.log(besseli_sra(D/2.0, ks, tol=1e-12)), label='Sra')
    plt.title(r"$\log I_{D/2}(\kappa)$")
    plt.legend()
    if i == len(Ds)-1:
        plt.xlabel(r'$||\eta||$')
plt.suptitle('naive implementation of Sra (2012) for Bessel function and gradient of log-partition',y=0.95)
plt.savefig('naive_Sra_2012_D10_to_D1000.pdf', bbox_inches='tight')
plt.show()



# towards mean-parameterized (Hyper-)spherical VAEs
Quick idea to make something out of mean parameterization for hyperspherical VAEs:
- hyperspherical VAEs are defined by von Mises-Fisher p(z), q(z|x) and general (typically Gaussian) p(x|z).
- as such they require the reparametrization trick to get training gradients for q(z|x) from the ELBO
- reparametrizaition for von Mises-Fisher latents is known, but is i) cumbersome and ii) formulated in natural parameterization (one samples a univariate $\omega \sim p(\omega \ | \ \kappa = ||\eta||, D)$.
- we here try a quick idea for $D=2$ and $D=3$ whereafter one only approximately samples $q(z|x)$ by sampling $\tilde{z} \sim \mathcal{N}(\tilde{z}| \mu(x), \sigma_\mu^2)$, where $\mu(x)$ is the mean parameter of the vMF $q(z|x)$. Then $z = \tilde{z}/||\tilde{z}||$, which is differentiable almost surely. The question is for the best-approximating variance function $\sigma^2_\mu$, i.e. a function in $\mu(x)$ (or more sensibly in $||\mu(x)||$).
- for $D=2,3$, it seems that $\sigma^2_\mu = \frac{1-||\mu||^{(8-2D))}}{\sqrt(2\pi)}$ works quite well. 
- generalization to $D > 3$ currently unclear.

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import vonmises  
from matplotlib.pyplot import cm
from vMFne.utils_angular import cart2spherical, spherical_rotMat
from vMFne.sample import sample_vMF_Ulrich

D = 3
N = 1000000

def sigma2(norm_mu):
    c = 2
    renorm =  1./np.sqrt(2*np.pi) * (1 - norm_mu**((4-D)*c))
    return renorm

mu_base = np.array([0., 0.0, 1.0])[-D:].reshape(1,D)
mu_base = mu_base / np.sqrt( (mu_base**2).sum() )
mu_norms = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]

plt.figure(figsize=(8,16))
for jj in range(len(mu_norms)):
    mu = mu_norms[jj] * mu_base
    norm_mu = np.sqrt( (mu**2).sum() )

    # numerically approximate \grad\Psi(||mu||) = ||eta|| = kappa
    etas = np.linspace(0,100, 100000)
    target = norm_mu
    kappa = np.linalg.norm(gradΨ(μ=mu,D=D))

    # sample from Gaussian proposal
    renorm = sigma2(norm_mu)
    x = mu + np.random.normal(size=[N,D]) * np.sqrt(renorm)
    x_norm = x / np.sqrt((x**2).sum(axis=-1)).reshape(-1,1)

    phi_x = cart2spherical(x.T)

    plt.subplot(np.int32(np.ceil(len(mu_norms)/2.)), 2, jj+1)
    if D == 2:
        xx = np.linspace(-np.pi, np.pi, 100)
        phi = cart2spherical(x.T)
        h_x,bins_x = np.histogram(phi, bins=xx, density=True)
        phi_mu = np.arctan2(mu[...,1], mu[...,0])
        phi_vmf = np.mod(vonmises.rvs(kappa, size=N) + phi_mu + np.pi, 2*np.pi) - np.pi
        plt.hist(phi_vmf, bins=xx, density=True)
        plt.plot(bins_x[:-1]+np.diff(bins_x[:2])[0]/2, h_x)
    elif D == 3:
        xx = np.linspace(0, np.pi, 50)        
        x_vmf = sample_vMF_Ulrich(N=N, m=mu.flatten()/norm_mu, kappa=kappa)
        phi_vmf = np.mod(cart2spherical(x_vmf.T) + np.pi, 2*np.pi) - np.pi
        h_vmf,_ = np.histogram(phi_vmf[0], xx, density=True)
        h_x,_   = np.histogram(phi_x[0], xx, density=True)        
        plt.plot(xx[:-1] + (xx[1]-xx[0])/2., h_x, label='angles of Gaussian draws')
        plt.plot(xx[:-1] + (xx[1]-xx[0])/2., h_vmf, label='von Mises-Fisher distribution')
    plt.title(r'$||\mu||=' + "{:10.2f}".format(norm_mu) + ', \kappa=' + "{:10.2f}".format(kappa) + '$')
    if jj == 0:
        plt.ylabel('radial profiles of angles')
        plt.legend()

    """
    For 3D plotting (plotting on S^2 in 3D plots), code adapted from
    https://stackoverflow.com/questions/22128909/plotting-the-temperature-distribution-on-a-sphere-with-python
    """
    """
    from mpl_toolkits.mplot3d import Axes3D
    from sklearn.metrics import pairwise

    if D == 3:
        fig = plt.figure()

        u = np.linspace( 0, 2 * np.pi, 120)
        v = np.linspace( 0, np.pi, 60 )

        # create the sphere surface
        XX = np.outer( np.cos( u ), np.sin( v ) )
        YY = np.outer( np.sin( u ), np.sin( v ) )
        ZZ = np.outer( np.ones( np.size( u ) ), np.cos( v ) )
        locs = np.stack([XX.flatten(),YY.flatten(),ZZ.flatten()], axis=-1)

        d0 = 0.1
        WW_vmf = (pairwise.pairwise_distances(locs, x_vmf.T)<d0).sum(axis=-1)
        myheatmap_vmf = WW_vmf.reshape(len(u), len(v)) / WW_vmf.max()
        WW_x = (pairwise.pairwise_distances(locs, x_norm.T)<d0).sum(axis=-1)
        myheatmap_x = WW_vmf.reshape(len(u), len(v)) / WW_x.max()

        # ~ ax.scatter( *zip( *pointList ), color='#dd00dd' )
        ax = fig.add_subplot( 1, 2, 1, projection='3d')
        ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_x ) )
        plt.title('angles of Gaussian')

        ax = fig.add_subplot( 1, 2, 2, projection='3d')
        ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_vmf ) )
        plt.title('von Mises-Fisher')
        plt.show() 
    """


###### 