-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
dirichlet.py
122 lines (99 loc) · 3.37 KB
/
dirichlet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy
import chainer
from chainer import distribution
from chainer.functions.array import expand_dims
from chainer.functions.math import digamma
from chainer.functions.math import exponential
from chainer.functions.math import lgamma
from chainer.functions.math import sum as sum_mod
from chainer.utils import cache
def _lbeta(x):
return (
sum_mod.sum(lgamma.lgamma(x), axis=-1)
- lgamma.lgamma(sum_mod.sum(x, axis=-1)))
class Dirichlet(distribution.Distribution):
"""Dirichlet Distribution.
The probability density function of the distribution is expressed as
.. math::
p(x) = \\frac{\\Gamma(\\sum_{i=1}^{K} \\alpha_i)}
{\\prod_{i=1}^{K} \\Gamma (\\alpha_i)}
\\prod_{i=1}^{K} {x_i}^{\\alpha_i-1}
Args:
alpha(:class:`~chainer.Variable` or :ref:`ndarray`): Parameter of
distribution.
"""
def __init__(self, alpha):
self.__alpha = alpha
@cache.cached_property
def alpha(self):
return chainer.as_variable(self.__alpha)
@cache.cached_property
def alpha0(self):
return sum_mod.sum(self.alpha, axis=-1)
@property
def batch_shape(self):
return self.alpha.shape[:-1]
@cache.cached_property
def entropy(self):
return (
_lbeta(self.alpha)
+ ((self.alpha0 - self.event_shape[0])
* digamma.digamma(self.alpha0))
- sum_mod.sum(
(self.alpha - 1) * digamma.digamma(self.alpha),
axis=-1))
@property
def event_shape(self):
return self.alpha.shape[-1:]
def log_prob(self, x):
return (
- _lbeta(self.alpha)
+ sum_mod.sum(
(self.alpha - 1) * exponential.log(x),
axis=-1))
@cache.cached_property
def mean(self):
alpha0 = expand_dims.expand_dims(self.alpha0, axis=-1)
return self.alpha / alpha0
@property
def params(self):
return {'alpha': self.alpha}
def sample_n(self, n):
obo_alpha = self.alpha.data.reshape(-1, self.event_shape[0])
xp = chainer.backend.get_array_module(self.alpha)
if xp is numpy:
eps = [
xp.random.dirichlet(one_alpha, size=(n,)).astype(numpy.float32)
for one_alpha in obo_alpha]
else:
eps = [
xp.random.dirichlet(one_alpha, size=(n,)).astype(numpy.float32)
for one_alpha in obo_alpha]
eps = [xp.expand_dims(eps_, 0) for eps_ in eps]
eps = xp.swapaxes(xp.vstack(eps), 0, 1)
eps = eps.reshape((n,) + self.alpha.shape)
noise = chainer.Variable(eps)
return noise
@property
def support(self):
return '[0, 1]'
@cache.cached_property
def variance(self):
alpha0 = expand_dims.expand_dims(self.alpha0, axis=-1)
return (
self.alpha
* (alpha0 - self.alpha)
/ alpha0 ** 2
/ (alpha0 + 1))
@distribution.register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(dist1, dist2):
return (
- _lbeta(dist1.alpha)
+ _lbeta(dist2.alpha)
+ sum_mod.sum(
(dist1.alpha - dist2.alpha)
* (digamma.digamma(dist1.alpha)
- expand_dims.expand_dims(
digamma.digamma(dist1.alpha0),
axis=-1)),
axis=-1))