New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Dirichlet distribution #5115
Add Dirichlet distribution #5115
Conversation
jenkins, test this please. |
1 similar comment
jenkins, test this please. |
Jenkins CI test (for commit de38f99) failed with status FAILURE. |
jenkins, test this please. |
Jenkins CI test (for commit 05c1ddf) succeeded without errors! |
jenkins, test this please. |
Jenkins CI test (for commit ee1a1a1, target branch master) failed with status FAILURE. |
chainer/distributions/dirichlet.py
Outdated
from chainer.functions.math import exponential | ||
from chainer.functions.math import lgamma | ||
from chainer.functions.math import sum as sum_mod | ||
import numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you fix the import grouping?
chainer/distributions/dirichlet.py
Outdated
@property | ||
def entropy(self): | ||
return sum_mod.sum(lgamma.lgamma(self.alpha), axis=-1) \ | ||
- lgamma.lgamma(self.alpha0) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multivariate log beta function (not necessarily a FunctionNode
) could be useful in entropy
, log_prob
, and kl.
chainer/distributions/dirichlet.py
Outdated
return self.__alpha | ||
|
||
@property | ||
def k(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to encourage users to access this via event_shape
, instead of providing the property k
.
# Conflicts: # chainer/distributions/__init__.py # docs/source/reference/distributions.rst # tests/chainer_tests/distributions_tests/test_kldivergence.py
jenkins, test this please. |
Jenkins CI test (for commit c9f9516, target branch master) failed with status FAILURE. |
dist2 = self.make_dirichlet_dist() | ||
self.check_kl(dist1, dist2) | ||
|
||
@testing.with_requires('scipy') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need scipy
for the GPU test.
@@ -55,6 +55,12 @@ def make_categorical_dist(self, is_gpu=False): | |||
params = self.encode_params({"p": p}, is_gpu) | |||
return distributions.Categorical(**params) | |||
|
|||
def make_dirichlet_dist(self, is_gpu=False): | |||
alpha = numpy.random.uniform( | |||
1, 10, self.shape + (3,)).astype(numpy.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you tried #5088 (comment) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK to merge this.
split from #4678.