Skip to content

Commit

Permalink
Merge pull request #5240 from gwtnb/normal-ndtr
Browse files Browse the repository at this point in the history
Use ndtr and log_ndtr in normal distribution
  • Loading branch information
toslunar committed Aug 20, 2018
2 parents 91aa32f + d2f51f3 commit 92a288c
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions chainer/distributions/normal.py
Expand Up @@ -7,9 +7,10 @@
from chainer import distribution
from chainer.functions.array import expand_dims
from chainer.functions.array import repeat
from chainer.functions.math import erfc
from chainer.functions.math import erfcinv
from chainer.functions.math import exponential
from chainer.functions.math import log_ndtr
from chainer.functions.math import ndtr
from chainer.utils import argument


Expand Down Expand Up @@ -75,7 +76,7 @@ def batch_shape(self):
return self.loc.shape

def cdf(self, x):
return 0.5 * erfc.erfc((self.loc - x) / (2 ** 0.5 * self.scale))
return ndtr.ndtr((x - self.loc) / self.scale)

@property
def entropy(self):
Expand All @@ -95,14 +96,14 @@ def _is_gpu(self):
return isinstance(self.loc.data, cuda.ndarray)

def log_cdf(self, x):
return exponential.log(self.cdf(x))
return log_ndtr.log_ndtr((x - self.loc) / self.scale)

def log_prob(self, x):
return LOGPROBC - self.log_scale \
- 0.5 * (x - self.loc) ** 2 / self.scale ** 2

def log_survival_function(self, x):
return exponential.log(self.survival_function(x))
return log_ndtr.log_ndtr((self.loc - x) / self.scale)

@property
def mean(self):
Expand Down Expand Up @@ -135,7 +136,7 @@ def support(self):
return 'real'

def survival_function(self, x):
return 0.5 * erfc.erfc((x - self.loc) / (2 ** 0.5 * self.scale))
return ndtr.ndtr((self.loc - x) / self.scale)

@property
def variance(self):
Expand Down

0 comments on commit 92a288c

Please sign in to comment.