Skip to content

Commit

Permalink
Merge 93bb349 into f97db2a
Browse files Browse the repository at this point in the history
  • Loading branch information
YoshikawaMasashi committed Apr 27, 2018
2 parents f97db2a + 93bb349 commit ee5355f
Show file tree
Hide file tree
Showing 20 changed files with 1,636 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer/__init__.py
Expand Up @@ -27,6 +27,7 @@
from chainer.configuration import config # NOQA
from chainer.configuration import global_config # NOQA
from chainer.configuration import using_config # NOQA
from chainer.distribution import Distribution # NOQA
from chainer.function import force_backprop_mode # NOQA
from chainer.function import Function # NOQA
from chainer.function import FunctionAdapter # NOQA
Expand Down
270 changes: 270 additions & 0 deletions chainer/distribution.py
@@ -0,0 +1,270 @@
import copy


class Distribution(object):

"""Interface of Distribution
`Distribution` is a bass class to treat probability distributions.
When initialization, it takes parameter as input.
"""

def _copy_to(self, target):
target.__dict__ = copy.copy(self.__dict__)
return target

@property
def batch_shape(self):
"""Returns the shape of a sample.
Returns:
~chainer.Variable: Output variable representing the shape of a
sample.
"""
raise NotImplementedError

def cdf(self, x):
"""Returns Cumulative Distribution Function for a input variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing Cumulative
Distribution Function.
"""
raise NotImplementedError

@property
def covariance(self):
"""Returns covariance.
Returns:
~chainer.Variable: Output variable representing covariance.
"""
raise NotImplementedError

@property
def entropy(self):
"""Returns entropy.
Returns:
~chainer.Variable: Output variable representing entropy.
"""
raise NotImplementedError

@property
def enumerate_support(self):
"""Returns support values of discrete distribution.
Returns:
~chainer.Variable: Output variable containing candidates.
"""
raise NotImplementedError

@property
def event_shape(self):
"""Returns the shape of an event.
Returns:
~chainer.Variable: Output variable representing the shape of an
event.
"""
raise NotImplementedError

def icdf(self, x):
"""Returns Inverse Cumulative Distribution Function for a input Variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing Inverse Cumulative
Distribution Function.
"""
raise NotImplementedError

def log_cdf(self, x):
"""Returns logarithm of Cumulative Distribution Function for a input Variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing logarithm of
Cumulative Distribution Function.
"""
raise NotImplementedError

def log_prob(self, x):
"""Returns logarithm of probability for a input variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing logarithm of
probability.
"""
raise NotImplementedError

def log_survival_function(self, x):
"""Returns logarithm of survival function for a input Variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing logarithm of
survival function for a input variable.
"""
raise NotImplementedError

@property
def mean(self):
"""Returns mean value.
Returns:
~chainer.Variable: Output variable representing mean value.
"""
raise NotImplementedError

@property
def mode(self):
"""Returns mode.
Returns:
~chainer.Variable: Output variable representing mode.
"""
raise NotImplementedError

def perplexity(self, x):
"""Returns perplexity function for a input variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing perplexity function
for a input variable.
"""
raise NotImplementedError

def prob(self, x):
"""Returns probability for a input variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing probability.
"""
raise NotImplementedError

def sample(self, shape=()):
"""Samples from this distribution.
Args:
shape(:class:`tuple` of :class:`int`): Sampling shape.
Returns:
~chainer.Variable: Output variable representing sampled random
variable.
"""
final_shape = self.batch_shape + self.event_shape
if shape == ():
n = 1
elif isinstance(shape, int):
n = shape
final_shape = (n,) + final_shape
else:
n = 1
for shape_ in shape:
n *= shape_
final_shape = shape + final_shape
samples = self._sample_n(n)
return samples.reshape(final_shape)

def _sample_n(self, n):
"""Samples from this distribution.
Args:
n(`int`): Sampling size.
Returns:
~chainer.Variable: Output variable representing sampled random
variable.
"""
raise NotImplementedError

@property
def stddev(self):
"""Returns standard deviation.
Returns:
~chainer.Variable: Output variable representing standard deviation.
"""
raise NotImplementedError

@property
def support(self):
"""Returns support.
Returns:
string: Output string that means support of this distribution.
"""
raise NotImplementedError

def survival_function(self, x):
"""Returns survival function for a input variable.
Args:
x(:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`): Input variable representing a random
variable.
Returns:
~chainer.Variable: Output variable representing survival function
for a input variable.
"""
raise NotImplementedError

@property
def variance(self):
"""Returns variance.
Returns:
~chainer.Variable: Output variable representing variance.
"""
raise NotImplementedError

0 comments on commit ee5355f

Please sign in to comment.