# The von Mises-Fisher distribution on $\mathbb{S}^2$

In [None]:
import math
from random import random
from typing import TypeAlias

import matplotlib.pyplot as plt
import torch
import torch.linalg as LA

from distributions import SphericalUniformPrior3D
from utils import spherical_mesh
from visualisations import pairplot, heatmap, scatter, scatter3d, line3d

Tensor: TypeAlias = torch.Tensor

π = math.pi

## Definition of the distribution

The von Mises-Fisher (vMF) family are simple distributions on the n-sphere with two parameters: a mean direction $\boldsymbol{\mu} \in \mathbb{S}^n$ and a concentration parameter $\kappa \in [0, \infty)$, defined for $\mathbf{x} \in \mathbb{S}^n$ by the density

$$
f_n(\mathbf{x} \mid \kappa, \boldsymbol{\mu}) \propto e^{\kappa \boldsymbol{\mu} \cdot \mathbf{x}}
\, .
$$

The normalisation is given by

$$
\int_{\mathbb{S}^n} \mathrm{d}\mathbf{x} \, e^{\kappa \boldsymbol{\mu} \cdot \mathbf{x}} = \frac{\kappa}{4\pi \sinh \kappa} \, .
$$

A more numerically stable variant is

$$
f_n(\mathbf{x} \mid \kappa, \boldsymbol{\mu})
= \begin{cases}
    \frac{1}{4\pi} & \kappa = 0 \, ,\\
    \frac{\kappa}{2\pi(1 - \exp(-2\kappa))} e^{\boldsymbol{\mu} \cdot \mathbf{x} - 1} & \kappa > 0 \, .
\end{cases}
$$

The $n=1$ family is called the *von Mises* distribution which is better expressed in terms of angles

$$
f_1(\phi \mid \kappa, \theta) \propto e^{\kappa \cos(\phi - \theta)}
$$

for $\phi = \mathrm{atan2}(x_2, x_1) \in [0, 2\pi)$ and $\theta = \mathrm{atan2}(\mu_2, \mu_1) \in [0, 2\pi)$.

In [None]:
def log_prob_vMF(x: Tensor, μ: Tensor, κ: float) -> Tensor:
    """Computes the log of the (un-normalised) vMF probability density for a sample of inputs."""
    return κ * torch.mv(x, μ.float())

## Inverse-CDF sampling

Let $(U, V)$ be a random vector distributed uniformly on the unit circle, and $W \in [-1, 1]$ a random variable distributed according to

$$
f_W(w) = \frac{\kappa}{2 \sinh \kappa} e^{\kappa w} \, .
$$

The random vector

$$
\mathbf{X} \equiv (X, Y, Z) = \left(\sqrt{1 - W}^2 U, \sqrt{1 - W}^2 V, W \right)
$$

is distributed according to the vMF distribution with $\boldsymbol{\mu} = (0, 0, 1)$.

Realisations of $W$ can be generated by inverting the cumulative distribution function $F_W(\xi)$, and sampling $\xi$ uniformly on $[0, 1]$,

$$
F_W^{-1}(\xi) = \kappa^{-1} \log \left( e^{-\kappa} + 2 \xi \sinh \kappa \right) \, .
$$

In [None]:
def sample_vMF(N: int, κ: float, μ: Tensor = Tensor([0, 0, 1])) -> Tensor:

    if not isinstance(μ, Tensor):
        μ = Tensor(μ)
    μ.div_(LA.vector_norm(μ))

    # Uniform part
    ϕ = torch.rand(N) * 2 * π
    u = ϕ.cos()
    v = ϕ.sin()

    # Unimodal part
    xi = torch.rand(N)
    w = torch.log(math.exp(-κ) + 2 * xi * math.sinh(κ)) / κ
    ρ = (1 - w**2).sqrt()

    x, y, z = u * ρ, v * ρ, w

    x = torch.stack([x, y, z], dim=1)

    # TODO: rotation
    # ...

    return x

x = sample_vMF(10, 1)

assert x.shape == torch.Size([10, 3])
assert torch.allclose(x.pow(2).sum(dim=1), torch.ones(x.shape[0]))

In [None]:
N = 5000
κ = 5

fig = plt.figure(figsize=(10, 10))
ax1 = fig.add_subplot(121, projection="3d")
ax2 = fig.add_subplot(122, projection="3d")
ax1.set_axis_off()
ax2.set_axis_off()

# Create a spherical mesh and 'unwrap' it to a cylinder
X, Y, Z = spherical_mesh(100)
RHO = (1 - Z**2).sqrt()
U, V, W = X / RHO, Y / RHO, Z
#ax1.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.1)
#ax2.plot_surface(U, V, W, rstride=1, cstride=1, alpha=0.1)

# Sample a bunch of 2-d vMF variables using the procedure outlined above

# Uniform on the circle
ϕ = torch.rand(N) * 2 * π
u = ϕ.cos()
v = ϕ.sin()

# Unimodal part
ξ = torch.rand(N)
w = torch.log(math.exp(-κ) + 2 * ξ * math.sinh(κ)) / κ

# Wrap onto sphere
ρ = (1 - w**2).sqrt()
x, y, z = u * ρ, v * ρ, w

ax1.scatter3D(x, y, z, s=1, c="black")
ax2.scatter3D(u, v, w, s=1, c="black")

plt.show()

## Visualisations

In [None]:
x = sample_vMF(N=10000, κ=10)
scatter(x, s=2)
heatmap(x)
scatter3d(x, s=2)
line3d(x, marker="", ls="--", lw=0.1)
pairplot(x)