Skip to content

Commit

Permalink
Merge pull request #5115 from YoshikawaMasashi/distributions/dirichlet
Browse files Browse the repository at this point in the history
Add Dirichlet distribution
  • Loading branch information
toslunar committed Sep 20, 2018
2 parents 8b75c50 + c9f9516 commit fc99d9d
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/distributions/__init__.py
Expand Up @@ -3,6 +3,7 @@
from chainer.distributions.bernoulli import Bernoulli # NOQA
from chainer.distributions.beta import Beta # NOQA
from chainer.distributions.categorical import Categorical # NOQA
from chainer.distributions.dirichlet import Dirichlet # NOQA
from chainer.distributions.laplace import Laplace # NOQA
from chainer.distributions.log_normal import LogNormal # NOQA
from chainer.distributions.multivariate_normal import MultivariateNormal # NOQA
Expand Down
104 changes: 104 additions & 0 deletions chainer/distributions/dirichlet.py
@@ -0,0 +1,104 @@
import numpy

import chainer
from chainer.backends import cuda
from chainer import distribution
from chainer.functions.array import expand_dims
from chainer.functions.math import digamma
from chainer.functions.math import exponential
from chainer.functions.math import lgamma
from chainer.functions.math import sum as sum_mod


def _lbeta(x):
return sum_mod.sum(lgamma.lgamma(x), axis=-1) \
- lgamma.lgamma(sum_mod.sum(x, axis=-1))


class Dirichlet(distribution.Distribution):

"""Dirichlet Distribution.
The probability density function of the distribution is expressed as
.. math::
p(x) = \\frac{\\Gamma(\\sum_{i=1}^{K} \\alpha_i)}
{\\prod_{i=1}^{K} \\Gamma (\\alpha_i)}
\\prod_{i=1}^{K} {x_i}^{\\alpha_i-1}
Args:
alpha(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Parameter of distribution.
"""

def __init__(self, alpha):
self.__alpha = chainer.as_variable(alpha)

@property
def alpha(self):
return self.__alpha

@property
def alpha0(self):
return sum_mod.sum(self.alpha, axis=-1)

@property
def batch_shape(self):
return self.alpha.shape[:-1]

@property
def entropy(self):
return _lbeta(self.alpha) \
+ (self.alpha0 - self.event_shape[0]) \
* digamma.digamma(self.alpha0) \
- sum_mod.sum((self.alpha - 1)
* digamma.digamma(self.alpha), axis=-1)

@property
def event_shape(self):
return self.alpha.shape[-1:]

def log_prob(self, x):
return - _lbeta(self.alpha) \
+ sum_mod.sum((self.alpha - 1) * exponential.log(x), axis=-1)

@property
def mean(self):
alpha0 = expand_dims.expand_dims(self.alpha0, axis=-1)
return self.alpha / alpha0

def sample_n(self, n):
obo_alpha = self.alpha.data.reshape(-1, self.event_shape[0])
xp = cuda.get_array_module(self.alpha)
if xp is numpy:
eps = [xp.random.dirichlet(
one_alpha, size=(n,)).astype(numpy.float32)
for one_alpha in obo_alpha]
else:
eps = [xp.random.dirichlet(
one_alpha, size=(n,)).astype(numpy.float32)
for one_alpha in obo_alpha]
eps = [xp.expand_dims(eps_, 0) for eps_ in eps]
eps = xp.swapaxes(xp.vstack(eps), 0, 1)
eps = eps.reshape((n,) + self.alpha.shape)
noise = chainer.Variable(eps)
return noise

@property
def support(self):
return '[0, 1]'

@property
def variance(self):
alpha0 = expand_dims.expand_dims(self.alpha0, axis=-1)
return self.alpha * (alpha0 - self.alpha) \
/ alpha0 ** 2 / (alpha0 + 1)


@distribution.register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(dist1, dist2):
return - _lbeta(dist1.alpha) + _lbeta(dist2.alpha) \
+ sum_mod.sum((dist1.alpha - dist2.alpha) * (
digamma.digamma(dist1.alpha)
- expand_dims.expand_dims(digamma.digamma(
dist1.alpha0), axis=-1)), axis=-1)
1 change: 1 addition & 0 deletions docs/source/reference/distributions.rst
Expand Up @@ -17,6 +17,7 @@ Distributions
chainer.distributions.Bernoulli
chainer.distributions.Beta
chainer.distributions.Categorical
chainer.distributions.Dirichlet
chainer.distributions.Laplace
chainer.distributions.LogNormal
chainer.distributions.MultivariateNormal
Expand Down
40 changes: 40 additions & 0 deletions tests/chainer_tests/distributions_tests/test_dirichlet.py
@@ -0,0 +1,40 @@
from chainer import distributions
from chainer import testing
import numpy


@testing.parameterize(*testing.product({
'shape': [(2, 3), ()],
'is_variable': [True, False],
'sample_shape': [(3, 2), ()],
}))
@testing.fix_random()
@testing.with_requires('scipy')
class TestDirichlet(testing.distribution_unittest):

scipy_onebyone = True

def setUp_configure(self):
from scipy import stats
self.dist = distributions.Dirichlet
self.scipy_dist = stats.dirichlet

self.test_targets = set([
"batch_shape", "entropy", "event_shape", "mean", "sample",
"support", "variance"])

alpha = numpy.random.uniform(
0, 10, self.shape + (3,)).astype(numpy.float32)
self.params = {"alpha": alpha}
self.scipy_params = {"alpha": alpha}
self.support = '[0, 1]'
self.event_shape = (3,)

def sample_for_test(self):
smp = numpy.random.normal(size=self.shape + (3,)).astype(numpy.float32)
smp = numpy.exp(smp)
smp /= numpy.expand_dims(smp.sum(axis=-1), axis=-1)
return smp


testing.run_module(__name__, __file__)
19 changes: 19 additions & 0 deletions tests/chainer_tests/distributions_tests/test_kldivergence.py
Expand Up @@ -55,6 +55,12 @@ def make_categorical_dist(self, is_gpu=False):
params = self.encode_params({"p": p}, is_gpu)
return distributions.Categorical(**params)

def make_dirichlet_dist(self, is_gpu=False):
alpha = numpy.random.uniform(
1, 10, self.shape + (3,)).astype(numpy.float32)
params = self.encode_params({"alpha": alpha}, is_gpu)
return distributions.Dirichlet(**params)

def make_laplace_dist(self, is_gpu=False):
loc = numpy.random.uniform(-1, 1, self.shape).astype(numpy.float32)
scale = numpy.exp(
Expand Down Expand Up @@ -149,6 +155,19 @@ def test_categorical_categorical_gpu(self):
dist2 = self.make_categorical_dist(True)
self.check_kl(dist1, dist2)

@testing.with_requires('scipy')
def test_dirichlet_dirichlet_cpu(self):
dist1 = self.make_dirichlet_dist()
dist2 = self.make_dirichlet_dist()
self.check_kl(dist1, dist2)

@testing.with_requires('scipy')
@attr.gpu
def test_dirichlet_dirichlet_gpu(self):
dist1 = self.make_dirichlet_dist(True)
dist2 = self.make_dirichlet_dist(True)
self.check_kl(dist1, dist2)

def test_laplace_laplace_cpu(self):
dist1 = self.make_laplace_dist()
dist2 = self.make_laplace_dist()
Expand Down

0 comments on commit fc99d9d

Please sign in to comment.