In [21]:
import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

from ott.geometry import costs
from ott.problems.linear import barycenter_problem
from ott.solvers.linear import continuous_barycenter
from ott.tools.gaussian_mixture import gaussian_mixture

In [22]:
dim = 2  # the dimension of the Gaussians
n_components = (2, 3, 5)  # the number of components of the GMMs
# the number of GMMs whose barycenter will be computed
n_gmms = len(n_components)
epsilon = 0.1  # the entropy regularization parameter
# generate the pseudo-random keys that will be needed
key = jax.random.PRNGKey(seed=0)
keys = jax.random.split(key, num=3)
alpha = 50.0  # the concentration parameter of Dirichlet
barycentric_weights = jax.random.dirichlet(
    keys[0], alpha=jnp.ones(n_gmms) * alpha
)
# Create the seeds for the random generation of each measure.
seeds = jax.random.randint(keys[1], shape=(n_gmms,), minval=0, maxval=100)

In [27]:
# Offsets for the means of each GMM
cs = jnp.array([[-20, -15], [60, -15], [50, 65]])
print(cs.shape)
ms = 0.1 * jnp.mean(cs, axis=1)
print(ms.shape)
# parameter that controls the covariance matrices
ss = jnp.array([4, 3, 5])
print(ss.shape)

(3, 2)
(3,)
(3,)


In [24]:
assert cs.shape[0] == n_gmms
assert ss.size == n_gmms
assert seeds.size == n_gmms
assert len(n_components) == n_gmms
assert jnp.mean(cs, axis=1).all() > 0
assert ss.all() > 0

In [28]:
gmm_generators = [
    gaussian_mixture.GaussianMixture.from_random(
        jax.random.PRNGKey(seeds[i]),
        n_components=n_components[i],
        n_dimensions=dim,
        stdev_cov=ss[i],
        stdev_mean=ms[i],
        ridge=cs[i],
    )
    for i in range(n_gmms)
]
print(type(gmm_generators[0]))

<class 'ott.tools.gaussian_mixture.gaussian_mixture.GaussianMixture'>


In [26]:
# get the means and covariances of the GMMs
means_covs = [
    (gmm_generators[i].loc, gmm_generators[i].covariance) for i in range(n_gmms)
]
print(len(means_covs), len(means_covs[0]), means_covs[0][0].shape, means_covs[0][1].shape)

3 2 (2, 2) (2, 2, 2)
