Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5338 from YoshikawaMasashi/distributions/chisquare
add `D.Chisquare`
- Loading branch information
Showing
4 changed files
with
112 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy | ||
|
||
import chainer | ||
from chainer.backends import cuda | ||
from chainer import distribution | ||
from chainer.functions.math import digamma | ||
from chainer.functions.math import exponential | ||
from chainer.functions.math import lgamma | ||
|
||
|
||
class Chisquare(distribution.Distribution): | ||
|
||
"""Chi-Square Distribution. | ||
The probability density function of the distribution is expressed as | ||
.. math:: | ||
p(x;k) = \\frac{1}{2^{k/2}\\Gamma(k/2)}x^{k/2-1}e^{-x/2} | ||
Args: | ||
k(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \ | ||
:class:`cupy.ndarray`): Parameter of distribution. | ||
""" | ||
|
||
def __init__(self, k): | ||
super(Chisquare, self).__init__() | ||
self.__k = chainer.as_variable(k) | ||
|
||
@property | ||
def k(self): | ||
return self.__k | ||
|
||
@property | ||
def batch_shape(self): | ||
return self.k.shape | ||
|
||
@property | ||
def entropy(self): | ||
return 0.5 * self.k + numpy.log(2.) + lgamma.lgamma(0.5 * self.k) \ | ||
+ (1 - 0.5 * self.k) * digamma.digamma(0.5 * self.k) | ||
|
||
@property | ||
def event_shape(self): | ||
return () | ||
|
||
def log_prob(self, x): | ||
return - lgamma.lgamma(0.5 * self.k) - 0.5 * self.k * numpy.log(2.) \ | ||
+ (0.5 * self.k - 1) * exponential.log(x) - 0.5 * x | ||
|
||
@property | ||
def mean(self): | ||
return self.k | ||
|
||
def sample_n(self, n): | ||
xp = cuda.get_array_module(self.k) | ||
if xp is cuda.cupy: | ||
eps = xp.random.chisquare( | ||
self.k.data, (n,)+self.k.shape, dtype=self.k.dtype) | ||
else: | ||
eps = xp.random.chisquare( | ||
self.k.data, (n,)+self.k.shape).astype(self.k.dtype) | ||
noise = chainer.Variable(eps) | ||
return noise | ||
|
||
@property | ||
def support(self): | ||
return 'positive' | ||
|
||
@property | ||
def variance(self): | ||
return 2 * self.k |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from chainer import distributions | ||
from chainer import testing | ||
import numpy | ||
|
||
|
||
@testing.parameterize(*testing.product({ | ||
'shape': [(2, 3), ()], | ||
'is_variable': [True, False], | ||
'sample_shape': [(3, 2), ()], | ||
})) | ||
@testing.fix_random() | ||
@testing.with_requires('scipy') | ||
class TestChisquare(testing.distribution_unittest): | ||
|
||
scipy_onebyone = True | ||
|
||
def setUp_configure(self): | ||
from scipy import stats | ||
self.dist = distributions.Chisquare | ||
self.scipy_dist = stats.chi2 | ||
|
||
self.test_targets = set([ | ||
"batch_shape", "entropy", "event_shape", "log_prob", "mean", | ||
"sample", "support", "variance"]) | ||
|
||
k = numpy.random.randint(1, 10, self.shape).astype(numpy.float32) | ||
self.params = {"k": k} | ||
self.scipy_params = {"df": k} | ||
|
||
self.support = "positive" | ||
|
||
def sample_for_test(self): | ||
smp = numpy.random.chisquare( | ||
df=1, size=self.sample_shape + self.shape | ||
).astype(numpy.float32) | ||
return smp | ||
|
||
|
||
testing.run_module(__name__, __file__) |