In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt

from vMFne.logpartition import log_besseli, ratio_besseli
from vMFne.negentropy import banerjee_44
from vMFne.logpartition import A

def besseli_sra(s,x, tol=1e-12):
    R = np.ones_like(x)
    t2 = 1 + 1/(12*s) + 1/(288*s**2) - (139/51840)/s**3
    t1 =  (x*np.exp(1)/(2.0*s))**s * np.sqrt(0.5*s/np.pi) / t2 
    M = np.ones_like(x)/s
    k = 2. # note this differs from Sra (2012) - to do an index shift k'<-k+1 below, so need to start from k'=2
    while np.any(R/M >= tol):
        R = R * (0.5*x)**2 / (k*(s+k))
        M = M + R
        k = k + 1.
    return t1 * M

def A_sra(κs, D, tol=1e-12):
    return besseli_sra(D/2.0, κs, tol=tol) / besseli_sra(D/2.0-1.0, κs, tol=tol)

def sra_2012(rbar,D,order=2,tol=1e-12):
    κ = banerjee_44(rbar,D)
    for order in range(order):
        Aκ = A_song(κ,D,tol=tol)
        f = Aκ - rbar
        df = 1. - Aκ**2 - (D-1) * Aκ/κ
        κ = κ - f / df        
    return κ

def logbesseli_song(s, x, tol=1e-12):

    c0 = 1.000000000190015
    c1 = 76.18009172947146
    c2 = -86.50532032941677
    c3 = 24.01409824083091
    c4 = -1.231739572450155
    c5 = 0.001208650973866179
    c6 = -0.000005395239384953
    cs = [c0, c1, c2, c3, c4, c5, c6]

    t1 = c0 + sum([cs[i] / (s+i) for i in range(1,len(cs))])
    t2 = (s+0.5) * np.log(s+5.5) - s - 5.5
    t1 = s * np.log(0.5*x) - 0.5 * np.log(2.0*np.pi) - t2 - np.log(t1)
    R, M, k = 1.0, 1.0, 1

    while np.any(R/M >= tol):
        R = R * (0.5*x)**2 / (k*(s+k))
        M = M + R
        k = k + 1.
    return t1 + np.log(M)

def A_song(κs, D, tol=1e-12):
    return np.exp(logbesseli_song(D/2.0, κs, tol=tol) - logbesseli_song(D/2.0-1.0, κs, tol=tol))

def song_2012(rbar,D,order=2,tol=1e-12):
    κ = banerjee_44(rbar,D)
    for order in range(order):
        Aκ = A_song(κ,D,tol=tol)
        f = Aκ - rbar
        df = 1. - Aκ**2 - (D-1) * Aκ/κ
        ddf = 2. * Aκ**3 + 3. * (D-1) * Aκ**2/κ + (D*(D-1)/κ**2 - 2.0) * Aκ - (D-1)/κ
        κ = κ - 2. * f * df / (2. * df**2 - f * ddf)
    return κ

def pcf_rec(u,v,w,ρ,x,xp,s,k):
    if k==1:
        v = s + x + 0.5
        u = (s + x) * v
        w = xp * (s + 0.5)
        ρ = w / ((s+xp) * v - w)
    else:
        u = u + v
        v = v + 0.5
        w = w + xp
        t = w * (1. + ρ)
        ρ = t / (u - t)
    return u,v,w,ρ

def A_pcf(x,D,K):
    A = np.empty_like(x)
    idx = x >= 1e-6
    s = D/2.
    xp = 0.5 * x[idx]
    p,psum = 1.0, 1.0
    u,v,w,ρ = None, None, None, None
    k = 1
    not_converged = True
    while not_converged:
        u,v,w,ρ = pcf_rec(u,v,w,ρ,x[idx],xp,s,k)
        p = ρ * p
        psum = psum + p 
        k = k+1
        not_converged = k < K

    A[idx] = psum / (1. + 2. * s/x[idx])
    nidx = np.invert(idx)
    if np.any(nidx):
        A[nidx] = x[nidx] / D - x[nidx]**3 / (D**2 * (D + 2)) + 2. * x[nidx]**5 / (D**3 * (D + 2) * (D + 4))
    return A

def hornik_2014(rbar,D,order=2,K=5):
    κ = banerjee_44(rbar,D)
    for order in range(order):
        Aκ = A_pcf(κ,D,K=K)
        f = Aκ - rbar
        df = 1. - Aκ**2 - (D-1) * Aκ/κ
        ddf = 2. * Aκ**3 + 3. * (D-1) * Aκ**2/κ + (D*(D-1)/κ**2 - 2.0) * Aκ - (D-1)/κ
        κ = κ - 2. * f * df / (2. * df**2 - f * ddf)
    return κ

def logbesseli_hornik(v,x):
    x2v12 = np.sqrt(x**2 + (v+1)**2)
    return x2v12 + (v+0.5)*np.log((2.0*v+1.5)*x/((v+0.5+x2v12)*(2.0*v+2.0))) - np.log(x/2.)/2. - np.log(2.0*np.pi)/2


In [None]:


Ds = [100, 1000, 10000]

plt.figure(figsize=(12,4*len(Ds)))

for i,D in enumerate(Ds):
    plt.subplot(len(Ds),3,3*i+1)
    mu_norms = np.linspace(0,1, 102)[1:-1]
    plt.semilogy(mu_norms, banerjee_44(mu_norms,D), label='Banerjee (with mpmath)')
    plt.semilogy(mu_norms, sra_2012(mu_norms,D,order=10), label='Sra (2012)')
    plt.semilogy(mu_norms, song_2012(mu_norms,D,order=10), label='Song (2012)')
    plt.semilogy(mu_norms, hornik_2014(mu_norms,D,order=2,K=5), label='Hornik (2012)')

    plt.ylabel('D='+str(D))
    plt.title(r"$\Psi'(||\mu||)$")
    plt.legend()
    if i == len(Ds)-1:
        plt.xlabel(r'$||\mu||$')

    plt.subplot(len(Ds),3,3*i+2)
    ks = 10**np.linspace(-2, np.log(D)/np.log(10)+0.1, 100)
    plt.loglog(ks, A(ks,D), label='mpmath')
    plt.loglog(ks, A_sra(ks,D,tol=1e-12), label='Sra (2012)')
    plt.loglog(ks, A_song(ks,D,tol=1e-12), label='Song (2012)')
    plt.loglog(ks, A_pcf(ks,D,K=5), label='Hornik (2014)')
    plt.title(r"$A(\kappa)$")
    plt.legend()
    if i == len(Ds)-1:
        plt.xlabel(r'$||\eta||$')

    plt.subplot(len(Ds),3,3*i+3)
    plt.loglog(ks, np.array([log_besseli(D/2.0, k) for k in ks]), label='mpmath')
    plt.loglog(ks, np.log(besseli_sra(D/2.0, ks, tol=1e-12)), label='Sra (2012)')
    plt.loglog(ks, logbesseli_song(D/2.0, ks, tol=1e-12), label='Song (2012)')
    plt.loglog(ks, logbesseli_hornik(D/2.0, ks), label='Hornik (2014)')
    plt.title(r"$\log I_{D/2}(\kappa)$")
    plt.legend()
    if i == len(Ds)-1:
        plt.xlabel(r'$||\eta||$')
plt.suptitle('naive implementation of Sra (2012) for Bessel function and gradient of log-partition',y=0.95)
plt.savefig('naive_Sra_2012_D10_to_D1000.pdf', bbox_inches='tight')
plt.show()



In [None]:
from vMFne.negentropy import Ψ, Ψ_base, gradΨ, dΨ_base

D = 10
K = 100

μs_norm = np.linspace(0, 1, K+2)[1:-1]
V = np.ones((K,D))/np.sqrt(D) # unit length vectors, some functions need vector-valued inputs

def psi_hornik_u(μ_norm, D):
    sqrt = np.sqrt((D+1)**2 -  4.*D*μ_norm**2)
    offset = (D-1)/2. * np.log(2.) + (D+1)/2.
    return (1.-D)/2. * np.log(sqrt + 1.-D) - 0.5 * sqrt + offset

def psi_hornik_l1(μ_norm, D):
    sqrt = np.sqrt((D+2)**2 - 8.*D*μ_norm**2)
    offset = (D-2)/2. * np.log(4.) + (D+2)/2.
    return (2.-D)/2. * np.log(sqrt + 2. - D) - 0.5 * sqrt + offset

def psi_hornik_l2(μ_norm, D):
    sqrt = np.sqrt((D - 1.) * (D- 2.*μ_norm**2 + 1.))
    offset = 0.5 * np.sqrt(D**2-1.) + (D-1)/2. * np.log( np.sqrt(D**2 - 1.) + 1. - D)
    return (1.-D)/2. * np.log(sqrt + 1. - D) - 0.5 * sqrt + offset

def psi_hornik_l(μ_norm, D):
    crossing = np.sqrt((3*D**2 + 8*D + 5)/(3*D+1)**2) # below this, l2 is better, above l1 is
    μ_norm1  = μ_norm[μ_norm<crossing]
    psi1 = psi_hornik_l1(μ_norm1, D) 
    μ_norm2  = μ_norm[μ_norm>=crossing]
    if len(μ_norm2)>0:
        c2 = psi_hornik_l1(crossing, D)
        c1 = psi_hornik_l2(crossing, D)
        psi2 = psi_hornik_l2(μ_norm2, D) + c2 - c1 
    else:
        psi2 = []
    return np.concatenate((psi1,psi2))

def relfac_psi(μs_norm, D):
    crossing = 0.6 #np.sqrt((3*D**2 + 8*D + 5)/(3*D+1)**2)
    f1 = 0.5/(μs_norm[μs_norm<=crossing]**2.25)
    f2 = 1.5*np.exp(3.8*np.abs(μs_norm[μs_norm>crossing]-2/3)**1.5)
    return np.concatenate((f1,f2))

def comp_hornik_w_opt_heuristic_psi(μs_norm, D):
    return 1/(1. + relfac_psi(μs_norm, D))

def psi_hornik(μ_norm, D, w=None):
    if w is None:
        w = comp_hornik_w_opt_heuristic_psi(μ_norm, D)
    return w * psi_hornik_u(μ_norm, D) + (1.-w) * psi_hornik_l(μ_norm, D)

def antiblunt(μ_norm, D):
    sqrt = np.sqrt(D**2/4. - 2.*D + 2.)
    v = D/2. - 1.
    μ_norm2 = μ_norm**2
    return (D-1.) * ((np.log(v + μ_norm2 - sqrt) - np.log(v + μ_norm2 + sqrt))/(4.*sqrt) - 0.5 * np.log(1.-μ_norm2))

Ψμ, _ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[0., 1e-5], t0=0., return_grad=True)
Ψμ_app, _ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[0., 0.0], t0=0., return_grad=True)

 
w = comp_hornik_w_opt_heuristic_psi(μs_norm, D)

plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
plt.semilogy(μs_norm, psi_hornik_u(μs_norm,D), label='antidev of Hornik (2014), upper')
plt.semilogy(μs_norm, psi_hornik_l(μs_norm,D), label='antidev of Hornik (2014), max(lower)')
plt.semilogy(μs_norm, psi_hornik(μs_norm,D,w=w), label='antidev of Hornik (2014), avg')
plt.semilogy(μs_norm, Ψμ_app, ':', label='approx. ODE')
plt.semilogy(μs_norm, antiblunt(μs_norm,D)-antiblunt(0,D), 'o--', label='blunt')

plt.legend()

banerjee_44 = 0.5 *((1-D) * np.log(1 - μs_norm**2) + μs_norm**2)
plt.semilogy(μs_norm, banerjee_44,  label='Banerjee (4.4)')

plt.subplot(1,3,2)
plt.semilogy(μs_norm, np.abs(Ψμ - psi_hornik_u(μs_norm,D)), label='antidev of Hornik (2014), upper')
plt.semilogy(μs_norm, np.abs(Ψμ - psi_hornik_l(μs_norm,D)),  label='antidev of Hornik (2014), max(lower)')
plt.semilogy(μs_norm, np.abs(Ψμ - psi_hornik(μs_norm,D,w=w)), label='antidev of Hornik (2014), avg')
plt.semilogy(μs_norm, np.abs(Ψμ - Ψμ_app), ':', label='approx. ODE')
plt.semilogy(μs_norm, np.abs(Ψμ - banerjee_44),  label='Banerjee (4.4)')
plt.semilogy(μs_norm, np.abs(Ψμ - antiblunt(μs_norm,D)+antiblunt(0,D)), 'o--', label='blunt')
plt.legend()

plt.subplot(1,3,3)
plt.semilogy(μs_norm, np.abs(Ψμ - psi_hornik_u(μs_norm,D))/Ψμ, label='antidev of Hornik (2014), upper')
plt.semilogy(μs_norm, np.abs(Ψμ - psi_hornik_l(μs_norm,D))/Ψμ,  label='antidev of Hornik (2014), max(lower)')
plt.semilogy(μs_norm, np.abs(Ψμ - psi_hornik(μs_norm,D,w=w))/Ψμ, label='antidev of Hornik (2014), avg')
plt.semilogy(μs_norm, np.abs(Ψμ - Ψμ_app)/Ψμ, ':', label='approx. ODE')
plt.semilogy(μs_norm, np.abs(Ψμ - antiblunt(μs_norm,D)+antiblunt(0,D))/Ψμ, 'o--', label='blunt')

plt.legend()


In [None]:
from vMFne.negentropy import banerjee_44

D = 1000
K = 100

μs_norm = np.linspace(0, 1, K+2)[1:-1]
V = np.ones((K,D))/np.sqrt(D) # unit length vectors, some functions need vector-valued inputs

def Ginv(ρ, a, b):
    return ρ/(1.-ρ**2) * (a + np.sqrt( b**2 + (a**2-b**2) * ρ**2))

def dPsi_hornik_u(μ_norm, D):
    return Ginv(μ_norm, D/2-0.5, D/2+0.5)

def dPsi_hornik_l1(μ_norm, D):
    return Ginv(μ_norm, D/2-0.5, np.sqrt(D**2-1)/2)

def dPsi_hornik_l2(μ_norm, D):
    return Ginv(μ_norm, D/2-1, D/2+1)

def dPsi_hornik_l(μ_norm, D):
    crossing = np.sqrt((3*D**2 + 8*D + 5)/(3*D+1)**2) # below this, l2 is better, above l1 is

    μ_norm1  = μ_norm[μ_norm<crossing]
    if len(μ_norm1)>0:
        dPsi1 = dPsi_hornik_l2(μ_norm1, D) 
    else:
        dPsi1 = []

    μ_norm2  = μ_norm[μ_norm>=crossing]
    if len(μ_norm2)>0:
        dPsi2 = dPsi_hornik_l1(μ_norm2,D)
    else:
        dPsi2 = []
    return np.concatenate((dPsi1,dPsi2))

def relfac_dpsi(μs_norm, D):
    crossing = np.sqrt((3*D**2 + 8*D + 5)/(3*D+1)**2)
    p1 = 1.5
    p2 = 1.0
    p3 = 4.5
    return crossing**p1*(1-crossing)**p2*np.exp(p3*np.abs(crossing-μs_norm)**1)/(μs_norm**p1 * (1-μs_norm)**p2)


def w_opt_dpsi(μ_norm,D):
    ku = dPsi_hornik_u(μ_norm, D)
    kl = dPsi_hornik_l(μ_norm, D)
    μ_norm2 =  μ_norm**2
    dpsidmu2_est = 1. +  (D-1.) * (1 + μ_norm2) / (1. - μ_norm2)**2 # banerjee 4.4
    w_opt = ((D-1)*μ_norm/(1 - μ_norm**2 - 1/dpsidmu2_est) - ku) / (kl-ku)
    return 1 - w_opt

def comp_hornik_w_opt_heuristic_dpsi(μs_norm, D):
    return 1/(1. + relfac_dpsi(μs_norm, D))

def dPsi_hornik(μ_norm, D, w=None):
    if w is None:
        w = w_opt_dpsi(μ_norm, D)
    return w * dPsi_hornik_u(μ_norm, D) + (1.-w) * dPsi_hornik_l(μ_norm, D)

w = w_opt_dpsi(μs_norm, D)

_, dΨμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[0., 1e-5], t0=0., return_grad=True)
_, dΨμ_app = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[0., 0.0], t0=0., return_grad=True)
dΨμ, dΨμ_app = np.linalg.norm(dΨμ,axis=-1), np.linalg.norm(dΨμ_app,axis=-1)

plt.figure(figsize=(16,6))
plt.subplot(1,3,1)
plt.semilogy(μs_norm, dΨμ, ':', label='ODE')
plt.semilogy(μs_norm, dPsi_hornik(μs_norm,D,w=w), label='Hornik (2014), avg')
plt.semilogy(μs_norm, dΨμ_app, ':', label='approx. ODE')
plt.semilogy(μs_norm, banerjee_44(μs_norm,D), ':', label='Banerjee eq. (4.4)')
#dpsidmu2_est = ((D-3) * μs_norm**2 + D + μs_norm**4) / (1-μs_norm**2)**2
#blunt = (D-1)*μs_norm/(1 - μs_norm**2 - 1/dpsidmu2_est)
#blunt = (D-1)*μs_norm/(1 - μs_norm**2) * ((D-3) * μs_norm**2 + D + μs_norm**4)  / ((D-2) * μs_norm**2 + D - 1+ μs_norm**4) 
blunt = (D-1)*(μs_norm/(1 - μs_norm**2) + μs_norm/(μs_norm**4 + (D-2) * μs_norm**2 + D -1 ))

dpsidmu2_est = (D-1) * ((2*μs_norm**2)/(1 - μs_norm**2)**2 + 1/(1 - μs_norm**2) - (μs_norm* (2 *(D-2) * μs_norm + 4 * μs_norm**3))/(D-1 + (D-2) * μs_norm**2 + μs_norm**4)**2 + 1/(D-1 + (D-2) * μs_norm**2 + μs_norm**4))
blunter = (D-1)*μs_norm/(1 - μs_norm**2 - 1/dpsidmu2_est)

plt.semilogy(μs_norm, blunt, label='blunt')
plt.semilogy(μs_norm, blunter, label='blunter')


plt.legend()

plt.subplot(1,3,2)
plt.semilogy(μs_norm, np.abs(dΨμ - dPsi_hornik(μs_norm,D,w=w)), label='Hornik (2014), avg')
plt.semilogy(μs_norm, np.abs(dΨμ - dΨμ_app), ':', label='approx. ODE')
plt.semilogy(μs_norm, np.abs(dΨμ - banerjee_44(μs_norm,D)),  label='Banerjee eq. (4.4)')
plt.semilogy(μs_norm, np.abs(dΨμ - blunt),  label='blunt')
plt.semilogy(μs_norm, np.abs(dΨμ - blunter),  label='blunter')
plt.legend()

plt.subplot(1,3,3)
plt.semilogy(μs_norm, np.abs(dΨμ - dPsi_hornik(μs_norm,D,w=w))/dΨμ, label='Hornik (2014), avg')
plt.semilogy(μs_norm, np.abs(dΨμ - dΨμ_app)/dΨμ, ':', label='approx. ODE')
plt.semilogy(μs_norm, np.abs(dΨμ - banerjee_44(μs_norm,D))/dΨμ,  label='Banerjee eq. (4.4)')
plt.semilogy(μs_norm, np.abs(dΨμ - blunt)/dΨμ,  label='blunt')
plt.semilogy(μs_norm, np.abs(dΨμ - blunter)/dΨμ,  label='blunter')

plt.show()

In [None]:
plt.semilogy(μs_norm, blunter, label='blunter')


In [None]:
K = 100
μs_norm = np.linspace(0, 1, K+2)[1:-1]

Ds = [2, 10, 100, 1000]
plt.figure(figsize=(12,12))

for D in Ds:
    V = np.ones((K,D))/np.sqrt(D) # unit length vectors, some functions need vector-valued inputs
    _, dΨμ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[0., 1e-5], t0=0., return_grad=True)
    dΨμ = np.linalg.norm(dΨμ,axis=-1)

    plt.subplot(1,2,1)
    plt.semilogy(μs_norm, (dPsi_hornik_u(μs_norm,D) - dΨμ)/(dΨμ - dPsi_hornik_l(μs_norm,D)), label='upper, D='+str(D) )
    plt.semilogy(μs_norm, w_opt(μs_norm,D)/(1-w_opt(μs_norm,D)), 'k:')
    plt.subplot(1,2,2)
    plt.plot(μs_norm, (dPsi_hornik_u(μs_norm,D) - dΨμ)/(dΨμ - dPsi_hornik_l(μs_norm,D)), label='upper, D='+str(D) )
plt.subplot(1,2,1)
plt.subplot(1,2,2)
#plt.plot(μs_norm, relfac_dpsi(μs_norm,D), 'k--', label='bam')
plt.legend()

In [None]:
K = 100
μs_norm = np.linspace(0, 1, K+2)[1:-1]

Ds = [2, 10, 100, 1000]
plt.figure(figsize=(12,12))
for D in Ds:
    V = np.ones((K,D))/np.sqrt(D) # unit length vectors, some functions need vector-valued inputs
    Ψμ, _ = Ψ(μs_norm.reshape(-1,1)*V, D, Ψ0=[0., 1e-4], t0=0., return_grad=True)
    plt.subplot(1,2,1)
    plt.semilogy(μs_norm, (psi_hornik_u(μs_norm,D) - Ψμ)/(Ψμ - psi_hornik_l(μs_norm,D)), label='upper, D='+str(D) )
    plt.subplot(1,2,2)
    plt.plot(μs_norm, (psi_hornik_u(μs_norm,D) - Ψμ)/(Ψμ - psi_hornik_l(μs_norm,D)), label='upper, D='+str(D) )
plt.subplot(1,2,1)
plt.semilogy(μs_norm, relfac(μs_norm,D), 'k--', label='bam')
plt.subplot(1,2,2)
plt.plot(μs_norm, relfac(μs_norm,D), 'k--', label='bam')
plt.legend()

# towards mean-parameterized (Hyper-)spherical VAEs
Quick idea to make something out of mean parameterization for hyperspherical VAEs:
- hyperspherical VAEs are defined by von Mises-Fisher p(z), q(z|x) and general (typically Gaussian) p(x|z).
- as such they require the reparametrization trick to get training gradients for q(z|x) from the ELBO
- reparametrizaition for von Mises-Fisher latents is known, but is i) cumbersome and ii) formulated in natural parameterization (one samples a univariate $\omega \sim p(\omega \ | \ \kappa = ||\eta||, D)$.
- we here try a quick idea for $D=2$ and $D=3$ whereafter one only approximately samples $q(z|x)$ by sampling $\tilde{z} \sim \mathcal{N}(\tilde{z}| \mu(x), \sigma_\mu^2)$, where $\mu(x)$ is the mean parameter of the vMF $q(z|x)$. Then $z = \tilde{z}/||\tilde{z}||$, which is differentiable almost surely. The question is for the best-approximating variance function $\sigma^2_\mu$, i.e. a function in $\mu(x)$ (or more sensibly in $||\mu(x)||$).
- for $D=2,3$, it seems that $\sigma^2_\mu = \frac{1-||\mu||^{(8-2D))}}{\sqrt(2\pi)}$ works quite well. 
- generalization to $D > 3$ currently unclear.

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import vonmises  
from matplotlib.pyplot import cm
from vMFne.utils_angular import cart2spherical, spherical_rotMat
from vMFne.sample import sample_vMF_Ulrich

D = 3
N = 1000000

def sigma2(norm_mu):
    c = 2
    renorm =  1./np.sqrt(2*np.pi) * (1 - norm_mu**((4-D)*c))
    return renorm

mu_base = np.array([0., 0.0, 1.0])[-D:].reshape(1,D)
mu_base = mu_base / np.sqrt( (mu_base**2).sum() )
mu_norms = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]

plt.figure(figsize=(8,16))
for jj in range(len(mu_norms)):
    mu = mu_norms[jj] * mu_base
    norm_mu = np.sqrt( (mu**2).sum() )

    # numerically approximate \grad\Psi(||mu||) = ||eta|| = kappa
    etas = np.linspace(0,100, 100000)
    target = norm_mu
    kappa = np.linalg.norm(gradΨ(μ=mu,D=D))

    # sample from Gaussian proposal
    renorm = sigma2(norm_mu)
    x = mu + np.random.normal(size=[N,D]) * np.sqrt(renorm)
    x_norm = x / np.sqrt((x**2).sum(axis=-1)).reshape(-1,1)

    phi_x = cart2spherical(x.T)

    plt.subplot(np.int32(np.ceil(len(mu_norms)/2.)), 2, jj+1)
    if D == 2:
        xx = np.linspace(-np.pi, np.pi, 100)
        phi = cart2spherical(x.T)
        h_x,bins_x = np.histogram(phi, bins=xx, density=True)
        phi_mu = np.arctan2(mu[...,1], mu[...,0])
        phi_vmf = np.mod(vonmises.rvs(kappa, size=N) + phi_mu + np.pi, 2*np.pi) - np.pi
        plt.hist(phi_vmf, bins=xx, density=True)
        plt.plot(bins_x[:-1]+np.diff(bins_x[:2])[0]/2, h_x)
    elif D == 3:
        xx = np.linspace(0, np.pi, 50)        
        x_vmf = sample_vMF_Ulrich(N=N, m=mu.flatten()/norm_mu, kappa=kappa)
        phi_vmf = np.mod(cart2spherical(x_vmf.T) + np.pi, 2*np.pi) - np.pi
        h_vmf,_ = np.histogram(phi_vmf[0], xx, density=True)
        h_x,_   = np.histogram(phi_x[0], xx, density=True)        
        plt.plot(xx[:-1] + (xx[1]-xx[0])/2., h_x, label='angles of Gaussian draws')
        plt.plot(xx[:-1] + (xx[1]-xx[0])/2., h_vmf, label='von Mises-Fisher distribution')
    plt.title(r'$||\mu||=' + "{:10.2f}".format(norm_mu) + ', \kappa=' + "{:10.2f}".format(kappa) + '$')
    if jj == 0:
        plt.ylabel('radial profiles of angles')
        plt.legend()

    """
    For 3D plotting (plotting on S^2 in 3D plots), code adapted from
    https://stackoverflow.com/questions/22128909/plotting-the-temperature-distribution-on-a-sphere-with-python
    """
    """
    from mpl_toolkits.mplot3d import Axes3D
    from sklearn.metrics import pairwise

    if D == 3:
        fig = plt.figure()

        u = np.linspace( 0, 2 * np.pi, 120)
        v = np.linspace( 0, np.pi, 60 )

        # create the sphere surface
        XX = np.outer( np.cos( u ), np.sin( v ) )
        YY = np.outer( np.sin( u ), np.sin( v ) )
        ZZ = np.outer( np.ones( np.size( u ) ), np.cos( v ) )
        locs = np.stack([XX.flatten(),YY.flatten(),ZZ.flatten()], axis=-1)

        d0 = 0.1
        WW_vmf = (pairwise.pairwise_distances(locs, x_vmf.T)<d0).sum(axis=-1)
        myheatmap_vmf = WW_vmf.reshape(len(u), len(v)) / WW_vmf.max()
        WW_x = (pairwise.pairwise_distances(locs, x_norm.T)<d0).sum(axis=-1)
        myheatmap_x = WW_vmf.reshape(len(u), len(v)) / WW_x.max()

        # ~ ax.scatter( *zip( *pointList ), color='#dd00dd' )
        ax = fig.add_subplot( 1, 2, 1, projection='3d')
        ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_x ) )
        plt.title('angles of Gaussian')

        ax = fig.add_subplot( 1, 2, 2, projection='3d')
        ax.plot_surface( XX, YY,  ZZ, cstride=1, rstride=1, facecolors=cm.jet( myheatmap_vmf ) )
        plt.title('von Mises-Fisher')
        plt.show() 
    """


###### 