Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ndtr and log_ndtr in normal distribution #5240

Merged
merged 1 commit into from Aug 20, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions chainer/distributions/normal.py
Expand Up @@ -7,9 +7,10 @@
from chainer import distribution
from chainer.functions.array import expand_dims
from chainer.functions.array import repeat
from chainer.functions.math import erfc
from chainer.functions.math import erfcinv
from chainer.functions.math import exponential
from chainer.functions.math import log_ndtr
from chainer.functions.math import ndtr
from chainer.utils import argument


Expand Down Expand Up @@ -75,7 +76,7 @@ def batch_shape(self):
return self.loc.shape

def cdf(self, x):
return 0.5 * erfc.erfc((self.loc - x) / (2 ** 0.5 * self.scale))
return ndtr.ndtr((x - self.loc) / self.scale)

@property
def entropy(self):
Expand All @@ -95,14 +96,14 @@ def _is_gpu(self):
return isinstance(self.loc.data, cuda.ndarray)

def log_cdf(self, x):
return exponential.log(self.cdf(x))
return log_ndtr.log_ndtr((x - self.loc) / self.scale)

def log_prob(self, x):
return LOGPROBC - self.log_scale \
- 0.5 * (x - self.loc) ** 2 / self.scale ** 2

def log_survival_function(self, x):
return exponential.log(self.survival_function(x))
return log_ndtr.log_ndtr((self.loc - x) / self.scale)

@property
def mean(self):
Expand Down Expand Up @@ -135,7 +136,7 @@ def support(self):
return 'real'

def survival_function(self, x):
return 0.5 * erfc.erfc((x - self.loc) / (2 ** 0.5 * self.scale))
return ndtr.ndtr((self.loc - x) / self.scale)

@property
def variance(self):
Expand Down