In [1]:
import jax
from jax import random, jit, grad
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_probability.substrates.jax as tfp

In [2]:
def norm(v):
    return jnp.linalg.norm(v, axis=-1, keepdims=True)
    
def normalised(v):
    return v / norm(v)

def random_points_on_sphere(key, shape):
    pts = jax.random.normal(key, shape)
    return normalised(pts)

def cosine_sim(a, b):
    a = normalised(a)
    b = normalised(b)
    if len(b.shape) == 1:
        # (D,) x (D,) -> 1
        return jnp.dot(a, b)
    else:
        # (D,) x (N, D)
        assert a.shape[0] == b.shape[1]
        return jnp.dot(a, b.T)

def mean_nll_vmf(mu, kappa, x):
    mu = normalised(mu)  # ensure normalised
    von_mises_fisher = tfp.distributions.VonMisesFisher(mean_direction=mu, concentration=kappa)
    nll = -von_mises_fisher.log_prob(x)
    return jnp.mean(nll)

def mean_nll_vmf_p(params, x):    
    mu = params['mu']
    kappa = jnp.exp(params['log_kappa'])
    return mean_nll_vmf(mu, kappa, x)

In [7]:
def fit_von_mises_fisher(x, train_steps: int, seed: int):
    dim = x.shape[-1]    
    key = jax.random.key(seed)    
    params = {
        'mu': random_points_on_sphere(key, shape=(dim,)),
        'log_kappa': jnp.log(jnp.array(1.0))
    }

    optimizer = optax.adam(learning_rate=1e-3)
    opt_state = optimizer.init(params)
    
    @jit
    def training_step(params, opt_state, x):
        grads = grad(mean_nll_vmf_p)(params, x)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    for _ in range(train_steps):
        params, opt_state = training_step(params, opt_state, x)

    fit_mu = normalised(params['mu'])
    fit_kappa = jnp.exp(params['log_kappa'])
    return fit_mu,  fit_kappa
            

decide a true mu for sampling training data as well as some other mu that will represent "out of distribution" data

In [17]:
k0, k1 = random.split(random.PRNGKey(123), 2)

D = 512

# define a true mean value 
true_mu = random_points_on_sphere(k0, (D))
true_kappa = 2.0  # concentration

# some other mu 
other_mu = random_points_on_sphere(k1, (D))

print('true_mu', true_mu.shape, norm(true_mu))
print('other_mu', other_mu.shape, norm(other_mu))

true_mu (512,) [1.]
other_mu (512,) [1.]


In [9]:
def sample(key, n_samples, mu, kappa):
    return tfp.distributions.VonMisesFisher(mean_direction=mu, concentration=kappa).sample(n_samples, seed=key)

n_train = 100_000
n_test  = 100

k = random.split(random.PRNGKey(456), 3)

x_train = sample(k[0], n_train, true_mu, true_kappa)
x_test = sample(k[1], n_test, true_mu, true_kappa)
x_other = sample(k[2], n_test, other_mu, true_kappa)

In [10]:
fit_mu, fit_kappa = fit_von_mises_fisher(x_train, train_steps=100, seed=123)

In [12]:
print("cosine sim; fit vs true ", cosine_sim(fit_mu, true_mu))
print("cosine sim; fit vs other", cosine_sim(fit_mu, other_mu))

cosine sim; fit vs true  0.7360561
cosine sim; fit vs other -0.040914513


In [13]:
x_train = np.load('data/cat_dog/1k/train/clip_embed_img.npy')
x_test_in_distribution = np.load('data/cat_dog/1k/test/clip_embed_img.npy')
x_test_out_of_distribution = np.load('data/open_images/1k/train/clip_embed_img.npy')[:100]

In [15]:
fit_mu, fit_kappa = fit_von_mises_fisher(x_train, train_steps=100, seed=234)

In [16]:
print("cosine sim; fit vs true ", jnp.mean(cosine_sim(fit_mu, x_test_in_distribution)))
print("cosine sim; fit vs other", jnp.mean(cosine_sim(fit_mu, x_test_out_of_distribution)))

cosine sim; fit vs true  0.5695675
cosine sim; fit vs other 0.16227296
