# Let's try out the vMF sampling in TensorFlow

In [None]:
import numpy as np
np.random.seed(12345)

import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

plt.rcParams.update({
    "text.usetex": True,
})

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import jax.random as jr
import jax.numpy as jnp

init_key = jr.key(123)

In [None]:
mu = jnp.array([1., 0, 1., 0, 0, 0, 0, 0, 0, 0])
mu /= jnp.linalg.norm(mu)
kappa = jnp.array(100.)
vmf = tfd.VonMisesFisher(mu, kappa)

In [None]:
key, sample_key = jr.split(init_key)
vmf.sample((5,), sample_key)

Now, let's take a bunch of samples and see how they are distributed:

In [None]:
kappa_list = [0., 10., 100.]

for kappa in kappa_list:
    vmf = tfd.VonMisesFisher(mu, kappa)

    key, sample_key = jr.split(key)
    samples = vmf.sample((1_000,), sample_key)

    fig, ax = plt.subplots(2, 5, figsize=(18, 8))
    for idx, a in enumerate(ax.reshape(-1)):
        a.hist(samples[:, idx], bins=25);
        a.set_xlim(-1, 1)
    fig.suptitle(f"Samples for $\\kappa={vmf.concentration}$", fontsize=16);

# Higher dimensions

Now we can try a higher-dimensional example to see if everything still works as intended.

In [None]:
mu = jnp.sin(jnp.arange(100) * 2 * jnp.pi/100)
mu /= jnp.linalg.norm(mu)
plt.plot(mu);
plt.title("vMF mean vector $\\mu$")
plt.xlabel("Dimension")
plt.ylabel("Coordinate value")

In [None]:
kappa_list = [0., 100., 1000.]

for kappa in kappa_list:
    vmf = tfd.VonMisesFisher(mu, kappa)

    key, sample_key = jr.split(key)
    samples = vmf.sample((1_000,), sample_key)

    fig, ax = plt.subplots(10, 10, figsize=(18, 8))
    for idx, a in enumerate(ax.reshape(-1)):
        a.hist(samples[:, idx], bins=25);
        a.set_xlim(-0.15, 0.15)
        a.set_yticks([])
    fig.suptitle(f"Samples for $\\kappa={vmf.concentration}$", fontsize=16);